diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 9212eee..d1652d1 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -94,7 +94,7 @@ def train(flags): models = {} for device in device_iterator: if flags.old_model: - model = OldModel(device="cpu", flags = flags) + model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model) else: model = Model(device="cpu", flags = flags, lite_model = flags.lite_model) model.share_memory() @@ -109,7 +109,7 @@ def train(flags): # Learner model for training if flags.old_model: - learner_model = OldModel(device=flags.training_device) + learner_model = OldModel(device=flags.training_device, lite_model = flags.lite_model) else: learner_model = Model(device=flags.training_device, lite_model = flags.lite_model) diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index 0db0943..4bac5fe 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -113,6 +113,106 @@ class FarmerLstmModel(nn.Module): x = self.dense6(x) return dict(values=x) + +class LandlordLstmModelLite(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM(276, 128, batch_first=True) + self.dense1 = nn.Linear(590 + 128, 768) + self.dense2 = nn.Linear(768, 768) + self.dense3 = nn.Linear(768, 768) + self.dense4 = nn.Linear(768, 768) + self.dense5 = nn.Linear(768, 768) + self.dense6 = nn.Linear(768, 1) + + def get_onnx_params(self, device): + return { + 'args': ( + torch.randn(1, 5, 276, requires_grad=True, device=device), + torch.randn(1, 590, requires_grad=True, device=device), + ), + 'input_names': ['z_batch','x_batch'], + 'output_names': ['values'], + 'dynamic_axes': { + 'z_batch': { + 0: "batch_size" + }, + 'x_batch': { + 0: "batch_size" + }, + 'values': { + 0: "batch_size" + } + } + } + + def forward(self, z, x): + lstm_out, (h_n, _) = self.lstm(z) + lstm_out = lstm_out[:,-1,:] + x = torch.cat([lstm_out,x], dim=-1) + x = self.dense1(x) + x = torch.relu(x) + x = self.dense2(x) + x = torch.relu(x) + x = self.dense3(x) + x = torch.relu(x) + x = self.dense4(x) + x = torch.relu(x) + x = self.dense5(x) + x = torch.relu(x) + x = self.dense6(x) + return dict(values=x) + +class FarmerLstmModelLite(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM(276, 128, batch_first=True) + self.dense1 = nn.Linear(798 + 128, 768) + self.dense2 = nn.Linear(768, 768) + self.dense3 = nn.Linear(768, 768) + self.dense4 = nn.Linear(768, 768) + self.dense5 = nn.Linear(768, 768) + self.dense6 = nn.Linear(768, 1) + + def get_onnx_params(self, device): + return { + 'args': ( + torch.randn(1, 5, 276, requires_grad=True, device=device), + torch.randn(1, 798, requires_grad=True, device=device), + ), + 'input_names': ['z_batch','x_batch'], + 'output_names': ['values'], + 'dynamic_axes': { + 'z_batch': { + 0: "legal_actions" + }, + 'x_batch': { + 0: "legal_actions" + }, + 'values': { + 0: "batch_size" + } + } + } + + def forward(self, z, x): + lstm_out, (h_n, _) = self.lstm(z) + lstm_out = lstm_out[:,-1,:] + x = torch.cat([lstm_out,x], dim=-1) + x = self.dense1(x) + x = torch.relu(x) + x = self.dense2(x) + x = torch.relu(x) + x = self.dense3(x) + x = torch.relu(x) + x = self.dense4(x) + x = torch.relu(x) + x = self.dense5(x) + x = torch.relu(x) + x = self.dense6(x) + return dict(values=x) + + class LandlordLstmModelLegacy(nn.Module): def __init__(self): super().__init__() @@ -396,7 +496,7 @@ class OldModel: The wrapper for the three models. We also wrap several interfaces such as share_memory, eval, etc. """ - def __init__(self, device=0, flags=None): + def __init__(self, device=0, flags=None, lite_model = False): self.models = {} self.onnx_models = {} self.flags = flags @@ -409,8 +509,8 @@ class OldModel: self.models[position] = None else: for position in positions[1:]: - self.models[position] = FarmerLstmModel().to(self.device) - self.models['landlord'] = LandlordLstmModel().to(self.device) + self.models[position] = FarmerLstmModelLite().to(self.device) if lite_model else FarmerLstmModel().to(self.device) + self.models['landlord'] = LandlordLstmModelLite().to(self.device) if lite_model else LandlordLstmModel().to(self.device) self.onnx_models = { 'landlord': None, 'landlord_up': None, diff --git a/douzero/env/env.py b/douzero/env/env.py index e6ff0cf..4032ee2 100644 --- a/douzero/env/env.py +++ b/douzero/env/env.py @@ -496,7 +496,7 @@ def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False): my_action_batch[j, :] = _cards2array(action, compressed_form) landlord_up_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_up'], 25, 15 if compressed_form else 0) + infoset.num_cards_left_dict['landlord_up'], 25, 8 if compressed_form else 0) landlord_up_num_cards_left_batch = np.repeat( landlord_up_num_cards_left[np.newaxis, :], num_legal_actions, axis=0)