77
88# from ..models import MODELS
99from models import MODELS
10- from utils import METRICS
10+ from utils import METRICS , DATALOADERS
1111
1212parser = argparse .ArgumentParser ("Benchmarks for OpenCV Zoo." )
1313parser .add_argument ('--cfg' , '-c' , type = str ,
1414 help = 'Benchmarking on the given config.' )
1515args = parser .parse_args ()
1616
17- def build_from_cfg (cfg , registery , key = 'name' ):
18- obj_name = cfg .pop (key )
19- obj = registery .get (obj_name )
20- return obj (** cfg )
17+ def build_from_cfg (cfg , registery , key = None , name = None ):
18+ if key is not None :
19+ obj_name = cfg .pop (key )
20+ obj = registery .get (obj_name )
21+ return obj (** cfg )
22+ elif name is not None :
23+ obj = registery .get (name )
24+ return obj (** cfg )
25+ else :
26+ raise NotImplementedError ()
2127
2228def prepend_pythonpath (cfg ):
2329 for k , v in cfg .items ():
@@ -27,62 +33,26 @@ def prepend_pythonpath(cfg):
2733 if 'path' in k .lower ():
2834 cfg [k ] = os .path .join (os .environ ['PYTHONPATH' ], v )
2935
30- class Data :
31- def __init__ (self , ** kwargs ):
32- self ._path = kwargs .pop ('path' , None )
33- assert self ._path , 'Benchmark[\' data\' ][\' path\' ] cannot be empty.'
34-
35- self ._files = kwargs .pop ('files' , None )
36- if not self ._files :
37- print ('Benchmark[\' data\' ][\' files\' ] is empty, loading all images by default.' )
38- self ._files = list ()
39- for filename in os .listdir (self ._path ):
40- if filename .endswith ('jpg' ) or filename .endswith ('png' ):
41- self ._files .append (filename )
42-
43- self ._use_label = kwargs .pop ('useLabel' , False )
44- if self ._use_label :
45- self ._labels = self ._load_label ()
46-
47- self ._to_rgb = kwargs .pop ('toRGB' , False )
48- self ._resize = tuple (kwargs .pop ('resize' , []))
49- self ._center_crop = kwargs .pop ('centerCrop' , None )
50-
51- def _load_label (self ):
52- labels = dict .fromkeys (self ._files , None )
53- for filename in self ._files :
54- labels [filename ] = np .loadtxt (os .path .join (self ._path , '{}.txt' .format (filename [:- 4 ])), ndmin = 2 )
55- return labels
56-
57- def __getitem__ (self , idx ):
58- image = cv .imread (os .path .join (self ._path , self ._files [idx ]))
59-
60- if self ._to_rgb :
61- image = cv .cvtColor (image , cv .COLOR_BGR2RGB )
62- if self ._resize :
63- image = cv .resize (image , self ._resize )
64- if self ._center_crop :
65- h , w , _ = image .shape
66- w_crop = int ((w - self ._center_crop ) / 2. )
67- assert w_crop >= 0
68- h_crop = int ((h - self ._center_crop ) / 2. )
69- assert h_crop >= 0
70- image = image [w_crop :w - w_crop , h_crop :h - h_crop , :]
71-
72- if self ._use_label :
73- return self ._files [idx ], image , self ._labels [self ._files [idx ]]
74- else :
75- return self ._files [idx ], image
76-
7736class Benchmark :
7837 def __init__ (self , ** kwargs ):
38+ self ._type = kwargs .pop ('type' , None )
39+ if self ._type is None :
40+ self ._type = 'Base'
41+ print ('Benchmark[\' type\' ] is omitted, set to \' Base\' by default.' )
42+
7943 self ._data_dict = kwargs .pop ('data' , None )
8044 assert self ._data_dict , 'Benchmark[\' data\' ] cannot be empty and must have path and files.'
81- self ._data = Data (** self ._data_dict )
45+ if 'type' in self ._data_dict :
46+ self ._dataloader = build_from_cfg (self ._data_dict , registery = DATALOADERS , key = 'type' )
47+ else :
48+ self ._dataloader = build_from_cfg (self ._data_dict , registery = DATALOADERS , name = self ._type )
8249
8350 self ._metric_dict = kwargs .pop ('metric' , None )
84- # self._metric = Metric(**self._metric_dict)
85- self ._metric = build_from_cfg (self ._metric_dict , registery = METRICS , key = 'type' )
51+ assert self ._metric_dict , 'Benchmark[\' metric\' ] cannot be empty.'
52+ if 'type' in self ._metric_dict :
53+ self ._metric = build_from_cfg (self ._metric_dict , registery = METRICS , key = 'type' )
54+ else :
55+ self ._metric = build_from_cfg (self ._metric_dict , registery = METRICS , name = self ._type )
8656
8757 backend_id = kwargs .pop ('backend' , 'default' )
8858 available_backends = dict (
@@ -115,8 +85,15 @@ def run(self, model):
11585 model .setBackend (self ._backend )
11686 model .setTarget (self ._target )
11787
118- for data in self ._data :
119- self ._benchmark_results [data [0 ]] = self ._metric .forward (model , * data [1 :])
88+ if 'video' in self ._dataloader .name .lower ():
89+ model .init (self ._dataloader .getROI ())
90+
91+ for data in self ._dataloader :
92+ filename , img = data [:2 ]
93+ size = [img .shape [1 ], img .shape [0 ]]
94+ if filename not in self ._benchmark_results :
95+ self ._benchmark_results [filename ] = dict ()
96+ self ._benchmark_results [filename ][str (size )] = self ._metric .forward (model , * data [1 :])
12097
12198 def printResults (self ):
12299 for imgName , results in self ._benchmark_results .items ():
@@ -138,7 +115,7 @@ def printResults(self):
138115 benchmark = Benchmark (** cfg ['Benchmark' ])
139116
140117 # Instantiate model
141- model = build_from_cfg (cfg = cfg ['Model' ], registery = MODELS )
118+ model = build_from_cfg (cfg = cfg ['Model' ], registery = MODELS , key = 'name' )
142119
143120 # Run benchmarking
144121 print ('Benchmarking {}:' .format (model .name ))
0 commit comments