2021-09-07 16:38:34 +08:00
import torch
import numpy as np
2021-12-21 10:46:24 +08:00
import os
import onnxruntime
from onnxruntime . datasets import get_example
2021-09-07 16:38:34 +08:00
from douzero . env . env import get_obs
2021-12-23 09:22:38 +08:00
def _load_model ( position , model_path , model_type , use_legacy , use_lite ) :
2021-12-24 09:49:36 +08:00
from douzero . dmc . models import model_dict_new , model_dict_new_lite , model_dict , model_dict_legacy , model_dict_lite
2021-09-07 16:38:34 +08:00
model = None
if model_type == " general " :
2021-12-23 09:22:38 +08:00
if use_lite :
model = model_dict_new_lite [ position ] ( )
else :
model = model_dict_new [ position ] ( )
2021-09-07 16:38:34 +08:00
else :
2021-12-20 10:02:55 +08:00
if use_legacy :
model = model_dict_legacy [ position ] ( )
else :
2021-12-24 09:49:36 +08:00
if use_lite :
model = model_dict_lite [ position ] ( )
else :
model = model_dict [ position ] ( )
2021-09-07 16:38:34 +08:00
model_state_dict = model . state_dict ( )
if torch . cuda . is_available ( ) :
pretrained = torch . load ( model_path , map_location = ' cuda:0 ' )
else :
pretrained = torch . load ( model_path , map_location = ' cpu ' )
pretrained = { k : v for k , v in pretrained . items ( ) if k in model_state_dict }
model_state_dict . update ( pretrained )
model . load_state_dict ( model_state_dict )
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
if torch . cuda . is_available ( ) :
model . cuda ( )
model . eval ( )
2021-12-21 10:46:24 +08:00
onnx_params = model . get_onnx_params ( torch . device ( ' cpu ' ) )
model_path = model_path + ' .onnx '
if not os . path . exists ( model_path ) :
torch . onnx . export (
model ,
onnx_params [ ' args ' ] ,
model_path ,
export_params = True ,
opset_version = 10 ,
do_constant_folding = True ,
input_names = onnx_params [ ' input_names ' ] ,
output_names = onnx_params [ ' output_names ' ] ,
dynamic_axes = onnx_params [ ' dynamic_axes ' ]
)
2021-09-07 16:38:34 +08:00
return model
class DeepAgent :
def __init__ ( self , position , model_path ) :
2021-12-20 10:02:55 +08:00
self . use_legacy = True if " legacy " in model_path else False
2021-12-22 21:19:10 +08:00
self . lite_model = True if " lite " in model_path else False
2021-09-07 16:38:34 +08:00
self . model_type = " general " if " resnet " in model_path else " old "
2021-12-23 09:22:38 +08:00
self . model = _load_model ( position , model_path , self . model_type , self . use_legacy , self . lite_model )
2021-12-21 10:46:24 +08:00
self . onnx_model = onnxruntime . InferenceSession ( get_example ( os . path . abspath ( model_path + ' .onnx ' ) ) , providers = [ ' CPUExecutionProvider ' ] )
2021-09-07 16:38:34 +08:00
self . EnvCard2RealCard = { 3 : ' 3 ' , 4 : ' 4 ' , 5 : ' 5 ' , 6 : ' 6 ' , 7 : ' 7 ' ,
8 : ' 8 ' , 9 : ' 9 ' , 10 : ' T ' , 11 : ' J ' , 12 : ' Q ' ,
13 : ' K ' , 14 : ' A ' , 17 : ' 2 ' , 20 : ' X ' , 30 : ' D ' }
def act ( self , infoset ) :
if len ( infoset . legal_actions ) == 1 :
return infoset . legal_actions [ 0 ]
2021-12-22 21:19:10 +08:00
obs = get_obs ( infoset , self . model_type == " general " , self . use_legacy , self . lite_model )
2021-09-07 16:38:34 +08:00
2021-12-23 09:22:38 +08:00
# z_batch = torch.from_numpy(obs['z_batch']).float()
# x_batch = torch.from_numpy(obs['x_batch']).float()
# if torch.cuda.is_available():
# z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
2021-12-21 10:46:24 +08:00
# y_pred = self.model.forward(z_batch, x_batch)['values']
# y_pred = y_pred.detach().cpu().numpy()
y_pred = self . onnx_model . run ( None , { ' z_batch ' : obs [ ' z_batch ' ] , ' x_batch ' : obs [ ' x_batch ' ] } ) [ 0 ]
2021-09-07 16:38:34 +08:00
best_action_index = np . argmax ( y_pred , axis = 0 ) [ 0 ]
best_action = infoset . legal_actions [ best_action_index ]
# action_list = []
# output = ""
# for i, action in enumerate(y_pred):
# action_list.append((y_pred[i].item(), "".join([self.EnvCard2RealCard[ii] for ii in infoset.legal_actions[i]]) if len(infoset.legal_actions[i]) != 0 else "Pass"))
# action_list.sort(key=lambda x: x[0], reverse=True)
# value_list = []
# for action in action_list:
# output += str(round(action[0],3)) + " " + action[1] + "\n"
# value_list.append(action[0])
# # print(value_list)
# print(output)
# print("--------------------\n")
return best_action