diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 7232128..ea003f6 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -53,10 +53,16 @@ def get_batch(b_queues, position, flags, lock): buffer = [] while len(buffer) < flags.batch_size: buffer.append(b_queue.get()) - batch = { - key: torch.stack([m[key] for m in buffer], dim=1) - for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"] - } + if flags.old_model and position != 'bidding': + batch = { + key: torch.stack([m[key] for m in buffer], dim=1) + for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"] + } + else: + batch = { + key: torch.stack([m[key] for m in buffer], dim=1) + for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"] + } del buffer return batch