diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 2d9aa0c..872b1d7 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -220,8 +220,12 @@ def train(flags): nonlocal frames, position_frames, stats while frames < flags.total_frames: batch = get_batch(batch_queues, position, flags, local_lock) - _stats = learn(position, actor_model, learner_model.get_model(position), batch, - optimizers['uni'], flags, position_lock) + if 'uni' in optimizers.keys(): + _stats = learn(position, actor_model, learner_model.get_model(position), batch, + optimizers['uni'], flags, position_lock) + else: + _stats = learn(position, actor_model, learner_model.get_model(position), batch, + optimizers[position], flags, position_lock) with lock: for k in _stats: stats[k] = _stats[k]