77#
88# Copyright Ⓒ 2023 Mukai (Tom Notch) Yu, Yao He
99#
10+ import csv
11+ import json
1012import os
13+ import warnings
1114
1215import cv2
16+ import magic
1317import numpy as np
18+ import toml
19+ import yaml
1420
1521# import torch
1622# import torch_tensorrt
1723
1824
19- # loaded_models = [] # list of loaded models for one total program run
25+ mime = magic . Magic ( mime = True , uncompress = True )
2026
27+ file_cache = {} # cache for files that are read
2128
22- # def load_model(model_path: str) -> torch.nn.Module:
23- # """ensure that only one instance of the model is loaded, later loadings will point to the same model instance loaded before
2429
25- # Args :
26- # model_path (str): path to the model
30+ def opencv_matrix_constructor ( loader , node ) :
31+ """Custom constructor for !!opencv-matrix tag."""
2732
28- # Returns:
29- # torch.nn.Module: model
30- # """
31- # model = torch.jit.load(model_path)
33+ # Parse the node as a dictionary
34+ matrix_data = loader .construct_mapping (node , deep = True )
3235
33- # for loaded_model in loaded_models:
34- # if (
35- # model.state_dict() == loaded_model.state_dict()
36- # ): # if the model is already loaded
37- # del model # delete the duplicate model
38- # return loaded_model # return the loaded model
36+ # Extract rows, cols, dt, and data
37+ rows = matrix_data ["rows" ]
38+ cols = matrix_data ["cols" ]
39+ dt = matrix_data ["dt" ]
40+ data = matrix_data ["data" ]
3941
40- # loaded_models.append(model) # model's a new model, add it to the list
41- # return model
42+ # Map OpenCV data types to NumPy data types
43+ dtype_map = {"u" : np .uint8 , "i" : np .int32 , "f" : np .float32 , "d" : np .float64 }
44+
45+ # Determine the NumPy data type
46+ dtype = dtype_map .get (dt , np .float64 )
47+
48+ # Convert data to a NumPy array and reshape
49+ matrix = np .array (data , dtype = dtype ).reshape ((rows , cols ))
50+
51+ return matrix
52+
53+
54+ yaml .add_constructor ("tag:yaml.org,2002:opencv-matrix" , opencv_matrix_constructor )
4255
4356
4457def print_dict (d : dict , indent : int = 0 ) -> None :
@@ -77,53 +90,41 @@ def parse_path(probe_path: str, base_path: str = None):
7790 if base_path is None :
7891 base_path = os .getcwd ()
7992 if os .path .isabs (expand_path ) and os .path .exists (expand_path ):
80- return expand_path
93+ return os . path . realpath ( expand_path )
8194 elif os .path .exists (os .path .join (base_path , probe_path )):
82- return os .path .join (base_path , probe_path )
95+ return os .path .realpath ( os . path . join (base_path , probe_path ) )
8396 else :
8497 return False
8598
8699
87- def get_item (node : cv2 . FileNode , yaml_base_path : str ):
88- """get an item from a cv2.FileNode, recursively parse the item if it is a Map, List or path
100+ def parse_content (node , yaml_base_path : str ):
101+ """recursively look into the leaf node of a yaml file, if the leaf node is a string, try to parse it as a path and read the file, could be an image or nested yaml config
89102
90103 Args:
91- node (cv2.FileNode) : file node to be parsed
104+ node: file node to be parsed
92105 yaml_base_path (str): the base path of the current yaml file
93106
94107 Returns:
95108 the read content
96109 """
97- if node .isNone (): # empty
98- return None
99- elif node .isMap (): # dict
100- keys = node .keys ()
101- if all (mat_key in keys for mat_key in ["rows" , "cols" , "dt" , "data" ]): # matrix
102- return node .mat ()
103- else : # key-value pairs
104- dict = {}
105- for key in keys :
106- dict [key ] = get_item (node .getNode (key ), yaml_base_path )
107- return dict
108- elif node .isSeq (): # list
109- list = []
110- for i in range (node .size ()):
111- list .append (get_item (node .at (i ), yaml_base_path ))
112- return list
113- elif node .isReal () or node .isInt (): # number
114- return node .real ()
115- elif node .isString (): # string
116- path = parse_path (
117- node .string (), yaml_base_path
118- ) # try parsing the string as path
110+
111+ if isinstance (node , dict ):
112+ for key , value in node .items ():
113+ node [key ] = parse_content (value , yaml_base_path )
114+ elif isinstance (node , list ):
115+ for index , value in enumerate (node ):
116+ node [index ] = parse_content (value , yaml_base_path )
117+ elif isinstance (node , str ):
118+ path = parse_path (node , yaml_base_path )
119119 if path :
120- return read_file (path )
121- else : # not a path
122- return node . string ()
120+ node = read_file (path )
121+
122+ return node
123123
124124
125125def read_file (path : str ):
126- """test the path, read a file, if it is a yaml file; parse it, if it is an image, read it; if it is a torchscript module, return the path; otherwise raise exception for file not supported.
126+ """test the path, read a file.
127+ supports multiple file type and interleaving config file types.
127128
128129 Args:
129130 path (str): path to the file, can be absolute or relative
@@ -133,12 +134,44 @@ def read_file(path: str):
133134 """
134135 if not os .path .exists (path ):
135136 raise FileNotFoundError (f"File { path } does not exist" )
136- if path .endswith (".yaml" ) or path .endswith (".yml" ):
137- yaml_file = cv2 .FileStorage (path , cv2 .FILE_STORAGE_READ )
138- return get_item (yaml_file .root (), os .path .dirname (path ))
139- elif path .endswith (".png" ) or path .endswith (".jpg" ) or path .endswith (".jpeg" ):
140- return cv2 .imread (path )
141- # elif path.endswith(".ts") or path.endswith(".trt"):
142- # return load_model(path)
137+
138+ if path in file_cache .keys ():
139+ return file_cache [path ]
140+
141+ # determine file type with python-magic, useful for massive image types
142+ mime_type = mime .from_file (path )
143+
144+ file_base_path = os .path .dirname (path )
145+
146+ if mime_type in {
147+ "application/yaml" ,
148+ "application/x-yaml" ,
149+ "text/yaml" ,
150+ } or path .endswith ((".yaml" , ".yml" )):
151+ with open (path , "r" ) as f :
152+ parsed_yaml = yaml .load (f .read (), Loader = yaml .FullLoader )
153+ parsed_content = parse_content (parsed_yaml , file_base_path )
154+ elif mime_type .startswith ("image/" ):
155+ parsed_content = cv2 .imread (path )
156+ elif mime_type in {"text/csv" , "application/csv" } or path .endswith (".csv" ):
157+ with open (path , "r" ) as f :
158+ reader = csv .reader (f )
159+ parsed_content = list (reader )
160+ elif mime_type in {"application/json" , "text/json" } or path .endswith (".json" ):
161+ with open (path , "r" ) as f :
162+ parsed_content = parse_content (json .load (f ), file_base_path )
163+ elif mime_type in {"application/toml" , "text/toml" } or path .endswith (".toml" ):
164+ parsed_content = parse_content (toml .load (path ), file_base_path )
165+ elif path .endswith (".npy" ):
166+ try :
167+ parsed_content = np .load (path , allow_pickle = True )
168+ except Exception as e :
169+ raise RuntimeError (f"Error loading NumPy array from { path } : { e } " )
170+ # elif mime_type in {'application/x-torchscript', 'application/x-tensorrt'} or path.endswith((".ts", ".trt")):
171+ # parsed_content = torch.jit.load(model_path)
143172 else :
144- raise Exception (f"File { path } is not supported" )
173+ warnings .warn (f"File { path } is not supported, reading as string" )
174+ parsed_content = path
175+
176+ file_cache [path ] = parsed_content
177+ return parsed_content
0 commit comments