import torch import numpy as np import os import onnxruntime from onnxruntime.datasets import get_example from douzero.env.env import get_obs def _load_model(position, model_path, model_type, use_legacy, use_lite): from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy model = None if model_type == "general": if use_lite: model = model_dict_new_lite[position]() else: model = model_dict_new[position]() else: if use_legacy: model = model_dict_legacy[position]() else: model = model_dict[position]() model_state_dict = model.state_dict() if torch.cuda.is_available(): pretrained = torch.load(model_path, map_location='cuda:0') else: pretrained = torch.load(model_path, map_location='cpu') pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict} model_state_dict.update(pretrained) model.load_state_dict(model_state_dict) # torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt")) if torch.cuda.is_available(): model.cuda() model.eval() onnx_params = model.get_onnx_params(torch.device('cpu')) model_path = model_path + '.onnx' if not os.path.exists(model_path): torch.onnx.export( model, onnx_params['args'], model_path, export_params=True, opset_version=10, do_constant_folding=True, input_names=onnx_params['input_names'], output_names=onnx_params['output_names'], dynamic_axes=onnx_params['dynamic_axes'] ) return model class DeepAgent: def __init__(self, position, model_path): self.use_legacy = True if "legacy" in model_path else False self.lite_model = True if "lite" in model_path else False self.model_type = "general" if "resnet" in model_path else "old" self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model) self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider']) self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', 13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'} def act(self, infoset): if len(infoset.legal_actions) == 1: return infoset.legal_actions[0] obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model) # z_batch = torch.from_numpy(obs['z_batch']).float() # x_batch = torch.from_numpy(obs['x_batch']).float() # if torch.cuda.is_available(): # z_batch, x_batch = z_batch.cuda(), x_batch.cuda() # y_pred = self.model.forward(z_batch, x_batch)['values'] # y_pred = y_pred.detach().cpu().numpy() y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0] best_action_index = np.argmax(y_pred, axis=0)[0] best_action = infoset.legal_actions[best_action_index] # action_list = [] # output = "" # for i, action in enumerate(y_pred): # action_list.append((y_pred[i].item(), "".join([self.EnvCard2RealCard[ii] for ii in infoset.legal_actions[i]]) if len(infoset.legal_actions[i]) != 0 else "Pass")) # action_list.sort(key=lambda x: x[0], reverse=True) # value_list = [] # for action in action_list: # output += str(round(action[0],3)) + " " + action[1] + "\n" # value_list.append(action[0]) # # print(value_list) # print(output) # print("--------------------\n") return best_action