修bug
This commit is contained in:
parent
05aa179ba6
commit
6dcfe074de
|
@ -53,10 +53,16 @@ def get_batch(b_queues, position, flags, lock):
|
||||||
buffer = []
|
buffer = []
|
||||||
while len(buffer) < flags.batch_size:
|
while len(buffer) < flags.batch_size:
|
||||||
buffer.append(b_queue.get())
|
buffer.append(b_queue.get())
|
||||||
|
if flags.old_model and position != 'bidding':
|
||||||
batch = {
|
batch = {
|
||||||
key: torch.stack([m[key] for m in buffer], dim=1)
|
key: torch.stack([m[key] for m in buffer], dim=1)
|
||||||
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
|
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
batch = {
|
||||||
|
key: torch.stack([m[key] for m in buffer], dim=1)
|
||||||
|
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"]
|
||||||
|
}
|
||||||
del buffer
|
del buffer
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue