diff --git a/douzero/evaluation/deep_agent.py b/douzero/evaluation/deep_agent.py index df1fb4a..398bb54 100644 --- a/douzero/evaluation/deep_agent.py +++ b/douzero/evaluation/deep_agent.py @@ -30,19 +30,21 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx # torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt")) model.eval() model_path = model_path + '.onnx' - if use_onnx and not os.path.exists(model_path): - onnx_params = model.get_onnx_params(torch.device('cpu')) - torch.onnx.export( - model, - onnx_params['args'], - model_path, - export_params=True, - opset_version=10, - do_constant_folding=True, - input_names=onnx_params['input_names'], - output_names=onnx_params['output_names'], - dynamic_axes=onnx_params['dynamic_axes'] - ) + if use_onnx: + if not os.path.exists(model_path): + onnx_params = model.get_onnx_params(torch.device('cpu')) + torch.onnx.export( + model, + onnx_params['args'], + model_path, + export_params=True, + opset_version=10, + do_constant_folding=True, + input_names=onnx_params['input_names'], + output_names=onnx_params['output_names'], + dynamic_axes=onnx_params['dynamic_axes'] + ) + return None return model