This commit is contained in:
ZaneYork 2022-01-05 22:52:39 +08:00
parent 16198fbb65
commit bfd22bd5e3
1 changed files with 6 additions and 2 deletions

View File

@ -220,8 +220,12 @@ def train(flags):
nonlocal frames, position_frames, stats
while frames < flags.total_frames:
batch = get_batch(batch_queues, position, flags, local_lock)
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
optimizers['uni'], flags, position_lock)
if 'uni' in optimizers.keys():
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
optimizers['uni'], flags, position_lock)
else:
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
optimizers[position], flags, position_lock)
with lock:
for k in _stats:
stats[k] = _stats[k]