提速降费

This commit is contained in:
zhiyang7 2021-12-14 08:53:50 +08:00
parent 364f882014
commit c82b834d89
1 changed files with 7 additions and 5 deletions

View File

@ -123,11 +123,13 @@ def act(i, device, batch_queues, model, flags):
type_buf[mul_obs["position"]].append(2)
size[mul_obs["position"]] += 1
while True:
with torch.no_grad():
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
_action_idx = int(agent_output['action'].cpu().detach().numpy())
action = obs['legal_actions'][_action_idx]
if len(obs['legal_actions']) > 1:
with torch.no_grad():
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
_action_idx = int(agent_output['action'].cpu().detach().numpy())
action = obs['legal_actions'][_action_idx]
else:
action = obs['legal_actions'][0]
if flags.old_model and position != 'bidding':
obs_action_buf[position].append(_cards2tensor(action))
obs_x_no_action[position].append(env_output['obs_x_no_action'])