@@ -13,15 +13,14 @@ def __init__(self):
1313 pass
1414
1515 @staticmethod
16- def convert (model , imgs , weights , dynamic , simplify ):
16+ def convert (model , imgs , weights , dynamic ):
1717 """
1818 torch模型转为onnx模型
1919
2020 model: torch模型
2121 imgs: [B,C,H,W]Tensor
2222 weights: onnx权重保存路径
2323 dynamic: batch轴是否设为动态维度
24- simplify: 是否简化onnx
2524 """
2625 torch .onnx .export (
2726 model ,
@@ -37,17 +36,18 @@ def convert(model, imgs, weights, dynamic, simplify):
3736 )
3837 model_onnx = onnx .load (weights ) # load onnx model
3938 onnx .checker .check_model (model_onnx ) # check onnx model
40- if simplify :
41- try :
42- model_onnx , check = onnxsim .simplify (
43- model_onnx ,
44- dynamic_input_shape = dynamic ,
45- input_shapes = {"input" : list (imgs .shape )} if dynamic else None ,
46- )
47- assert check , "assert check failed"
48- onnx .save (model_onnx , weights )
49- except Exception as e :
50- print (f"simplifer failure: { e } " )
39+
40+ try :
41+ # 简化onnx
42+ model_onnx , check = onnxsim .simplify (
43+ model_onnx ,
44+ dynamic_input_shape = dynamic ,
45+ input_shapes = {"input" : list (imgs .shape )} if dynamic else None ,
46+ )
47+ assert check , "assert check failed"
48+ onnx .save (model_onnx , weights )
49+ except Exception as e :
50+ print (f"simplifer failure: { e } " )
5151
5252 print ("*" * 28 )
5353 print ("ONNX export success, saved as %s" % weights )
0 commit comments