2021-05-29 00:25:02 +08:00
|
|
|
import os
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
from collections import Counter
|
|
|
|
|
|
|
|
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
|
|
|
|
11: 8, 12: 9, 13: 10, 14: 11, 17: 12, 20: 13, 30: 14}
|
|
|
|
|
2021-12-18 15:49:37 +08:00
|
|
|
NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
|
|
|
|
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
|
|
|
|
2: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
|
|
|
|
3: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
|
|
|
|
4: np.array([1, 1, 1, 1, 0, 0, 0, 0]),
|
|
|
|
5: np.array([1, 1, 1, 1, 1, 0, 0, 0]),
|
|
|
|
6: np.array([1, 1, 1, 1, 1, 1, 0, 0]),
|
|
|
|
7: np.array([1, 1, 1, 1, 1, 1, 1, 0]),
|
|
|
|
8: np.array([1, 1, 1, 1, 1, 1, 1, 1])}
|
|
|
|
|
|
|
|
|
|
|
|
def _get_one_hot_bomb(bomb_num, use_legacy = False):
|
|
|
|
"""
|
|
|
|
A utility function to encode the number of bombs
|
|
|
|
into one-hot representation.
|
|
|
|
"""
|
|
|
|
if use_legacy:
|
|
|
|
one_hot = np.zeros(29)
|
|
|
|
one_hot[bomb_num[0] + bomb_num[1]] = 1
|
|
|
|
else:
|
|
|
|
one_hot = np.zeros(56) # 14 + 15 + 27
|
|
|
|
one_hot[bomb_num[0]] = 1
|
|
|
|
one_hot[14 + bomb_num[1]] = 1
|
|
|
|
one_hot[29 + bomb_num[2]] = 1
|
2021-05-29 00:25:02 +08:00
|
|
|
return one_hot
|
|
|
|
|
|
|
|
def _load_model(position, model_dir, use_onnx):
|
|
|
|
if not use_onnx or not os.path.isfile(os.path.join(model_dir, position+'.onnx')) :
|
2021-12-18 15:49:37 +08:00
|
|
|
from models import model_dict_new
|
|
|
|
model = model_dict_new[position]()
|
2021-05-29 00:25:02 +08:00
|
|
|
model_state_dict = model.state_dict()
|
|
|
|
model_path = os.path.join(model_dir, position+'.ckpt')
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
pretrained = torch.load(model_path, map_location='cuda:0')
|
|
|
|
else:
|
|
|
|
pretrained = torch.load(model_path, map_location='cpu')
|
|
|
|
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
|
|
|
|
model_state_dict.update(pretrained)
|
|
|
|
model.load_state_dict(model_state_dict)
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
model.cuda()
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
if use_onnx:
|
|
|
|
z = torch.randn(1, 5, 162, requires_grad=True)
|
|
|
|
if position == 'landlord':
|
|
|
|
x = torch.randn(1, 373, requires_grad=True)
|
|
|
|
else:
|
|
|
|
x = torch.randn(1, 484, requires_grad=True)
|
|
|
|
torch.onnx.export(model,
|
|
|
|
(z,x),
|
|
|
|
os.path.join(model_dir, position+'.onnx'),
|
|
|
|
export_params=True,
|
|
|
|
opset_version=10,
|
|
|
|
do_constant_folding=True,
|
|
|
|
input_names = ['z', 'x'],
|
|
|
|
output_names = ['y'],
|
|
|
|
dynamic_axes={'z' : {0 : 'batch_size'},
|
|
|
|
'x' : {0 : 'batch_size'},
|
|
|
|
'y' : {0 : 'batch_size'}})
|
|
|
|
|
|
|
|
if use_onnx:
|
|
|
|
import onnxruntime
|
|
|
|
model = onnxruntime.InferenceSession(os.path.join(model_dir, position+'.onnx'))
|
|
|
|
return model
|
|
|
|
|
2021-12-18 15:49:37 +08:00
|
|
|
def _process_action_seq(sequence, length=20, new_model=True):
|
|
|
|
"""
|
|
|
|
A utility function encoding historical moves. We
|
|
|
|
encode 20 moves. If there is no 20 moves, we pad
|
|
|
|
with zeros.
|
|
|
|
"""
|
2021-05-29 00:25:02 +08:00
|
|
|
sequence = sequence[-length:].copy()
|
2021-12-18 15:49:37 +08:00
|
|
|
if new_model:
|
|
|
|
sequence = sequence[::-1]
|
2021-05-29 00:25:02 +08:00
|
|
|
if len(sequence) < length:
|
|
|
|
empty_sequence = [[] for _ in range(length - len(sequence))]
|
|
|
|
empty_sequence.extend(sequence)
|
|
|
|
sequence = empty_sequence
|
|
|
|
return sequence
|
|
|
|
|
|
|
|
class DeepAgent:
|
|
|
|
|
|
|
|
def __init__(self, position, model_dir, use_onnx=False):
|
|
|
|
self.model = _load_model(position, model_dir, use_onnx)
|
|
|
|
self.use_onnx = use_onnx
|
|
|
|
|
|
|
|
def cards2array(self, list_cards):
|
|
|
|
if len(list_cards) == 0:
|
2021-12-18 15:49:37 +08:00
|
|
|
return np.zeros(108, dtype=np.int8)
|
2021-05-29 00:25:02 +08:00
|
|
|
|
2021-12-18 15:49:37 +08:00
|
|
|
matrix = np.zeros([8, 13], dtype=np.int8)
|
|
|
|
jokers = np.zeros(4, dtype=np.int8)
|
2021-05-29 00:25:02 +08:00
|
|
|
counter = Counter(list_cards)
|
|
|
|
for card, num_times in counter.items():
|
|
|
|
if card < 20:
|
|
|
|
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
|
|
|
elif card == 20:
|
|
|
|
jokers[0] = 1
|
2021-12-18 15:49:37 +08:00
|
|
|
if num_times == 2:
|
|
|
|
jokers[1] = 1
|
2021-05-29 00:25:02 +08:00
|
|
|
elif card == 30:
|
2021-12-18 15:49:37 +08:00
|
|
|
jokers[2] = 1
|
|
|
|
if num_times == 2:
|
|
|
|
jokers[3] = 1
|
2021-05-29 00:25:02 +08:00
|
|
|
return np.concatenate((matrix.flatten('F'), jokers))
|
|
|
|
|
|
|
|
def get_one_hot_array(self, num_left_cards, max_num_cards):
|
|
|
|
one_hot = np.zeros(max_num_cards, dtype=np.float32)
|
|
|
|
one_hot[num_left_cards - 1] = 1
|
|
|
|
|
|
|
|
return one_hot
|
|
|
|
|
2021-12-18 15:49:37 +08:00
|
|
|
def action_seq_list2array(self, action_seq_list, new_model=True):
|
|
|
|
"""
|
|
|
|
A utility function to encode the historical moves.
|
|
|
|
We encode the historical 20 actions. If there is
|
|
|
|
no 20 actions, we pad the features with 0. Since
|
|
|
|
three moves is a round in DouDizhu, we concatenate
|
|
|
|
the representations for each consecutive three moves.
|
|
|
|
Finally, we obtain a 5x432 matrix, which will be fed
|
|
|
|
into LSTM for encoding.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if new_model:
|
|
|
|
# position_map = {"landlord": 0, "landlord_up": 1, "landlord_front": 2, "landlord_down": 3}
|
|
|
|
action_seq_array = np.ones((len(action_seq_list), 108)) * -1 # Default Value -1 for not using area
|
|
|
|
for row, list_cards in enumerate(action_seq_list):
|
|
|
|
if list_cards != []:
|
|
|
|
action_seq_array[row, :108] = self.cards2array(list_cards)
|
|
|
|
else:
|
|
|
|
action_seq_array = np.zeros((len(action_seq_list), 108))
|
|
|
|
for row, list_cards in enumerate(action_seq_list):
|
|
|
|
if list_cards != []:
|
|
|
|
action_seq_array[row, :] = self.cards2array(list_cards)
|
|
|
|
action_seq_array = action_seq_array.reshape(5, 432)
|
2021-05-29 00:25:02 +08:00
|
|
|
return action_seq_array
|
|
|
|
|
2021-12-18 15:49:37 +08:00
|
|
|
def get_obs_general(self, infoset, position, use_legacy = False):
|
2021-05-29 00:25:02 +08:00
|
|
|
num_legal_actions = len(infoset.legal_actions)
|
|
|
|
my_handcards = self.cards2array(infoset.player_hand_cards)
|
|
|
|
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
2021-12-18 15:49:37 +08:00
|
|
|
num_legal_actions, axis=0)
|
2021-05-29 00:25:02 +08:00
|
|
|
|
|
|
|
other_handcards = self.cards2array(infoset.other_hand_cards)
|
|
|
|
|
2021-12-18 15:49:37 +08:00
|
|
|
bid_info = np.array(infoset.bid_info).flatten()
|
|
|
|
bid_info_batch = np.repeat(bid_info[np.newaxis, :],
|
|
|
|
num_legal_actions, axis=0)
|
|
|
|
|
|
|
|
multiply_info = np.array(infoset.multiply_info)
|
|
|
|
multiply_info_batch = np.repeat(multiply_info[np.newaxis, :],
|
|
|
|
num_legal_actions, axis=0)
|
|
|
|
|
|
|
|
my_action_batch = np.zeros(my_handcards_batch.shape)
|
2021-05-29 00:25:02 +08:00
|
|
|
for j, action in enumerate(infoset.legal_actions):
|
|
|
|
my_action_batch[j, :] = self.cards2array(action)
|
|
|
|
|
2021-12-18 15:49:37 +08:00
|
|
|
landlord_num_cards_left = self.get_one_hot_array(
|
|
|
|
infoset.num_cards_left_dict['landlord'], 33)
|
|
|
|
|
|
|
|
landlord_up_num_cards_left = self.get_one_hot_array(
|
|
|
|
infoset.num_cards_left_dict['landlord_up'], 25)
|
|
|
|
|
|
|
|
landlord_front_num_cards_left = self.get_one_hot_array(
|
|
|
|
infoset.num_cards_left_dict['landlord_front'], 25)
|
|
|
|
|
|
|
|
landlord_down_num_cards_left = self.get_one_hot_array(
|
|
|
|
infoset.num_cards_left_dict['landlord_down'], 25)
|
|
|
|
|
|
|
|
landlord_played_cards = self.cards2array(
|
|
|
|
infoset.played_cards[0])
|
|
|
|
|
|
|
|
landlord_up_played_cards = self.cards2array(
|
|
|
|
infoset.played_cards[1])
|
|
|
|
|
|
|
|
landlord_front_played_cards = self.cards2array(
|
|
|
|
infoset.played_cards[2])
|
|
|
|
|
|
|
|
landlord_down_played_cards = self.cards2array(
|
|
|
|
infoset.played_cards[3])
|
|
|
|
|
|
|
|
bomb_num = _get_one_hot_bomb(
|
|
|
|
infoset.bomb_num, use_legacy)
|
|
|
|
bomb_num_batch = np.repeat(
|
|
|
|
bomb_num[np.newaxis, :],
|
|
|
|
num_legal_actions, axis=0)
|
|
|
|
num_cards_left = np.hstack((
|
|
|
|
landlord_num_cards_left, # 33
|
|
|
|
landlord_up_num_cards_left, # 25
|
|
|
|
landlord_front_num_cards_left, # 25
|
|
|
|
landlord_down_num_cards_left))
|
|
|
|
|
|
|
|
if use_legacy:
|
|
|
|
x_batch = np.hstack((
|
|
|
|
bid_info_batch, # 20
|
|
|
|
multiply_info_batch)) # 4
|
|
|
|
x_no_action = np.hstack((
|
|
|
|
bid_info,
|
|
|
|
multiply_info))
|
2021-05-29 00:25:02 +08:00
|
|
|
else:
|
2021-12-18 15:49:37 +08:00
|
|
|
x_batch = np.hstack((
|
|
|
|
bomb_num_batch, # 56
|
|
|
|
bid_info_batch, # 20
|
|
|
|
multiply_info_batch)) # 4
|
|
|
|
x_no_action = np.hstack((
|
|
|
|
bomb_num, # 56
|
|
|
|
bid_info,
|
|
|
|
multiply_info))
|
|
|
|
z =np.vstack((
|
|
|
|
num_cards_left,
|
|
|
|
my_handcards, # 108
|
|
|
|
other_handcards, # 108
|
|
|
|
# three_landlord_cards, # 108
|
|
|
|
landlord_played_cards, # 108
|
|
|
|
landlord_up_played_cards, # 108
|
|
|
|
landlord_front_played_cards, # 108
|
|
|
|
landlord_down_played_cards, # 108
|
|
|
|
self.action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32))
|
|
|
|
))
|
|
|
|
|
|
|
|
_z_batch = np.repeat(
|
|
|
|
z[np.newaxis, :, :],
|
|
|
|
num_legal_actions, axis=0)
|
|
|
|
my_action_batch = my_action_batch[:,np.newaxis,:]
|
|
|
|
z_batch = np.zeros([len(_z_batch),40,108],int)
|
|
|
|
for i in range(0,len(_z_batch)):
|
|
|
|
z_batch[i] = np.vstack((my_action_batch[i],_z_batch[i]))
|
|
|
|
obs = {
|
|
|
|
'position': position,
|
|
|
|
'x_batch': x_batch.astype(np.float32),
|
|
|
|
'z_batch': z_batch.astype(np.float32),
|
|
|
|
'legal_actions': infoset.legal_actions,
|
|
|
|
'x_no_action': x_no_action.astype(np.int8),
|
|
|
|
'z': z.astype(np.int8),
|
|
|
|
}
|
|
|
|
return obs
|
2021-05-29 00:25:02 +08:00
|
|
|
|
2021-12-18 15:49:37 +08:00
|
|
|
def act(self, infoset):
|
|
|
|
obs = self.get_obs_general(infoset, infoset.player_position)
|
|
|
|
z_batch = obs['z_batch']
|
|
|
|
x_batch = obs['x_batch']
|
|
|
|
if self.use_onnx:
|
|
|
|
ort_inputs = {'z': z_batch, 'x': x_batch}
|
|
|
|
y_pred = self.model.run(None, ort_inputs)[0]
|
|
|
|
elif torch.cuda.is_available():
|
|
|
|
y_pred = self.model.forward(torch.from_numpy(z_batch).float().cuda(),
|
|
|
|
torch.from_numpy(x_batch).float().cuda())
|
|
|
|
y_pred = y_pred.cpu().detach().numpy()
|
|
|
|
else:
|
|
|
|
y_pred = self.model.forward(torch.from_numpy(z_batch).float(),
|
|
|
|
torch.from_numpy(x_batch).float())['values']
|
|
|
|
y_pred = y_pred.detach().numpy()
|
2021-05-29 00:25:02 +08:00
|
|
|
|
|
|
|
y_pred = y_pred.flatten()
|
|
|
|
|
|
|
|
#best_action_index = np.argmax(y_pred, axis=0)[0]
|
|
|
|
size = min(3, len(y_pred))
|
|
|
|
best_action_index = np.argpartition(y_pred, -size)[-size:]
|
|
|
|
best_action_confidence = y_pred[best_action_index]
|
|
|
|
best_action = [infoset.legal_actions[index] for index in best_action_index]
|
|
|
|
|
|
|
|
return best_action, best_action_confidence
|