提速降费
This commit is contained in:
parent
364f882014
commit
c82b834d89
|
@ -123,11 +123,13 @@ def act(i, device, batch_queues, model, flags):
|
|||
type_buf[mul_obs["position"]].append(2)
|
||||
size[mul_obs["position"]] += 1
|
||||
while True:
|
||||
|
||||
with torch.no_grad():
|
||||
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
||||
_action_idx = int(agent_output['action'].cpu().detach().numpy())
|
||||
action = obs['legal_actions'][_action_idx]
|
||||
if len(obs['legal_actions']) > 1:
|
||||
with torch.no_grad():
|
||||
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
||||
_action_idx = int(agent_output['action'].cpu().detach().numpy())
|
||||
action = obs['legal_actions'][_action_idx]
|
||||
else:
|
||||
action = obs['legal_actions'][0]
|
||||
if flags.old_model and position != 'bidding':
|
||||
obs_action_buf[position].append(_cards2tensor(action))
|
||||
obs_x_no_action[position].append(env_output['obs_x_no_action'])
|
||||
|
|
Loading…
Reference in New Issue