116 lines
5.0 KiB
Python
116 lines
5.0 KiB
Python
import torch
|
|
import numpy as np
|
|
import os
|
|
import onnxruntime
|
|
from onnxruntime.datasets import get_example
|
|
|
|
from douzero.env.env import get_obs
|
|
|
|
def _load_model(position, model_path, model_type, use_legacy, use_lite, use_unified, use_onnx=False):
|
|
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite, model_dict_uni_lite
|
|
model_path_onnx = model_path + '.onnx'
|
|
if use_onnx and os.path.exists(model_path_onnx):
|
|
return None
|
|
model = None
|
|
if model_type == "general":
|
|
if use_lite:
|
|
model = model_dict_new_lite[position]()
|
|
else:
|
|
model = model_dict_new[position]()
|
|
else:
|
|
if use_legacy:
|
|
model = model_dict_legacy[position]()
|
|
elif use_unified:
|
|
if use_lite:
|
|
model = model_dict_uni_lite[position]()
|
|
else:
|
|
model = model_dict[position]()
|
|
else:
|
|
if use_lite:
|
|
model = model_dict_lite[position]()
|
|
else:
|
|
model = model_dict[position]()
|
|
model_state_dict = model.state_dict()
|
|
pretrained = torch.load(model_path, map_location='cpu')
|
|
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
|
|
model_state_dict.update(pretrained)
|
|
model.load_state_dict(model_state_dict)
|
|
model.eval()
|
|
if use_onnx and not os.path.exists(model_path_onnx):
|
|
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
|
torch.onnx.export(
|
|
model,
|
|
onnx_params['args'],
|
|
model_path_onnx,
|
|
export_params=True,
|
|
opset_version=10,
|
|
do_constant_folding=True,
|
|
input_names=onnx_params['input_names'],
|
|
output_names=onnx_params['output_names'],
|
|
dynamic_axes=onnx_params['dynamic_axes']
|
|
)
|
|
return None
|
|
|
|
return model
|
|
|
|
class DeepAgent:
|
|
|
|
def __init__(self, position, model_path, use_onnx=False):
|
|
self.use_legacy = True if "legacy" in model_path else False
|
|
self.use_unified = True if "uni" in model_path else False
|
|
self.lite_model = True if "lite" in model_path else False
|
|
self.model_type = "general" if "resnet" in model_path else "old"
|
|
self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model, self.use_unified, use_onnx=use_onnx)
|
|
self.onnx_model_path = os.path.abspath(model_path + '.onnx')
|
|
self.use_onnx = use_onnx
|
|
self.onnx_model = None
|
|
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
|
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
|
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
|
def act(self, infoset, with_confidence = False):
|
|
if self.use_onnx and self.onnx_model is None:
|
|
if torch.cuda.is_available():
|
|
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CUDAExecutionProvider'])
|
|
else:
|
|
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider'])
|
|
|
|
if not with_confidence and len(infoset.legal_actions) == 1:
|
|
return infoset.legal_actions[0]
|
|
|
|
obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model, self.use_unified)
|
|
|
|
if self.onnx_model is None:
|
|
z_batch = torch.from_numpy(obs['z_batch']).float()
|
|
x_batch = torch.from_numpy(obs['x_batch']).float()
|
|
if torch.cuda.is_available():
|
|
z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
|
|
y_pred = self.model.forward(z_batch, x_batch)['values']
|
|
y_pred = y_pred.detach().cpu().numpy()
|
|
else:
|
|
y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0]
|
|
|
|
if with_confidence:
|
|
y_pred = y_pred.flatten()
|
|
size = min(3, len(y_pred))
|
|
best_action_index = np.argpartition(y_pred, -size)[-size:]
|
|
best_action_confidence = y_pred[best_action_index]
|
|
best_action = [infoset.legal_actions[index] for index in best_action_index]
|
|
|
|
return best_action, best_action_confidence
|
|
|
|
best_action_index = np.argmax(y_pred, axis=0)[0]
|
|
best_action = infoset.legal_actions[best_action_index]
|
|
# action_list = []
|
|
# output = ""
|
|
# for i, action in enumerate(y_pred):
|
|
# action_list.append((y_pred[i].item(), "".join([self.EnvCard2RealCard[ii] for ii in infoset.legal_actions[i]]) if len(infoset.legal_actions[i]) != 0 else "Pass"))
|
|
# action_list.sort(key=lambda x: x[0], reverse=True)
|
|
# value_list = []
|
|
# for action in action_list:
|
|
# output += str(round(action[0],3)) + " " + action[1] + "\n"
|
|
# value_list.append(action[0])
|
|
# # print(value_list)
|
|
# print(output)
|
|
# print("--------------------\n")
|
|
return best_action
|