2021-09-07 16:38:34 +08:00
import torch
import numpy as np
2021-12-21 10:46:24 +08:00
import os
import onnxruntime
from onnxruntime . datasets import get_example
2021-09-07 16:38:34 +08:00
from douzero . env . env import get_obs
2021-12-26 20:11:05 +08:00
def _load_model ( position , model_path , model_type , use_legacy , use_lite , use_onnx = False ) :
2021-12-24 09:49:36 +08:00
from douzero . dmc . models import model_dict_new , model_dict_new_lite , model_dict , model_dict_legacy , model_dict_lite
2021-12-29 17:13:22 +08:00
model_path_onnx = model_path + ' .onnx '
if use_onnx and os . path . exists ( model_path_onnx ) :
2021-12-29 16:56:42 +08:00
return None
2021-09-07 16:38:34 +08:00
model = None
if model_type == " general " :
2021-12-23 09:22:38 +08:00
if use_lite :
model = model_dict_new_lite [ position ] ( )
else :
model = model_dict_new [ position ] ( )
2021-09-07 16:38:34 +08:00
else :
2021-12-20 10:02:55 +08:00
if use_legacy :
model = model_dict_legacy [ position ] ( )
else :
2021-12-24 09:49:36 +08:00
if use_lite :
model = model_dict_lite [ position ] ( )
else :
model = model_dict [ position ] ( )
2021-09-07 16:38:34 +08:00
model_state_dict = model . state_dict ( )
2021-12-27 14:52:38 +08:00
pretrained = torch . load ( model_path , map_location = ' cpu ' )
2021-09-07 16:38:34 +08:00
pretrained = { k : v for k , v in pretrained . items ( ) if k in model_state_dict }
model_state_dict . update ( pretrained )
model . load_state_dict ( model_state_dict )
model . eval ( )
2021-12-29 17:13:22 +08:00
if use_onnx and not os . path . exists ( model_path_onnx ) :
2021-12-29 16:56:42 +08:00
onnx_params = model . get_onnx_params ( torch . device ( ' cpu ' ) )
torch . onnx . export (
model ,
onnx_params [ ' args ' ] ,
2021-12-29 17:13:22 +08:00
model_path_onnx ,
2021-12-29 16:56:42 +08:00
export_params = True ,
opset_version = 10 ,
do_constant_folding = True ,
input_names = onnx_params [ ' input_names ' ] ,
output_names = onnx_params [ ' output_names ' ] ,
dynamic_axes = onnx_params [ ' dynamic_axes ' ]
)
2021-12-29 16:54:30 +08:00
return None
2021-12-21 10:46:24 +08:00
2021-09-07 16:38:34 +08:00
return model
class DeepAgent :
2021-12-26 20:11:05 +08:00
def __init__ ( self , position , model_path , use_onnx = False ) :
2021-12-20 10:02:55 +08:00
self . use_legacy = True if " legacy " in model_path else False
2021-12-22 21:19:10 +08:00
self . lite_model = True if " lite " in model_path else False
2021-09-07 16:38:34 +08:00
self . model_type = " general " if " resnet " in model_path else " old "
2021-12-26 20:11:05 +08:00
self . model = _load_model ( position , model_path , self . model_type , self . use_legacy , self . lite_model , use_onnx = use_onnx )
2021-12-27 14:52:38 +08:00
self . onnx_model_path = os . path . abspath ( model_path + ' .onnx ' )
self . use_onnx = use_onnx
self . onnx_model = None
2021-09-07 16:38:34 +08:00
self . EnvCard2RealCard = { 3 : ' 3 ' , 4 : ' 4 ' , 5 : ' 5 ' , 6 : ' 6 ' , 7 : ' 7 ' ,
8 : ' 8 ' , 9 : ' 9 ' , 10 : ' T ' , 11 : ' J ' , 12 : ' Q ' ,
13 : ' K ' , 14 : ' A ' , 17 : ' 2 ' , 20 : ' X ' , 30 : ' D ' }
2021-12-26 20:11:05 +08:00
def act ( self , infoset , with_confidence = False ) :
2021-12-27 14:52:38 +08:00
if self . use_onnx and self . onnx_model is None :
if torch . cuda . is_available ( ) :
self . onnx_model = onnxruntime . InferenceSession ( get_example ( self . onnx_model_path ) , providers = [ ' CUDAExecutionProvider ' ] )
2021-12-27 14:58:47 +08:00
else :
self . onnx_model = onnxruntime . InferenceSession ( get_example ( self . onnx_model_path ) , providers = [ ' CPUExecutionProvider ' ] )
2021-12-27 14:52:38 +08:00
2021-12-26 20:11:05 +08:00
if not with_confidence and len ( infoset . legal_actions ) == 1 :
2021-09-07 16:38:34 +08:00
return infoset . legal_actions [ 0 ]
2021-12-22 21:19:10 +08:00
obs = get_obs ( infoset , self . model_type == " general " , self . use_legacy , self . lite_model )
2021-09-07 16:38:34 +08:00
2021-12-26 20:11:05 +08:00
if self . onnx_model is None :
z_batch = torch . from_numpy ( obs [ ' z_batch ' ] ) . float ( )
x_batch = torch . from_numpy ( obs [ ' x_batch ' ] ) . float ( )
if torch . cuda . is_available ( ) :
z_batch , x_batch = z_batch . cuda ( ) , x_batch . cuda ( )
y_pred = self . model . forward ( z_batch , x_batch ) [ ' values ' ]
y_pred = y_pred . detach ( ) . cpu ( ) . numpy ( )
else :
y_pred = self . onnx_model . run ( None , { ' z_batch ' : obs [ ' z_batch ' ] , ' x_batch ' : obs [ ' x_batch ' ] } ) [ 0 ]
if with_confidence :
y_pred = y_pred . flatten ( )
size = min ( 3 , len ( y_pred ) )
best_action_index = np . argpartition ( y_pred , - size ) [ - size : ]
best_action_confidence = y_pred [ best_action_index ]
best_action = [ infoset . legal_actions [ index ] for index in best_action_index ]
return best_action , best_action_confidence
2021-09-07 16:38:34 +08:00
best_action_index = np . argmax ( y_pred , axis = 0 ) [ 0 ]
best_action = infoset . legal_actions [ best_action_index ]
# action_list = []
# output = ""
# for i, action in enumerate(y_pred):
# action_list.append((y_pred[i].item(), "".join([self.EnvCard2RealCard[ii] for ii in infoset.legal_actions[i]]) if len(infoset.legal_actions[i]) != 0 else "Pass"))
# action_list.sort(key=lambda x: x[0], reverse=True)
# value_list = []
# for action in action_list:
# output += str(round(action[0],3)) + " " + action[1] + "\n"
# value_list.append(action[0])
# # print(value_list)
# print(output)
# print("--------------------\n")
return best_action