2021-09-07 16:38:34 +08:00
import torch
import numpy as np
from douzero . env . env import get_obs
2021-12-12 14:01:40 +08:00
def _load_model ( position , model_path , model_type , use_legacy ) :
from douzero . dmc . models import model_dict_new , model_dict , model_dict_new_legacy , model_dict_legacy
2021-09-07 16:38:34 +08:00
model = None
if model_type == " general " :
2021-12-12 14:01:40 +08:00
if use_legacy :
model = model_dict_new_legacy [ position ] ( )
else :
model = model_dict_new [ position ] ( )
2021-09-07 16:38:34 +08:00
else :
2021-12-12 14:01:40 +08:00
if use_legacy :
model = model_dict_legacy [ position ] ( )
else :
model = model_dict [ position ] ( )
2021-09-07 16:38:34 +08:00
model_state_dict = model . state_dict ( )
if torch . cuda . is_available ( ) :
pretrained = torch . load ( model_path , map_location = ' cuda:0 ' )
else :
pretrained = torch . load ( model_path , map_location = ' cpu ' )
pretrained = { k : v for k , v in pretrained . items ( ) if k in model_state_dict }
model_state_dict . update ( pretrained )
model . load_state_dict ( model_state_dict )
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
if torch . cuda . is_available ( ) :
model . cuda ( )
model . eval ( )
return model
class DeepAgent :
def __init__ ( self , position , model_path ) :
2021-12-12 14:01:40 +08:00
self . use_legacy = True if " legacy " in model_path else False
2021-09-07 16:38:34 +08:00
self . model_type = " general " if " resnet " in model_path else " old "
2021-12-12 14:01:40 +08:00
self . model = _load_model ( position , model_path , self . model_type , self . use_legacy )
2021-09-07 16:38:34 +08:00
self . EnvCard2RealCard = { 3 : ' 3 ' , 4 : ' 4 ' , 5 : ' 5 ' , 6 : ' 6 ' , 7 : ' 7 ' ,
8 : ' 8 ' , 9 : ' 9 ' , 10 : ' T ' , 11 : ' J ' , 12 : ' Q ' ,
13 : ' K ' , 14 : ' A ' , 17 : ' 2 ' , 20 : ' X ' , 30 : ' D ' }
def act ( self , infoset ) :
if len ( infoset . legal_actions ) == 1 :
return infoset . legal_actions [ 0 ]
2021-12-12 14:01:40 +08:00
obs = get_obs ( infoset , self . model_type == " general " , self . use_legacy )
2021-09-07 16:38:34 +08:00
z_batch = torch . from_numpy ( obs [ ' z_batch ' ] ) . float ( )
x_batch = torch . from_numpy ( obs [ ' x_batch ' ] ) . float ( )
if torch . cuda . is_available ( ) :
z_batch , x_batch = z_batch . cuda ( ) , x_batch . cuda ( )
2021-12-15 10:03:26 +08:00
y_pred = self . model . forward ( z_batch , x_batch ) [ ' values ' ]
2021-09-07 16:38:34 +08:00
y_pred = y_pred . detach ( ) . cpu ( ) . numpy ( )
best_action_index = np . argmax ( y_pred , axis = 0 ) [ 0 ]
best_action = infoset . legal_actions [ best_action_index ]
# action_list = []
# output = ""
# for i, action in enumerate(y_pred):
# action_list.append((y_pred[i].item(), "".join([self.EnvCard2RealCard[ii] for ii in infoset.legal_actions[i]]) if len(infoset.legal_actions[i]) != 0 else "Pass"))
# action_list.sort(key=lambda x: x[0], reverse=True)
# value_list = []
# for action in action_list:
# output += str(round(action[0],3)) + " " + action[1] + "\n"
# value_list.append(action[0])
# # print(value_list)
# print(output)
# print("--------------------\n")
return best_action