minor fix

This commit is contained in:
zhiyang7 2021-12-07 17:44:36 +08:00
parent 4a55570be6
commit ee409846f3
1 changed files with 1 additions and 1 deletions

View File

@ -40,7 +40,7 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock):
if position != "bidding":
episode_returns = batch['episode_return'][batch['done'] & (batch["obs_type"] == position_index[position])]
else:
episode_returns = batch['episode_return'][batch['done'] & ((batch["obs_type"] == 41) | (batch["obs_type"] == 42) | (batch["obs_type"] == 43))]
episode_returns = batch['episode_return'][batch['done'] & ((batch["obs_type"] == 41) | (batch["obs_type"] == 42) | (batch["obs_type"] == 43) | (batch["obs_type"] == 44))]
if len(episode_returns) > 0:
mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device))
with lock: