From b2dd98cc304b844ae8f9d8e2e66b247cd56eca2b Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Wed, 29 Dec 2021 16:56:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=B5=84=E6=BA=90=E5=8D=A0?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/evaluation/deep_agent.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/douzero/evaluation/deep_agent.py b/douzero/evaluation/deep_agent.py index 398bb54..e1a8561 100644 --- a/douzero/evaluation/deep_agent.py +++ b/douzero/evaluation/deep_agent.py @@ -8,6 +8,9 @@ from douzero.env.env import get_obs def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False): from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite + model_path = model_path + '.onnx' + if use_onnx and os.path.exists(model_path): + return None model = None if model_type == "general": if use_lite: @@ -29,21 +32,19 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx model.load_state_dict(model_state_dict) # torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt")) model.eval() - model_path = model_path + '.onnx' - if use_onnx: - if not os.path.exists(model_path): - onnx_params = model.get_onnx_params(torch.device('cpu')) - torch.onnx.export( - model, - onnx_params['args'], - model_path, - export_params=True, - opset_version=10, - do_constant_folding=True, - input_names=onnx_params['input_names'], - output_names=onnx_params['output_names'], - dynamic_axes=onnx_params['dynamic_axes'] - ) + if use_onnx and not os.path.exists(model_path): + onnx_params = model.get_onnx_params(torch.device('cpu')) + torch.onnx.export( + model, + onnx_params['args'], + model_path, + export_params=True, + opset_version=10, + do_constant_folding=True, + input_names=onnx_params['input_names'], + output_names=onnx_params['output_names'], + dynamic_axes=onnx_params['dynamic_axes'] + ) return None return model