45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
|
||
|
from douzero.env.env import get_obs
|
||
|
|
||
|
def _load_model(position, model_path):
|
||
|
from douzero.dmc.models import model_dict
|
||
|
model = model_dict[position]()
|
||
|
model_state_dict = model.state_dict()
|
||
|
if torch.cuda.is_available():
|
||
|
pretrained = torch.load(model_path, map_location='cuda:0')
|
||
|
else:
|
||
|
pretrained = torch.load(model_path, map_location='cpu')
|
||
|
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
|
||
|
model_state_dict.update(pretrained)
|
||
|
model.load_state_dict(model_state_dict)
|
||
|
if torch.cuda.is_available():
|
||
|
model.cuda()
|
||
|
model.eval()
|
||
|
return model
|
||
|
|
||
|
class DeepAgent:
|
||
|
|
||
|
def __init__(self, position, model_path):
|
||
|
self.model = _load_model(position, model_path)
|
||
|
|
||
|
def act(self, infoset):
|
||
|
# 只有一个合法动作时直接返回,这样会得不到胜率信息
|
||
|
# if len(infoset.legal_actions) == 1:
|
||
|
# return infoset.legal_actions[0], 0
|
||
|
|
||
|
obs = get_obs(infoset)
|
||
|
z_batch = torch.from_numpy(obs['z_batch']).float()
|
||
|
x_batch = torch.from_numpy(obs['x_batch']).float()
|
||
|
if torch.cuda.is_available():
|
||
|
z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
|
||
|
y_pred = self.model.forward(z_batch, x_batch, return_value=True)['values']
|
||
|
y_pred = y_pred.detach().cpu().numpy()
|
||
|
|
||
|
best_action_index = np.argmax(y_pred, axis=0)[0]
|
||
|
best_action = infoset.legal_actions[best_action_index]
|
||
|
best_action_confidence = y_pred[best_action_index]
|
||
|
# print(best_action, best_action_confidence, y_pred)
|
||
|
return best_action, best_action_confidence
|