From c970a59655d53c0ac4e286a72692dd3c08a5c85b Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Wed, 29 Dec 2021 17:13:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBUG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/evaluation/deep_agent.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/douzero/evaluation/deep_agent.py b/douzero/evaluation/deep_agent.py index e1a8561..d522aa0 100644 --- a/douzero/evaluation/deep_agent.py +++ b/douzero/evaluation/deep_agent.py @@ -8,8 +8,8 @@ from douzero.env.env import get_obs def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False): from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite - model_path = model_path + '.onnx' - if use_onnx and os.path.exists(model_path): + model_path_onnx = model_path + '.onnx' + if use_onnx and os.path.exists(model_path_onnx): return None model = None if model_type == "general": @@ -30,14 +30,13 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict} model_state_dict.update(pretrained) model.load_state_dict(model_state_dict) - # torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt")) model.eval() - if use_onnx and not os.path.exists(model_path): + if use_onnx and not os.path.exists(model_path_onnx): onnx_params = model.get_onnx_params(torch.device('cpu')) torch.onnx.export( model, onnx_params['args'], - model_path, + model_path_onnx, export_params=True, opset_version=10, do_constant_folding=True,