@@ -35,6 +35,7 @@ def str2bool(v):
3535 type = str , help = 'The output ONNX file, trained parameters inside' )
3636parser .add_argument ('--enable_dynamic_axes' , default = True ,
3737 type = str2bool , help = 'Enable dynamic axes for ONNX model.' )
38+ parser .add_argument ('--opset_version' , default = 11 , help = 'ONNX opset version to output.' )
3839args = parser .parse_args ()
3940
4041def check_keys (model , pretrained_state_dict ):
@@ -81,7 +82,7 @@ def load_model(model, pretrained_path, load_to_cpu):
8182 net .eval ()
8283
8384 print ('Finished loading model!' )
84-
85+
8586 img = torch .randn (1 , 3 , 480 , 640 , requires_grad = False )
8687 img = img .to (torch .device ('cpu' ))
8788
@@ -93,8 +94,8 @@ def load_model(model, pretrained_path, load_to_cpu):
9394 'conf' : {0 : 'batch_size' , 1 : 'num' , 2 : 'cls_data' },
9495 'iou' : {0 : 'batch_size' , 1 : 'num' , 2 : 'iou_data' }}
9596 output_path = os .path .join ('./onnx' , args .output_name + '.onnx' )
96- torch .onnx .export (net , img , output_path , input_names = input_names , output_names = output_names , dynamic_axes = dynamic_axes )
97+ torch .onnx .export (net , img , output_path , input_names = input_names , output_names = output_names , dynamic_axes = dynamic_axes , opset_version = args . opset_version )
9798 else :
9899 output_path = os .path .join ('./onnx' , args .output_name + '_' + str (args .image_dim ) + '.onnx' )
99- torch .onnx .export (net , img , output_path , input_names = input_names , output_names = output_names )
100- print ('Finished exporing model to ' + output_path )
100+ torch .onnx .export (net , img , output_path , input_names = input_names , output_names = output_names , opset_version = args . opset_version )
101+ print ('Finished exporing model to ' + output_path )
0 commit comments