Fix BUG
This commit is contained in:
parent
16198fbb65
commit
bfd22bd5e3
|
@ -220,8 +220,12 @@ def train(flags):
|
||||||
nonlocal frames, position_frames, stats
|
nonlocal frames, position_frames, stats
|
||||||
while frames < flags.total_frames:
|
while frames < flags.total_frames:
|
||||||
batch = get_batch(batch_queues, position, flags, local_lock)
|
batch = get_batch(batch_queues, position, flags, local_lock)
|
||||||
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
|
if 'uni' in optimizers.keys():
|
||||||
optimizers['uni'], flags, position_lock)
|
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
|
||||||
|
optimizers['uni'], flags, position_lock)
|
||||||
|
else:
|
||||||
|
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
|
||||||
|
optimizers[position], flags, position_lock)
|
||||||
with lock:
|
with lock:
|
||||||
for k in _stats:
|
for k in _stats:
|
||||||
stats[k] = _stats[k]
|
stats[k] = _stats[k]
|
||||||
|
|
Loading…
Reference in New Issue