提速降费
This commit is contained in:
parent
364f882014
commit
c82b834d89
|
@ -123,11 +123,13 @@ def act(i, device, batch_queues, model, flags):
|
||||||
type_buf[mul_obs["position"]].append(2)
|
type_buf[mul_obs["position"]].append(2)
|
||||||
size[mul_obs["position"]] += 1
|
size[mul_obs["position"]] += 1
|
||||||
while True:
|
while True:
|
||||||
|
if len(obs['legal_actions']) > 1:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
||||||
_action_idx = int(agent_output['action'].cpu().detach().numpy())
|
_action_idx = int(agent_output['action'].cpu().detach().numpy())
|
||||||
action = obs['legal_actions'][_action_idx]
|
action = obs['legal_actions'][_action_idx]
|
||||||
|
else:
|
||||||
|
action = obs['legal_actions'][0]
|
||||||
if flags.old_model and position != 'bidding':
|
if flags.old_model and position != 'bidding':
|
||||||
obs_action_buf[position].append(_cards2tensor(action))
|
obs_action_buf[position].append(_cards2tensor(action))
|
||||||
obs_x_no_action[position].append(env_output['obs_x_no_action'])
|
obs_x_no_action[position].append(env_output['obs_x_no_action'])
|
||||||
|
|
Loading…
Reference in New Issue