提速降费

This commit is contained in:
zhiyang7 2021-12-14 08:53:50 +08:00
parent 364f882014
commit c82b834d89
1 changed files with 7 additions and 5 deletions

View File

@ -123,11 +123,13 @@ def act(i, device, batch_queues, model, flags):
type_buf[mul_obs["position"]].append(2) type_buf[mul_obs["position"]].append(2)
size[mul_obs["position"]] += 1 size[mul_obs["position"]] += 1
while True: while True:
if len(obs['legal_actions']) > 1:
with torch.no_grad(): with torch.no_grad():
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags) agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
_action_idx = int(agent_output['action'].cpu().detach().numpy()) _action_idx = int(agent_output['action'].cpu().detach().numpy())
action = obs['legal_actions'][_action_idx] action = obs['legal_actions'][_action_idx]
else:
action = obs['legal_actions'][0]
if flags.old_model and position != 'bidding': if flags.old_model and position != 'bidding':
obs_action_buf[position].append(_cards2tensor(action)) obs_action_buf[position].append(_cards2tensor(action))
obs_x_no_action[position].append(env_output['obs_x_no_action']) obs_x_no_action[position].append(env_output['obs_x_no_action'])