From c82b834d8999a57151b96907db524c9222665a14 Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Tue, 14 Dec 2021 08:53:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E9=80=9F=E9=99=8D=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 1b350b2..4002be0 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -123,11 +123,13 @@ def act(i, device, batch_queues, model, flags): type_buf[mul_obs["position"]].append(2) size[mul_obs["position"]] += 1 while True: - - with torch.no_grad(): - agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags) - _action_idx = int(agent_output['action'].cpu().detach().numpy()) - action = obs['legal_actions'][_action_idx] + if len(obs['legal_actions']) > 1: + with torch.no_grad(): + agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags) + _action_idx = int(agent_output['action'].cpu().detach().numpy()) + action = obs['legal_actions'][_action_idx] + else: + action = obs['legal_actions'][0] if flags.old_model and position != 'bidding': obs_action_buf[position].append(_cards2tensor(action)) obs_x_no_action[position].append(env_output['obs_x_no_action'])