From bfd22bd5e3056f3bba018d5cc298fc3c299a7b83 Mon Sep 17 00:00:00 2001 From: ZaneYork Date: Wed, 5 Jan 2022 22:52:39 +0800 Subject: [PATCH] Fix BUG --- douzero/dmc/dmc.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 2d9aa0c..872b1d7 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -220,8 +220,12 @@ def train(flags): nonlocal frames, position_frames, stats while frames < flags.total_frames: batch = get_batch(batch_queues, position, flags, local_lock) - _stats = learn(position, actor_model, learner_model.get_model(position), batch, - optimizers['uni'], flags, position_lock) + if 'uni' in optimizers.keys(): + _stats = learn(position, actor_model, learner_model.get_model(position), batch, + optimizers['uni'], flags, position_lock) + else: + _stats = learn(position, actor_model, learner_model.get_model(position), batch, + optimizers[position], flags, position_lock) with lock: for k in _stats: stats[k] = _stats[k]