diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index a91e522..4a3c53d 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -120,7 +120,11 @@ def train(flags): ] frames, stats = 0, {k: 0 for k in stat_keys} position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} - position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} + if flags.unified_model: + lock = threading.Lock() + position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock} + else: + position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} def sync_onnx_model(frames): p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid) @@ -153,11 +157,17 @@ def train(flags): checkpoint_states = torch.load( checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu") ) - for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down'] - learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) - optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k]) + if flags.unified_model: + learner_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni']) + optimizers['uni'].load_state_dict(checkpoint_states["optimizer_state_dict"]['uni']) if not flags.enable_onnx: - actor_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) + actor_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni']) + else: + for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down'] + learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) + optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k]) + if not flags.enable_onnx: + actor_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) stats = checkpoint_states["stats"] frames = checkpoint_states["frames"] @@ -210,7 +220,7 @@ def train(flags): while frames < flags.total_frames: batch = get_batch(batch_queues, position, flags, local_lock) _stats = learn(position, actor_model, learner_model.get_model(position), batch, - optimizers[position], flags, position_lock) + optimizers['uni'], flags, position_lock) with lock: for k in _stats: stats[k] = _stats[k] @@ -248,7 +258,7 @@ def train(flags): }, checkpointpath + '.new') # Save the weights for evaluation purpose - for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: + for position in ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: model_weights_dir = os.path.expandvars(os.path.expanduser( '%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt'))) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 67266e1..7caa44d 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -104,12 +104,20 @@ def create_optimizers(flags, learner_model): """ positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] optimizers = {} - for position in positions: + if flags.unified_model: + position = 'uni' optimizer = RAdam( learner_model.parameters(position), lr=flags.learning_rate, eps=flags.epsilon) optimizers[position] = optimizer + else: + for position in positions: + optimizer = RAdam( + learner_model.parameters(position), + lr=flags.learning_rate, + eps=flags.epsilon) + optimizers[position] = optimizer return optimizers