diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index 2d654c0..4a9916d 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -491,6 +491,51 @@ model_dict_new_lite['landlord_up'] = GeneralModelLite model_dict_new_lite['landlord_front'] = GeneralModelLite model_dict_new_lite['landlord_down'] = GeneralModelLite +def forward_logic(self_model, position, z, x, return_value=False, flags=None): + legal_count = len(z) + if legal_count >= 80: + partition_count = int(legal_count / 40) + sub_z = np.array_split(z, partition_count) + sub_x = np.array_split(x, partition_count) + if flags.enable_onnx: + model = self_model.onnx_models[position] + if legal_count >= 80: + values = np.ndarray((legal_count, 1)) + j = 0 + for i in range(partition_count): + onnx_out = model.run(None, {'z_batch': sub_z[i], 'x_batch': sub_x[i]}) + values[j:j+len(sub_z[i])] = onnx_out[0] + j += len(sub_z[i]) + else: + onnx_out = model.run(None, {'z_batch': z, 'x_batch': x}) + values = onnx_out[0] + else: + if legal_count >= 80: + values = np.ndarray((legal_count, 1)) + j = 0 + for i in range(partition_count): + model = self_model.models[position] + model_out = model.forward(sub_z[i], sub_x[i])['values'] + values[j:j+len(sub_z[i])] = model_out + j += len(sub_z[i]) + else: + model = self_model.models[position] + values = model.forward(z, x)['values'] + if return_value: + return dict(values = values.cpu().detach().numpy() if torch.is_tensor(values) else values) + else: + if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: + if torch.is_tensor(values): + action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy() + else: + action = np.random.randint(0, values.shape[0], (1,))[0] + else: + if torch.is_tensor(values): + action = torch.argmax(values,dim=0)[0].cpu().detach().numpy() + else: + action = np.argmax(values, axis=0)[0] + return dict(action = action) + class OldModel: """ The wrapper for the three models. We also wrap several @@ -531,27 +576,7 @@ class OldModel: self.models[position].get_onnx_params(self.device) def forward(self, position, z, x, return_value=False, flags=None): - if flags.enable_onnx: - model = self.onnx_models[position] - onnx_out = model.run(None, {'z_batch': z, 'x_batch': x}) - values = onnx_out[0] - else: - model = self.models[position] - values = model.forward(z, x)['values'] - if return_value: - return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy()) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - if flags.enable_onnx: - action = np.random.randint(0, values.shape[0], (1,))[0] - else: - action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy() - else: - if flags.enable_onnx: - action = np.argmax(values, axis=0)[0] - else: - action = torch.argmax(values,dim=0)[0].cpu().detach().numpy() - return dict(action = action) + return forward_logic(self, position, z, x, return_value, flags) def share_memory(self): if self.models['landlord'] is not None: @@ -620,27 +645,7 @@ class Model: self.models[position].get_onnx_params(self.device) def forward(self, position, z, x, return_value=False, flags=None, debug=False): - if flags.enable_onnx: - model = self.onnx_models[position] - onnx_out = model.run(None, {'z_batch': z, 'x_batch': x}) - values = onnx_out[0] - else: - model = self.models[position] - values = model.forward(z, x)['values'] - if return_value: - return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy()) - else: - if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: - if flags.enable_onnx: - action = np.random.randint(0, values.shape[0], (1,))[0] - else: - action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy() - else: - if flags.enable_onnx: - action = np.argmax(values, axis=0)[0] - else: - action = torch.argmax(values,dim=0)[0].cpu().detach().numpy() - return dict(action = action) + return forward_logic(self, position, z, x, return_value, flags) def share_memory(self): if self.models['landlord'] is not None: