diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 1b350b2..4002be0 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -123,11 +123,13 @@ def act(i, device, batch_queues, model, flags): type_buf[mul_obs["position"]].append(2) size[mul_obs["position"]] += 1 while True: - - with torch.no_grad(): - agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags) - _action_idx = int(agent_output['action'].cpu().detach().numpy()) - action = obs['legal_actions'][_action_idx] + if len(obs['legal_actions']) > 1: + with torch.no_grad(): + agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags) + _action_idx = int(agent_output['action'].cpu().detach().numpy()) + action = obs['legal_actions'][_action_idx] + else: + action = obs['legal_actions'][0] if flags.old_model and position != 'bidding': obs_action_buf[position].append(_cards2tensor(action)) obs_x_no_action[position].append(env_output['obs_x_no_action'])