From 9f2e8f74f33082a3422837bdc2cd9e115add62e7 Mon Sep 17 00:00:00 2001 From: ZaneYork Date: Tue, 14 Dec 2021 23:06:07 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8onnx=E8=BF=9B=E8=A1=8Cinfer?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E6=9C=AA=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/dmc.py | 2 +- douzero/dmc/utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 1365313..9ecdb5d 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -53,7 +53,7 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock): if len(episode_returns) > 0: mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device)) with lock: - learner_outputs = model(obs_z, obs_x, return_value=True) + learner_outputs = model(obs_z, obs_x) if position == "bidding": loss = compute_loss(learner_outputs['values'], target) # pass diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index dd7c0f2..7e84fbb 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -118,7 +118,8 @@ def act(i, device, batch_queues, model, flags, onnx_frame): for p in positions: if p != 'bidding': model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, p) - model.set_onnx_model(p, os.path.abspath(model_path)) + if os.path.exists(model_path): + model.set_onnx_model(p, os.path.abspath(model_path)) for bid_obs in bid_obs_buffer: obs_z_buf["bidding"].append(bid_obs['z_batch'])