diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 175a1e4..8824df0 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -40,7 +40,7 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock): if position != "bidding": episode_returns = batch['episode_return'][batch['done'] & (batch["obs_type"] == position_index[position])] else: - episode_returns = batch['episode_return'][batch['done'] & ((batch["obs_type"] == 41) | (batch["obs_type"] == 42) | (batch["obs_type"] == 43))] + episode_returns = batch['episode_return'][batch['done'] & ((batch["obs_type"] == 41) | (batch["obs_type"] == 42) | (batch["obs_type"] == 43) | (batch["obs_type"] == 44))] if len(episode_returns) > 0: mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device)) with lock: