diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index 3bead7c..8aeee5a 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -113,6 +113,104 @@ class FarmerLstmModel(nn.Module): x = self.dense6(x) return dict(values=x) +class LandlordLstmModelLegacy(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM(432, 128, batch_first=True) + self.dense1 = nn.Linear(860 + 128, 1024) + self.dense2 = nn.Linear(1024, 1024) + self.dense3 = nn.Linear(1024, 768) + self.dense4 = nn.Linear(768, 512) + self.dense5 = nn.Linear(512, 256) + self.dense6 = nn.Linear(256, 1) + + def get_onnx_params(self, device): + return { + 'args': ( + torch.randn(1, 5, 432, requires_grad=True, device=device), + torch.randn(1, 887, requires_grad=True, device=device), + ), + 'input_names': ['z_batch','x_batch'], + 'output_names': ['values'], + 'dynamic_axes': { + 'z_batch': { + 0: "batch_size" + }, + 'x_batch': { + 0: "batch_size" + }, + 'values': { + 0: "batch_size" + } + } + } + + def forward(self, z, x): + lstm_out, (h_n, _) = self.lstm(z) + lstm_out = lstm_out[:,-1,:] + x = torch.cat([lstm_out,x], dim=-1) + x = self.dense1(x) + x = torch.relu(x) + x = self.dense2(x) + x = torch.relu(x) + x = self.dense3(x) + x = torch.relu(x) + x = self.dense4(x) + x = torch.relu(x) + x = self.dense5(x) + x = torch.relu(x) + x = self.dense6(x) + return dict(values=x) + +class FarmerLstmModelLegacy(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM(432, 128, batch_first=True) + self.dense1 = nn.Linear(1192 + 128, 1024) + self.dense2 = nn.Linear(1024, 1024) + self.dense3 = nn.Linear(1024, 768) + self.dense4 = nn.Linear(768, 512) + self.dense5 = nn.Linear(512, 256) + self.dense6 = nn.Linear(256, 1) + + def get_onnx_params(self, device): + return { + 'args': ( + torch.randn(1, 5, 432, requires_grad=True, device=device), + torch.randn(1, 1219, requires_grad=True, device=device), + ), + 'input_names': ['z_batch','x_batch'], + 'output_names': ['values'], + 'dynamic_axes': { + 'z_batch': { + 0: "legal_actions" + }, + 'x_batch': { + 0: "legal_actions" + }, + 'values': { + 0: "batch_size" + } + } + } + + def forward(self, z, x): + lstm_out, (h_n, _) = self.lstm(z) + lstm_out = lstm_out[:,-1,:] + x = torch.cat([lstm_out,x], dim=-1) + x = self.dense1(x) + x = torch.relu(x) + x = self.dense2(x) + x = torch.relu(x) + x = self.dense3(x) + x = torch.relu(x) + x = self.dense4(x) + x = torch.relu(x) + x = self.dense5(x) + x = torch.relu(x) + x = self.dense6(x) + return dict(values=x) + # 用于ResNet18和34的残差块,用的是2个3x3的卷积 class BasicBlock(nn.Module): expansion = 1 @@ -214,6 +312,11 @@ model_dict['landlord'] = LandlordLstmModel model_dict['landlord_up'] = FarmerLstmModel model_dict['landlord_front'] = FarmerLstmModel model_dict['landlord_down'] = FarmerLstmModel +model_dict_legacy = {} +model_dict_legacy['landlord'] = LandlordLstmModelLegacy +model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy +model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy +model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy model_dict_new = {} model_dict_new['landlord'] = GeneralModel model_dict_new['landlord_up'] = GeneralModel diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 5bcf662..29d1d2e 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -41,7 +41,7 @@ log.setLevel(logging.INFO) Buffers = typing.Dict[str, typing.List[torch.Tensor]] def create_env(flags): - return Env(flags.objective, flags.old_model) + return Env(flags.objective, flags.old_model, flags.lagacy_model) def get_batch(b_queues, position, flags, lock): """ diff --git a/douzero/env/env.py b/douzero/env/env.py index 89c09ab..75b2bd2 100644 --- a/douzero/env/env.py +++ b/douzero/env/env.py @@ -33,7 +33,7 @@ class Env: Doudizhu multi-agent wrapper """ - def __init__(self, objective, old_model): + def __init__(self, objective, old_model, legacy_model=False): """ Objective is wp/adp/logadp. It indicates whether considers bomb in reward calculation. Here, we use dummy agents. @@ -108,7 +108,7 @@ class Env: self._env.info_sets[pos].player_id = pid self.infoset = self._game_infoset - return get_obs(self.infoset, self.use_general) + return get_obs(self.infoset, self.use_general, self.use_legacy) def step(self, action): """ @@ -234,7 +234,7 @@ class DummyAgent(object): self.action = action -def get_obs(infoset, use_general=True): +def get_obs(infoset, use_general=True, use_legacy = False): """ This function obtains observations with imperfect information from the infoset. It has three branches since we encode @@ -264,13 +264,13 @@ def get_obs(infoset, use_general=True): return _get_obs_general(infoset, infoset.player_position) else: if infoset.player_position == 'landlord': - return _get_obs_landlord(infoset) + return _get_obs_landlord(infoset, use_legacy) elif infoset.player_position == 'landlord_up': - return _get_obs_landlord_up(infoset) + return _get_obs_landlord_up(infoset, use_legacy) elif infoset.player_position == 'landlord_front': - return _get_obs_landlord_front(infoset) + return _get_obs_landlord_front(infoset, use_legacy) elif infoset.player_position == 'landlord_down': - return _get_obs_landlord_down(infoset) + return _get_obs_landlord_down(infoset, use_legacy) else: raise ValueError('') @@ -377,19 +377,23 @@ def _process_action_seq(sequence, length=20, new_model=True): return sequence -def _get_one_hot_bomb(bomb_num): +def _get_one_hot_bomb(bomb_num, use_legacy = False): """ A utility function to encode the number of bombs into one-hot representation. """ - one_hot = np.zeros(56) # 14 + 15 + 27 - one_hot[bomb_num[0]] = 1 - one_hot[14 + bomb_num[1]] = 1 - one_hot[29 + bomb_num[2]] = 1 + if use_legacy: + one_hot = np.zeros(29) + one_hot[bomb_num[0] + bomb_num[1]] = 1 + else: + one_hot = np.zeros(56) # 14 + 15 + 27 + one_hot[bomb_num[0]] = 1 + one_hot[14 + bomb_num[1]] = 1 + one_hot[29 + bomb_num[2]] = 1 return one_hot -def _get_obs_landlord(infoset): +def _get_obs_landlord(infoset, use_legacy = False): """ Obttain the landlord features. See Table 4 in https://arxiv.org/pdf/2106.06135.pdf @@ -448,7 +452,7 @@ def _get_obs_landlord(infoset): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, use_legacy) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -489,7 +493,7 @@ def _get_obs_landlord(infoset): } return obs -def _get_obs_landlord_up(infoset): +def _get_obs_landlord_up(infoset, use_legacy = False): """ Obttain the landlord_up features. See Table 5 in https://arxiv.org/pdf/2106.06135.pdf @@ -563,7 +567,7 @@ def _get_obs_landlord_up(infoset): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, use_legacy) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -610,7 +614,7 @@ def _get_obs_landlord_up(infoset): } return obs -def _get_obs_landlord_front(infoset): +def _get_obs_landlord_front(infoset, use_legacy = False): """ Obttain the landlord_front features. See Table 5 in https://arxiv.org/pdf/2106.06135.pdf @@ -684,7 +688,7 @@ def _get_obs_landlord_front(infoset): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, use_legacy) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -731,7 +735,7 @@ def _get_obs_landlord_front(infoset): } return obs -def _get_obs_landlord_down(infoset): +def _get_obs_landlord_down(infoset, use_legacy = False): """ Obttain the landlord_down features. See Table 5 in https://arxiv.org/pdf/2106.06135.pdf @@ -805,7 +809,7 @@ def _get_obs_landlord_down(infoset): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, use_legacy) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) diff --git a/douzero/evaluation/deep_agent.py b/douzero/evaluation/deep_agent.py index 1c896f9..b987dd8 100644 --- a/douzero/evaluation/deep_agent.py +++ b/douzero/evaluation/deep_agent.py @@ -3,13 +3,16 @@ import numpy as np from douzero.env.env import get_obs -def _load_model(position, model_path, model_type): - from douzero.dmc.models import model_dict_new, model_dict, model_dict_new_legacy, model_dict_legacy +def _load_model(position, model_path, model_type, use_legacy): + from douzero.dmc.models import model_dict_new, model_dict, model_dict_legacy model = None if model_type == "general": model = model_dict_new[position]() else: - model = model_dict[position]() + if use_legacy: + model = model_dict_legacy[position]() + else: + model = model_dict[position]() model_state_dict = model.state_dict() if torch.cuda.is_available(): pretrained = torch.load(model_path, map_location='cuda:0') @@ -27,8 +30,9 @@ def _load_model(position, model_path, model_type): class DeepAgent: def __init__(self, position, model_path): + self.use_legacy = True if "legacy" in model_path else False self.model_type = "general" if "resnet" in model_path else "old" - self.model = _load_model(position, model_path, self.model_type) + self.model = _load_model(position, model_path, self.model_type, self.use_legacy) self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', 13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'} @@ -36,7 +40,7 @@ class DeepAgent: if len(infoset.legal_actions) == 1: return infoset.legal_actions[0] - obs = get_obs(infoset, self.model_type == "general") + obs = get_obs(infoset, self.model_type == "general", self.use_legacy) z_batch = torch.from_numpy(obs['z_batch']).float() x_batch = torch.from_numpy(obs['x_batch']).float()