Skip to content

Commit bc8cd70

Browse files
author
lipengbo
committed
简化onnx
2 parents 0786fb8 + da7313a commit bc8cd70

2 files changed

Lines changed: 18 additions & 18 deletions

File tree

Models/Backend/onnx.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@ def __init__(self):
1313
pass
1414

1515
@staticmethod
16-
def convert(model, imgs, weights, dynamic, simplify):
16+
def convert(model, imgs, weights, dynamic):
1717
"""
1818
torch模型转为onnx模型
1919
2020
model: torch模型
2121
imgs: [B,C,H,W]Tensor
2222
weights: onnx权重保存路径
2323
dynamic: batch轴是否设为动态维度
24-
simplify: 是否简化onnx
2524
"""
2625
torch.onnx.export(
2726
model,
@@ -37,17 +36,18 @@ def convert(model, imgs, weights, dynamic, simplify):
3736
)
3837
model_onnx = onnx.load(weights) # load onnx model
3938
onnx.checker.check_model(model_onnx) # check onnx model
40-
if simplify:
41-
try:
42-
model_onnx, check = onnxsim.simplify(
43-
model_onnx,
44-
dynamic_input_shape=dynamic,
45-
input_shapes={"input": list(imgs.shape)} if dynamic else None,
46-
)
47-
assert check, "assert check failed"
48-
onnx.save(model_onnx, weights)
49-
except Exception as e:
50-
print(f"simplifer failure: {e}")
39+
40+
try:
41+
# 简化onnx
42+
model_onnx, check = onnxsim.simplify(
43+
model_onnx,
44+
dynamic_input_shape=dynamic,
45+
input_shapes={"input": list(imgs.shape)} if dynamic else None,
46+
)
47+
assert check, "assert check failed"
48+
onnx.save(model_onnx, weights)
49+
except Exception as e:
50+
print(f"simplifer failure: {e}")
5151

5252
print("*" * 28)
5353
print("ONNX export success, saved as %s" % weights)

export.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
# onnx
1818
parser.add_argument("--torch2onnx", action="store_true", help="(可选)转为onnx")
19-
parser.add_argument("--simplify", action="store_true", help="(可选)简化onnx")
2019
parser.add_argument("--dynamic", action="store_true", help="(可选)batch轴设为动态")
2120

2221
# tensorrt
@@ -67,8 +66,7 @@
6766
model=model,
6867
imgs=imgs,
6968
weights=onnx_weights,
70-
dynamic=cfg.dynamic,
71-
simplify=cfg.simplify,
69+
dynamic=cfg.dynamic
7270
)
7371
output_onnx = OnnxBackend.infer(weights=onnx_weights, imgs=imgs.numpy())
7472

@@ -108,8 +106,10 @@
108106

109107
from Models.Backend.mnn import MNNBackbend
110108

111-
MNNBackbend.convert(onnx_weights,mnn_weights,fp16=cfg.mnn_fp16)
112-
output_mnn = MNNBackbend.infer(mnn_weights, imgs.numpy(),output_shape=output_onnx.shape)
109+
MNNBackbend.convert(onnx_weights, mnn_weights, fp16=cfg.mnn_fp16)
110+
output_mnn = MNNBackbend.infer(
111+
mnn_weights, imgs.numpy(), output_shape=output_onnx.shape
112+
)
113113
# ==========================验证结果===============================
114114
print("\n", "*" * 28)
115115
if cfg.torch2script:

0 commit comments

Comments
 (0)