onnx模式释放模型降低占用
This commit is contained in:
parent
e8e6bf8f59
commit
8ac4684607
|
@ -30,19 +30,21 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx
|
||||||
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
|
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
|
||||||
model.eval()
|
model.eval()
|
||||||
model_path = model_path + '.onnx'
|
model_path = model_path + '.onnx'
|
||||||
if use_onnx and not os.path.exists(model_path):
|
if use_onnx:
|
||||||
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
if not os.path.exists(model_path):
|
||||||
torch.onnx.export(
|
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
||||||
model,
|
torch.onnx.export(
|
||||||
onnx_params['args'],
|
model,
|
||||||
model_path,
|
onnx_params['args'],
|
||||||
export_params=True,
|
model_path,
|
||||||
opset_version=10,
|
export_params=True,
|
||||||
do_constant_folding=True,
|
opset_version=10,
|
||||||
input_names=onnx_params['input_names'],
|
do_constant_folding=True,
|
||||||
output_names=onnx_params['output_names'],
|
input_names=onnx_params['input_names'],
|
||||||
dynamic_axes=onnx_params['dynamic_axes']
|
output_names=onnx_params['output_names'],
|
||||||
)
|
dynamic_axes=onnx_params['dynamic_axes']
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue