This commit is contained in:
zhiyang7 2021-12-10 16:28:03 +08:00
parent 05aa179ba6
commit 6dcfe074de
1 changed files with 10 additions and 4 deletions

View File

@ -53,10 +53,16 @@ def get_batch(b_queues, position, flags, lock):
buffer = []
while len(buffer) < flags.batch_size:
buffer.append(b_queue.get())
batch = {
key: torch.stack([m[key] for m in buffer], dim=1)
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
}
if flags.old_model and position != 'bidding':
batch = {
key: torch.stack([m[key] for m in buffer], dim=1)
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
}
else:
batch = {
key: torch.stack([m[key] for m in buffer], dim=1)
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"]
}
del buffer
return batch