修bug
This commit is contained in:
parent
05aa179ba6
commit
6dcfe074de
|
@ -53,10 +53,16 @@ def get_batch(b_queues, position, flags, lock):
|
|||
buffer = []
|
||||
while len(buffer) < flags.batch_size:
|
||||
buffer.append(b_queue.get())
|
||||
batch = {
|
||||
key: torch.stack([m[key] for m in buffer], dim=1)
|
||||
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
|
||||
}
|
||||
if flags.old_model and position != 'bidding':
|
||||
batch = {
|
||||
key: torch.stack([m[key] for m in buffer], dim=1)
|
||||
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
|
||||
}
|
||||
else:
|
||||
batch = {
|
||||
key: torch.stack([m[key] for m in buffer], dim=1)
|
||||
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"]
|
||||
}
|
||||
del buffer
|
||||
return batch
|
||||
|
||||
|
|
Loading…
Reference in New Issue