diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index 77bdd82..0db0943 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -385,6 +385,11 @@ model_dict_new['landlord'] = GeneralModel model_dict_new['landlord_up'] = GeneralModel model_dict_new['landlord_front'] = GeneralModel model_dict_new['landlord_down'] = GeneralModel +model_dict_new_lite = {} +model_dict_new_lite['landlord'] = GeneralModelLite +model_dict_new_lite['landlord_up'] = GeneralModelLite +model_dict_new_lite['landlord_front'] = GeneralModelLite +model_dict_new_lite['landlord_down'] = GeneralModelLite class OldModel: """ diff --git a/douzero/evaluation/deep_agent.py b/douzero/evaluation/deep_agent.py index 73162cc..daa4719 100644 --- a/douzero/evaluation/deep_agent.py +++ b/douzero/evaluation/deep_agent.py @@ -6,11 +6,14 @@ from onnxruntime.datasets import get_example from douzero.env.env import get_obs -def _load_model(position, model_path, model_type, use_legacy): - from douzero.dmc.models import model_dict_new, model_dict, model_dict_legacy +def _load_model(position, model_path, model_type, use_legacy, use_lite): + from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy model = None if model_type == "general": - model = model_dict_new[position]() + if use_lite: + model = model_dict_new_lite[position]() + else: + model = model_dict_new[position]() else: if use_legacy: model = model_dict_legacy[position]() @@ -51,7 +54,7 @@ class DeepAgent: self.use_legacy = True if "legacy" in model_path else False self.lite_model = True if "lite" in model_path else False self.model_type = "general" if "resnet" in model_path else "old" - self.model = _load_model(position, model_path, self.model_type, self.use_legacy) + self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model) self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider']) self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', @@ -62,10 +65,10 @@ class DeepAgent: obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model) - z_batch = torch.from_numpy(obs['z_batch']).float() - x_batch = torch.from_numpy(obs['x_batch']).float() - if torch.cuda.is_available(): - z_batch, x_batch = z_batch.cuda(), x_batch.cuda() + # z_batch = torch.from_numpy(obs['z_batch']).float() + # x_batch = torch.from_numpy(obs['x_batch']).float() + # if torch.cuda.is_available(): + # z_batch, x_batch = z_batch.cuda(), x_batch.cuda() # y_pred = self.model.forward(z_batch, x_batch)['values'] # y_pred = y_pred.detach().cpu().numpy() y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0]