Douzero_Resnet/douzero/evaluation/deep_agent.py

70 lines
2.8 KiB
Python
Raw Normal View History

2021-09-07 16:38:34 +08:00
import torch
import numpy as np
from douzero.env.env import get_obs
2021-12-12 14:01:40 +08:00
def _load_model(position, model_path, model_type, use_legacy):
from douzero.dmc.models import model_dict_new, model_dict, model_dict_new_legacy, model_dict_legacy
2021-09-07 16:38:34 +08:00
model = None
if model_type == "general":
2021-12-12 14:01:40 +08:00
if use_legacy:
model = model_dict_new_legacy[position]()
else:
model = model_dict_new[position]()
2021-09-07 16:38:34 +08:00
else:
2021-12-12 14:01:40 +08:00
if use_legacy:
model = model_dict_legacy[position]()
else:
model = model_dict[position]()
2021-09-07 16:38:34 +08:00
model_state_dict = model.state_dict()
if torch.cuda.is_available():
pretrained = torch.load(model_path, map_location='cuda:0')
else:
pretrained = torch.load(model_path, map_location='cpu')
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
model_state_dict.update(pretrained)
model.load_state_dict(model_state_dict)
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
if torch.cuda.is_available():
model.cuda()
model.eval()
return model
class DeepAgent:
def __init__(self, position, model_path):
2021-12-12 14:01:40 +08:00
self.use_legacy = True if "legacy" in model_path else False
2021-09-07 16:38:34 +08:00
self.model_type = "general" if "resnet" in model_path else "old"
2021-12-12 14:01:40 +08:00
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
2021-09-07 16:38:34 +08:00
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
def act(self, infoset):
if len(infoset.legal_actions) == 1:
return infoset.legal_actions[0]
2021-12-12 14:01:40 +08:00
obs = get_obs(infoset, self.model_type == "general", self.use_legacy)
2021-09-07 16:38:34 +08:00
z_batch = torch.from_numpy(obs['z_batch']).float()
x_batch = torch.from_numpy(obs['x_batch']).float()
if torch.cuda.is_available():
z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
2021-12-15 10:03:26 +08:00
y_pred = self.model.forward(z_batch, x_batch)['values']
2021-09-07 16:38:34 +08:00
y_pred = y_pred.detach().cpu().numpy()
best_action_index = np.argmax(y_pred, axis=0)[0]
best_action = infoset.legal_actions[best_action_index]
# action_list = []
# output = ""
# for i, action in enumerate(y_pred):
# action_list.append((y_pred[i].item(), "".join([self.EnvCard2RealCard[ii] for ii in infoset.legal_actions[i]]) if len(infoset.legal_actions[i]) != 0 else "Pass"))
# action_list.sort(key=lambda x: x[0], reverse=True)
# value_list = []
# for action in action_list:
# output += str(round(action[0],3)) + " " + action[1] + "\n"
# value_list.append(action[0])
# # print(value_list)
# print(output)
# print("--------------------\n")
return best_action