onnx模式释放模型降低占用

This commit is contained in:
zhiyang7 2021-12-29 16:54:30 +08:00
parent e8e6bf8f59
commit 8ac4684607
1 changed files with 15 additions and 13 deletions

View File

@ -30,19 +30,21 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt")) # torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
model.eval() model.eval()
model_path = model_path + '.onnx' model_path = model_path + '.onnx'
if use_onnx and not os.path.exists(model_path): if use_onnx:
onnx_params = model.get_onnx_params(torch.device('cpu')) if not os.path.exists(model_path):
torch.onnx.export( onnx_params = model.get_onnx_params(torch.device('cpu'))
model, torch.onnx.export(
onnx_params['args'], model,
model_path, onnx_params['args'],
export_params=True, model_path,
opset_version=10, export_params=True,
do_constant_folding=True, opset_version=10,
input_names=onnx_params['input_names'], do_constant_folding=True,
output_names=onnx_params['output_names'], input_names=onnx_params['input_names'],
dynamic_axes=onnx_params['dynamic_axes'] output_names=onnx_params['output_names'],
) dynamic_axes=onnx_params['dynamic_axes']
)
return None
return model return model