diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index ed01c08..781ef97 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -219,11 +219,15 @@ def train(flags): torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) if flags.enable_onnx and position: model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position) - onnx_params = learner_model.get_model(position).get_onnx_params() + onnx_params = learner_model.get_model(position)\ + .get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device)) torch.onnx.export( learner_model.get_model(position), onnx_params['args'], model_path, + export_params=True, + opset_version=10, + do_constant_folding=True, input_names=onnx_params['input_names'], output_names=onnx_params['output_names'], dynamic_axes=onnx_params['dynamic_axes'] diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index b888884..6efde17 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -26,20 +26,23 @@ class LandlordLstmModel(nn.Module): self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) - def get_onnx_params(self): + def get_onnx_params(self, device): return { 'args': ( - torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32), - torch.tensor(np.zeros((1, 887)), dtype=torch.float32), + torch.randn(1, 5, 432, requires_grad=True, device=device), + torch.randn(1, 887, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { - 0: "legal_actions" + 0: "batch_size" }, 'x_batch': { - 0: "legal_actions" + 0: "batch_size" + }, + 'values': { + 0: "batch_size" } } } @@ -72,11 +75,11 @@ class FarmerLstmModel(nn.Module): self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) - def get_onnx_params(self): + def get_onnx_params(self, device): return { 'args': ( - torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32), - torch.tensor(np.zeros((1, 1219)), dtype=torch.float32) + torch.randn(1, 5, 432, requires_grad=True, device=device), + torch.randn(1, 1219, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], @@ -86,6 +89,9 @@ class FarmerLstmModel(nn.Module): }, 'x_batch': { 0: "legal_actions" + }, + 'values': { + 0: "batch_size" } } } @@ -107,62 +113,6 @@ class FarmerLstmModel(nn.Module): x = self.dense6(x) return dict(values=x) -class LandlordLstmModelLegacy(nn.Module): - def __init__(self): - super().__init__() - self.lstm = nn.LSTM(432, 128, batch_first=True) - self.dense1 = nn.Linear(860 + 128, 1024) - self.dense2 = nn.Linear(1024, 1024) - self.dense3 = nn.Linear(1024, 768) - self.dense4 = nn.Linear(768, 512) - self.dense5 = nn.Linear(512, 256) - self.dense6 = nn.Linear(256, 1) - - def forward(self, z, x): - lstm_out, (h_n, _) = self.lstm(z) - lstm_out = lstm_out[:,-1,:] - x = torch.cat([lstm_out,x], dim=-1) - x = self.dense1(x) - x = torch.relu(x) - x = self.dense2(x) - x = torch.relu(x) - x = self.dense3(x) - x = torch.relu(x) - x = self.dense4(x) - x = torch.relu(x) - x = self.dense5(x) - x = torch.relu(x) - x = self.dense6(x) - return dict(values=x) - -class FarmerLstmModelLegacy(nn.Module): - def __init__(self): - super().__init__() - self.lstm = nn.LSTM(432, 128, batch_first=True) - self.dense1 = nn.Linear(1192 + 128, 1024) - self.dense2 = nn.Linear(1024, 1024) - self.dense3 = nn.Linear(1024, 768) - self.dense4 = nn.Linear(768, 512) - self.dense5 = nn.Linear(512, 256) - self.dense6 = nn.Linear(256, 1) - - def forward(self, z, x): - lstm_out, (h_n, _) = self.lstm(z) - lstm_out = lstm_out[:,-1,:] - x = torch.cat([lstm_out,x], dim=-1) - x = self.dense1(x) - x = torch.relu(x) - x = self.dense2(x) - x = torch.relu(x) - x = self.dense3(x) - x = torch.relu(x) - x = self.dense4(x) - x = torch.relu(x) - x = self.dense5(x) - x = torch.relu(x) - x = self.dense6(x) - return dict(values=x) - # 用于ResNet18和34的残差块,用的是2个3x3的卷积 class BasicBlock(nn.Module): expansion = 1 @@ -221,20 +171,23 @@ class GeneralModel(nn.Module): self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def get_onnx_params(self): + def get_onnx_params(self, device=None): return { 'args': ( - torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32, device='cuda:0'), - torch.tensor(np.zeros((1, 80)), dtype=torch.float32, device='cuda:0') + torch.randn(1, 40, 108, requires_grad=True, device=device), + torch.randn(1, 56, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { - 0: "legal_actions" + 0: "batch_size" }, 'x_batch': { - 0: "legal_actions" + 0: "batch_size" + }, + 'values': { + 0: "batch_size" } } } @@ -261,69 +214,11 @@ model_dict['landlord'] = LandlordLstmModel model_dict['landlord_up'] = FarmerLstmModel model_dict['landlord_front'] = FarmerLstmModel model_dict['landlord_down'] = FarmerLstmModel -model_dict_legacy = {} -model_dict_legacy['landlord'] = LandlordLstmModelLegacy -model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy -model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy -model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy model_dict_new = {} model_dict_new['landlord'] = GeneralModel model_dict_new['landlord_up'] = GeneralModel model_dict_new['landlord_front'] = GeneralModel model_dict_new['landlord_down'] = GeneralModel -model_dict_lstm = {} -model_dict_lstm['landlord'] = GeneralModel -model_dict_lstm['landlord_up'] = GeneralModel -model_dict_lstm['landlord_front'] = GeneralModel -model_dict_lstm['landlord_down'] = GeneralModel - -class General_Model: - """ - The wrapper for the three models. We also wrap several - interfaces such as share_memory, eval, etc. - """ - def __init__(self, device=0): - self.models = {} - if not device == "cpu": - device = 'cuda:' + str(device) - # model = GeneralModel().to(torch.device(device)) - self.models['landlord'] = GeneralModel1().to(torch.device(device)) - self.models['landlord_up'] = GeneralModel1().to(torch.device(device)) - self.models['landlord_front'] = GeneralModel1().to(torch.device(device)) - self.models['landlord_down'] = GeneralModel1().to(torch.device(device)) - - def forward(self, position, z, x, return_value=False, flags=None, debug=False): - model = self.models[position] - values = model.forward(z, x)['values'] - if return_value: - return dict(values=values) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - action = torch.randint(values.shape[0], (1,))[0] - else: - action = torch.argmax(values,dim=0)[0] - return dict(action=action) - - def share_memory(self): - self.models['landlord'].share_memory() - self.models['landlord_up'].share_memory() - self.models['landlord_front'].share_memory() - self.models['landlord_down'].share_memory() - - def eval(self): - self.models['landlord'].eval() - self.models['landlord_up'].eval() - self.models['landlord_front'].eval() - self.models['landlord_down'].eval() - - def parameters(self, position): - return self.models[position].parameters() - - def get_model(self, position): - return self.models[position] - - def get_models(self): - return self.models class OldModel: """ @@ -334,10 +229,11 @@ class OldModel: self.models = {} if not device == "cpu": device = 'cuda:' + str(device) - self.models['landlord'] = LandlordLstmModel().to(torch.device(device)) - self.models['landlord_up'] = FarmerLstmModel().to(torch.device(device)) - self.models['landlord_front'] = FarmerLstmModel().to(torch.device(device)) - self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device)) + self.device = torch.device(device) + self.models['landlord'] = LandlordLstmModel().to(self.device) + self.models['landlord_up'] = FarmerLstmModel().to(self.device) + self.models['landlord_front'] = FarmerLstmModel().to(self.device) + self.models['landlord_down'] = FarmerLstmModel().to(self.device) self.onnx_models = { 'landlord': None, 'landlord_up': None, @@ -349,7 +245,7 @@ class OldModel: self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path)) def get_onnx_params(self, position): - self.models[position].get_onnx_params() + self.models[position].get_onnx_params(self.device) def forward(self, position, z, x, return_value=False, flags=None): model = self.onnx_models[position] @@ -403,6 +299,7 @@ class Model: self.flags = flags if not device == "cpu": device = 'cuda:' + str(device) + self.device = torch.device(device) # model = GeneralModel().to(torch.device(device)) positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] if flags is not None and flags.enable_onnx: @@ -410,7 +307,7 @@ class Model: self.models[position] = None else: for position in positions: - self.models[position] = GeneralModel().to(torch.device(device)) + self.models[position] = GeneralModel().to(self.device) self.onnx_models = { 'landlord': None, 'landlord_up': None, @@ -425,7 +322,7 @@ class Model: self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) def get_onnx_params(self, position): - self.models[position].get_onnx_params() + self.models[position].get_onnx_params(self.device) def forward(self, position, z, x, return_value=False, flags=None, debug=False): if self.flags.enable_onnx and len(self.onnx_models) == 0: