新增压缩版模型

This commit is contained in:
ZaneYork 2021-12-22 21:19:10 +08:00
parent d5401cefc9
commit b5982b7195
6 changed files with 343 additions and 159 deletions

View File

@ -29,6 +29,8 @@ parser.add_argument('--load_model', action='store_true',
help='Load an existing model') help='Load an existing model')
parser.add_argument('--old_model', action='store_true', parser.add_argument('--old_model', action='store_true',
help='Use vanilla model') help='Use vanilla model')
parser.add_argument('--lite_model', action='store_true',
help='Use lite card model')
parser.add_argument('--lagacy_model', action='store_true', parser.add_argument('--lagacy_model', action='store_true',
help='Use lagacy bomb model') help='Use lagacy bomb model')
parser.add_argument('--disable_checkpoint', action='store_true', parser.add_argument('--disable_checkpoint', action='store_true',

View File

@ -87,7 +87,7 @@ def train(flags):
if flags.actor_device_cpu: if flags.actor_device_cpu:
device_iterator = ['cpu'] device_iterator = ['cpu']
else: else:
device_iterator = [0, 'cpu'] #range(flags.num_actor_devices) #[0, 'cpu'] device_iterator = range(flags.num_actor_devices) #[0, 'cpu']
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices' assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
# Initialize actor models # Initialize actor models
@ -96,7 +96,7 @@ def train(flags):
if flags.old_model: if flags.old_model:
model = OldModel(device="cpu", flags = flags) model = OldModel(device="cpu", flags = flags)
else: else:
model = Model(device="cpu", flags = flags) model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
model.share_memory() model.share_memory()
model.eval() model.eval()
models[device] = model models[device] = model
@ -111,7 +111,7 @@ def train(flags):
if flags.old_model: if flags.old_model:
learner_model = OldModel(device=flags.training_device) learner_model = OldModel(device=flags.training_device)
else: else:
learner_model = Model(device=flags.training_device) learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
# Create optimizers # Create optimizers
optimizers = create_optimizers(flags, learner_model) optimizers = create_optimizers(flags, learner_model)

View File

@ -240,6 +240,69 @@ class BasicBlock(nn.Module):
out = F.relu(out) out = F.relu(out)
return out return out
class GeneralModelLite(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 80
#input 1*69*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*35*80
self.bn1 = nn.BatchNorm1d(80)
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*18*80
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*9*160
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*5*320
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear1 = nn.Linear(320 * BasicBlock.expansion * 5 + 56, 1536)
self.linear2 = nn.Linear(1536, 768)
self.linear3 = nn.Linear(768, 384)
self.linear4 = nn.Linear(384, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def get_onnx_params(self, device=None):
return {
'args': (
torch.randn(1, 40, 69, requires_grad=True, device=device),
torch.randn(1, 56, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = out.flatten(1,2)
out = torch.cat([x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
return dict(values=out)
class GeneralModel(nn.Module): class GeneralModel(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -414,7 +477,7 @@ class Model:
The wrapper for the three models. We also wrap several The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc. interfaces such as share_memory, eval, etc.
""" """
def __init__(self, device=0, flags=None): def __init__(self, device=0, flags=None, lite_model = False):
self.models = {} self.models = {}
self.onnx_models = {} self.onnx_models = {}
self.flags = flags self.flags = flags
@ -428,6 +491,9 @@ class Model:
self.models[position] = None self.models[position] = None
else: else:
for position in positions: for position in positions:
if lite_model:
self.models[position] = GeneralModelLite().to(self.device)
else:
self.models[position] = GeneralModel().to(self.device) self.models[position] = GeneralModel().to(self.device)
self.onnx_models = { self.onnx_models = {
'landlord': None, 'landlord': None,

View File

@ -36,6 +36,26 @@ NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]), 13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])} 15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
2: np.array([1, 1, 0, 0, 0]),
3: np.array([1, 1, 1, 0, 0]),
4: np.array([1, 1, 1, 1, 0]),
5: np.array([1, 0, 0, 0, 1]),
6: np.array([1, 1, 0, 0, 1]),
7: np.array([1, 1, 1, 0, 1]),
8: np.array([1, 1, 1, 1, 1])}
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
3: np.array([1, 1, 0, 0, 0]),
4: np.array([0, 0, 1, 0, 0]),
5: np.array([1, 0, 1, 0, 0]),
7: np.array([1, 1, 1, 0, 0]),
12: np.array([0, 0, 1, 1, 0]),
13: np.array([1, 0, 1, 1, 0]),
15: np.array([1, 1, 1, 1, 0])}
shandle = logging.StreamHandler() shandle = logging.StreamHandler()
shandle.setFormatter( shandle.setFormatter(
logging.Formatter( logging.Formatter(
@ -51,7 +71,7 @@ log.setLevel(logging.INFO)
Buffers = typing.Dict[str, typing.List[torch.Tensor]] Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags): def create_env(flags):
return Env(flags.objective, flags.old_model, flags.lagacy_model) return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model)
def get_batch(b_queues, position, flags, lock): def get_batch(b_queues, position, flags, lock):
""" """
@ -133,11 +153,11 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
else: else:
action = obs['legal_actions'][0] action = obs['legal_actions'][0]
if flags.old_model: if flags.old_model:
obs_action_buf[position].append(_cards2tensor(action)) obs_action_buf[position].append(_cards2tensor(action, flags.lite_model))
obs_x_no_action[position].append(env_output['obs_x_no_action']) obs_x_no_action[position].append(env_output['obs_x_no_action'])
obs_z_buf[position].append(env_output['obs_z']) obs_z_buf[position].append(env_output['obs_z'])
else: else:
obs_z_buf[position].append(torch.vstack((_cards2tensor(action).unsqueeze(0), env_output['obs_z'])).float()) obs_z_buf[position].append(torch.vstack((_cards2tensor(action, flags.lite_model).unsqueeze(0), env_output['obs_z'])).float())
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float() # x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
x_batch = env_output['obs_x_no_action'].float() x_batch = env_output['obs_x_no_action'].float()
obs_x_batch_buf[position].append(x_batch) obs_x_batch_buf[position].append(x_batch)
@ -196,12 +216,37 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
print() print()
raise e raise e
def _cards2tensor(list_cards): def _cards2tensor(list_cards, compress_form = False):
""" """
Convert a list of integers to the tensor Convert a list of integers to the tensor
representation representation
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
""" """
if compress_form:
if len(list_cards) == 0:
return torch.zeros(69, dtype=torch.int8)
matrix = np.zeros([5, 14], dtype=np.int8)
counter = Counter(list_cards)
joker_cnt = 0
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
elif card == 20:
if num_times == 2:
joker_cnt |= 0b11
else:
joker_cnt |= 0b01
elif card == 30:
if num_times == 2:
joker_cnt |= 0b1100
else:
joker_cnt |= 0b0100
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
matrix = matrix.flatten('F')[:-1]
matrix = torch.from_numpy(matrix)
return matrix
else:
if len(list_cards) == 0: if len(list_cards) == 0:
return torch.zeros(108, dtype=torch.int8) return torch.zeros(108, dtype=torch.int8)

262
douzero/env/env.py vendored
View File

@ -30,6 +30,26 @@ NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]), 13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])} 15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
2: np.array([1, 1, 0, 0, 0]),
3: np.array([1, 1, 1, 0, 0]),
4: np.array([1, 1, 1, 1, 0]),
5: np.array([1, 0, 0, 0, 1]),
6: np.array([1, 1, 0, 0, 1]),
7: np.array([1, 1, 1, 0, 1]),
8: np.array([1, 1, 1, 1, 1])}
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
3: np.array([1, 1, 0, 0, 0]),
4: np.array([0, 0, 1, 0, 0]),
5: np.array([1, 0, 1, 0, 0]),
7: np.array([1, 1, 1, 0, 0]),
12: np.array([0, 0, 1, 1, 0]),
13: np.array([1, 0, 1, 1, 0]),
15: np.array([1, 1, 1, 1, 0])}
deck = [] deck = []
for i in range(3, 15): for i in range(3, 15):
@ -43,7 +63,7 @@ class Env:
Doudizhu multi-agent wrapper Doudizhu multi-agent wrapper
""" """
def __init__(self, objective, old_model, legacy_model=False): def __init__(self, objective, old_model, legacy_model=False, lite_model = False):
""" """
Objective is wp/adp/logadp. It indicates whether considers Objective is wp/adp/logadp. It indicates whether considers
bomb in reward calculation. Here, we use dummy agents. bomb in reward calculation. Here, we use dummy agents.
@ -58,6 +78,7 @@ class Env:
self.objective = objective self.objective = objective
self.use_legacy = legacy_model self.use_legacy = legacy_model
self.use_general = not old_model self.use_general = not old_model
self.lite_model = lite_model
# Initialize players # Initialize players
# We use three dummy player for the target position # We use three dummy player for the target position
@ -92,7 +113,7 @@ class Env:
card_play_data[key].sort() card_play_data[key].sort()
self._env.card_play_init(card_play_data) self._env.card_play_init(card_play_data)
self.infoset = self._game_infoset self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general) return get_obs(self.infoset, self.use_general, self.lite_model)
else: else:
self.total_round += 1 self.total_round += 1
_deck = deck.copy() _deck = deck.copy()
@ -118,7 +139,7 @@ class Env:
self._env.info_sets[pos].player_id = pid self._env.info_sets[pos].player_id = pid
self.infoset = self._game_infoset self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general, self.use_legacy) return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
def step(self, action): def step(self, action):
""" """
@ -146,7 +167,7 @@ class Env:
} }
obs = None obs = None
else: else:
obs = get_obs(self.infoset, self.use_general) obs = get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
return obs, reward, done, {} return obs, reward, done, {}
def _get_reward(self, pos): def _get_reward(self, pos):
@ -244,7 +265,7 @@ class DummyAgent(object):
self.action = action self.action = action
def get_obs(infoset, use_general=True, use_legacy = False): def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False):
""" """
This function obtains observations with imperfect information This function obtains observations with imperfect information
from the infoset. It has three branches since we encode from the infoset. It has three branches since we encode
@ -271,24 +292,34 @@ def get_obs(infoset, use_general=True, use_legacy = False):
if use_general: if use_general:
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]: if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
raise ValueError('') raise ValueError('')
return _get_obs_general(infoset, infoset.player_position) return _get_obs_general(infoset, infoset.player_position, lite_model)
else: else:
if infoset.player_position == 'landlord': if infoset.player_position == 'landlord':
return _get_obs_landlord(infoset, use_legacy) return _get_obs_landlord(infoset, use_legacy, lite_model)
elif infoset.player_position == 'landlord_up': elif infoset.player_position == 'landlord_up':
return _get_obs_landlord_up(infoset, use_legacy) return _get_obs_landlord_up(infoset, use_legacy, lite_model)
elif infoset.player_position == 'landlord_front': elif infoset.player_position == 'landlord_front':
return _get_obs_landlord_front(infoset, use_legacy) return _get_obs_landlord_front(infoset, use_legacy, lite_model)
elif infoset.player_position == 'landlord_down': elif infoset.player_position == 'landlord_down':
return _get_obs_landlord_down(infoset, use_legacy) return _get_obs_landlord_down(infoset, use_legacy, lite_model)
else: else:
raise ValueError('') raise ValueError('')
def _get_one_hot_array(num_left_cards, max_num_cards): def _get_one_hot_array(num_left_cards, max_num_cards, compress_size = 0):
""" """
A utility function to obtain one-hot endoding A utility function to obtain one-hot endoding
""" """
if compress_size > 0:
assert compress_size <= max_num_cards / 2
array_size = max_num_cards - compress_size
one_hot = np.zeros(array_size)
if num_left_cards >= array_size:
one_hot[-1] = 1
num_left_cards -= array_size
if num_left_cards > 0:
one_hot[num_left_cards - 1] = 1
else:
one_hot = np.zeros(max_num_cards) one_hot = np.zeros(max_num_cards)
if num_left_cards > 0: if num_left_cards > 0:
one_hot[num_left_cards - 1] = 1 one_hot[num_left_cards - 1] = 1
@ -296,13 +327,36 @@ def _get_one_hot_array(num_left_cards, max_num_cards):
return one_hot return one_hot
def _cards2array(list_cards): def _cards2array(list_cards, compressed_form = False):
""" """
A utility function that transforms the actions, i.e., A utility function that transforms the actions, i.e.,
A list of integers into card matrix. Here we remove A list of integers into card matrix. Here we remove
the six entries that are always zero and flatten the the six entries that are always zero and flatten the
the representations. the representations.
""" """
if compressed_form:
if len(list_cards) == 0:
return np.zeros(69, dtype=np.int8)
matrix = np.zeros([5, 14], dtype=np.int8)
counter = Counter(list_cards)
joker_cnt = 0
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
elif card == 20:
if num_times == 2:
joker_cnt |= 0b11
else:
joker_cnt |= 0b01
elif card == 30:
if num_times == 2:
joker_cnt |= 0b1100
else:
joker_cnt |= 0b0100
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
return matrix.flatten('F')[:-1]
else:
if len(list_cards) == 0: if len(list_cards) == 0:
return np.zeros(108, dtype=np.int8) return np.zeros(108, dtype=np.int8)
@ -342,7 +396,7 @@ def _cards2array(list_cards):
# # action_seq_array = action_seq_array.reshape(5, 162) # # action_seq_array = action_seq_array.reshape(5, 162)
# return action_seq_array # return action_seq_array
def _action_seq_list2array(action_seq_list, new_model=True): def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = False):
""" """
A utility function to encode the historical moves. A utility function to encode the historical moves.
We encode the historical 20 actions. If there is We encode the historical 20 actions. If there is
@ -355,15 +409,28 @@ def _action_seq_list2array(action_seq_list, new_model=True):
if new_model: if new_model:
# position_map = {"landlord": 0, "landlord_up": 1, "landlord_front": 2, "landlord_down": 3} # position_map = {"landlord": 0, "landlord_up": 1, "landlord_front": 2, "landlord_down": 3}
if compressed_form:
action_seq_array = np.full((len(action_seq_list), 69), -1) # Default Value -1 for not using area
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :69] = _cards2array(list_cards[1], compressed_form)
else:
action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area
for row, list_cards in enumerate(action_seq_list): for row, list_cards in enumerate(action_seq_list):
if list_cards != []: if list_cards != []:
action_seq_array[row, :108] = _cards2array(list_cards[1]) action_seq_array[row, :108] = _cards2array(list_cards[1], compressed_form)
else:
if compressed_form:
action_seq_array = np.zeros((len(action_seq_list), 69))
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
action_seq_array = action_seq_array.reshape(5, 276)
else: else:
action_seq_array = np.zeros((len(action_seq_list), 108)) action_seq_array = np.zeros((len(action_seq_list), 108))
for row, list_cards in enumerate(action_seq_list): for row, list_cards in enumerate(action_seq_list):
if list_cards != []: if list_cards != []:
action_seq_array[row, :] = _cards2array(list_cards[1]) action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
action_seq_array = action_seq_array.reshape(5, 432) action_seq_array = action_seq_array.reshape(5, 432)
return action_seq_array return action_seq_array
@ -406,60 +473,60 @@ def _get_one_hot_bomb(bomb_num, use_legacy = False):
return one_hot return one_hot
def _get_obs_landlord(infoset, use_legacy = False): def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False):
""" """
Obttain the landlord features. See Table 4 in Obttain the landlord features. See Table 4 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
""" """
num_legal_actions = len(infoset.legal_actions) num_legal_actions = len(infoset.legal_actions)
my_handcards = _cards2array(infoset.player_hand_cards) my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
other_handcards = _cards2array(infoset.other_hand_cards) other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :], other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_action = _cards2array(infoset.last_move) last_action = _cards2array(infoset.last_move, compressed_form)
last_action_batch = np.repeat(last_action[np.newaxis, :], last_action_batch = np.repeat(last_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
my_action_batch = np.zeros(my_handcards_batch.shape) my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions): for j, action in enumerate(infoset.legal_actions):
my_action_batch[j, :] = _cards2array(action) my_action_batch[j, :] = _cards2array(action, compressed_form)
landlord_up_num_cards_left = _get_one_hot_array( landlord_up_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_up'], 25) infoset.num_cards_left_dict['landlord_up'], 25, 15)
landlord_up_num_cards_left_batch = np.repeat( landlord_up_num_cards_left_batch = np.repeat(
landlord_up_num_cards_left[np.newaxis, :], landlord_up_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_front_num_cards_left = _get_one_hot_array( landlord_front_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_front'], 25) infoset.num_cards_left_dict['landlord_front'], 25, 8)
landlord_front_num_cards_left_batch = np.repeat( landlord_front_num_cards_left_batch = np.repeat(
landlord_front_num_cards_left[np.newaxis, :], landlord_front_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_down_num_cards_left = _get_one_hot_array( landlord_down_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_down'], 25) infoset.num_cards_left_dict['landlord_down'], 25, 8)
landlord_down_num_cards_left_batch = np.repeat( landlord_down_num_cards_left_batch = np.repeat(
landlord_down_num_cards_left[np.newaxis, :], landlord_down_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_up_played_cards = _cards2array( landlord_up_played_cards = _cards2array(
infoset.played_cards['landlord_up']) infoset.played_cards['landlord_up'], 8)
landlord_up_played_cards_batch = np.repeat( landlord_up_played_cards_batch = np.repeat(
landlord_up_played_cards[np.newaxis, :], landlord_up_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_front_played_cards = _cards2array( landlord_front_played_cards = _cards2array(
infoset.played_cards['landlord_front']) infoset.played_cards['landlord_front'], compressed_form)
landlord_front_played_cards_batch = np.repeat( landlord_front_played_cards_batch = np.repeat(
landlord_front_played_cards[np.newaxis, :], landlord_front_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_down_played_cards = _cards2array( landlord_down_played_cards = _cards2array(
infoset.played_cards['landlord_down']) infoset.played_cards['landlord_down'], compressed_form)
landlord_down_played_cards_batch = np.repeat( landlord_down_played_cards_batch = np.repeat(
landlord_down_played_cards[np.newaxis, :], landlord_down_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -492,7 +559,7 @@ def _get_obs_landlord(infoset, use_legacy = False):
landlord_down_num_cards_left, landlord_down_num_cards_left,
bomb_num)) bomb_num))
z = _action_seq_list2array(_process_action_seq( z = _action_seq_list2array(_process_action_seq(
infoset.card_play_action_seq, 20, False), False) infoset.card_play_action_seq, 20, False), False, compressed_form)
z_batch = np.repeat( z_batch = np.repeat(
z[np.newaxis, :, :], z[np.newaxis, :, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -506,75 +573,75 @@ def _get_obs_landlord(infoset, use_legacy = False):
} }
return obs return obs
def _get_obs_landlord_up(infoset, use_legacy = False): def _get_obs_landlord_up(infoset, use_legacy = False, compressed_form = False):
""" """
Obttain the landlord_up features. See Table 5 in Obttain the landlord_up features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
""" """
num_legal_actions = len(infoset.legal_actions) num_legal_actions = len(infoset.legal_actions)
my_handcards = _cards2array(infoset.player_hand_cards) my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
other_handcards = _cards2array(infoset.other_hand_cards) other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :], other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_action = _cards2array(infoset.last_move) last_action = _cards2array(infoset.last_move, compressed_form)
last_action_batch = np.repeat(last_action[np.newaxis, :], last_action_batch = np.repeat(last_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
my_action_batch = np.zeros(my_handcards_batch.shape) my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions): for j, action in enumerate(infoset.legal_actions):
my_action_batch[j, :] = _cards2array(action) my_action_batch[j, :] = _cards2array(action, compressed_form)
last_landlord_action = _cards2array( last_landlord_action = _cards2array(
infoset.last_move_dict['landlord']) infoset.last_move_dict['landlord'], compressed_form)
last_landlord_action_batch = np.repeat( last_landlord_action_batch = np.repeat(
last_landlord_action[np.newaxis, :], last_landlord_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_num_cards_left = _get_one_hot_array( landlord_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord'], 33) infoset.num_cards_left_dict['landlord'], 33, 15)
landlord_num_cards_left_batch = np.repeat( landlord_num_cards_left_batch = np.repeat(
landlord_num_cards_left[np.newaxis, :], landlord_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_played_cards = _cards2array( landlord_played_cards = _cards2array(
infoset.played_cards['landlord']) infoset.played_cards['landlord'], compressed_form)
landlord_played_cards_batch = np.repeat( landlord_played_cards_batch = np.repeat(
landlord_played_cards[np.newaxis, :], landlord_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_teammate_action = _cards2array( last_teammate_action = _cards2array(
infoset.last_move_dict['landlord_down']) infoset.last_move_dict['landlord_down'], compressed_form)
last_teammate_action_batch = np.repeat( last_teammate_action_batch = np.repeat(
last_teammate_action[np.newaxis, :], last_teammate_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_num_cards_left = _get_one_hot_array( teammate_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_down'], 25) infoset.num_cards_left_dict['landlord_down'], 25, 8)
teammate_num_cards_left_batch = np.repeat( teammate_num_cards_left_batch = np.repeat(
teammate_num_cards_left[np.newaxis, :], teammate_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_played_cards = _cards2array( teammate_played_cards = _cards2array(
infoset.played_cards['landlord_down']) infoset.played_cards['landlord_down'], compressed_form)
teammate_played_cards_batch = np.repeat( teammate_played_cards_batch = np.repeat(
teammate_played_cards[np.newaxis, :], teammate_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_teammate_front_action = _cards2array( last_teammate_front_action = _cards2array(
infoset.last_move_dict['landlord_front']) infoset.last_move_dict['landlord_front'], compressed_form)
last_teammate_front_action_batch = np.repeat( last_teammate_front_action_batch = np.repeat(
last_teammate_front_action[np.newaxis, :], last_teammate_front_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_front_num_cards_left = _get_one_hot_array( teammate_front_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_front'], 25) infoset.num_cards_left_dict['landlord_front'], 25, 8)
teammate_front_num_cards_left_batch = np.repeat( teammate_front_num_cards_left_batch = np.repeat(
teammate_front_num_cards_left[np.newaxis, :], teammate_front_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_front_played_cards = _cards2array( teammate_front_played_cards = _cards2array(
infoset.played_cards['landlord_front']) infoset.played_cards['landlord_front'], compressed_form)
teammate_front_played_cards_batch = np.repeat( teammate_front_played_cards_batch = np.repeat(
teammate_front_played_cards[np.newaxis, :], teammate_front_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -613,7 +680,7 @@ def _get_obs_landlord_up(infoset, use_legacy = False):
teammate_front_num_cards_left, teammate_front_num_cards_left,
bomb_num)) bomb_num))
z = _action_seq_list2array(_process_action_seq( z = _action_seq_list2array(_process_action_seq(
infoset.card_play_action_seq, 20, False), False) infoset.card_play_action_seq, 20, False), False, compressed_form)
z_batch = np.repeat( z_batch = np.repeat(
z[np.newaxis, :, :], z[np.newaxis, :, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -627,75 +694,75 @@ def _get_obs_landlord_up(infoset, use_legacy = False):
} }
return obs return obs
def _get_obs_landlord_front(infoset, use_legacy = False): def _get_obs_landlord_front(infoset, use_legacy = False, compressed_form = False):
""" """
Obttain the landlord_front features. See Table 5 in Obttain the landlord_front features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
""" """
num_legal_actions = len(infoset.legal_actions) num_legal_actions = len(infoset.legal_actions)
my_handcards = _cards2array(infoset.player_hand_cards) my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
other_handcards = _cards2array(infoset.other_hand_cards) other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :], other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_action = _cards2array(infoset.last_move) last_action = _cards2array(infoset.last_move, compressed_form)
last_action_batch = np.repeat(last_action[np.newaxis, :], last_action_batch = np.repeat(last_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
my_action_batch = np.zeros(my_handcards_batch.shape) my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions): for j, action in enumerate(infoset.legal_actions):
my_action_batch[j, :] = _cards2array(action) my_action_batch[j, :] = _cards2array(action, compressed_form)
last_landlord_action = _cards2array( last_landlord_action = _cards2array(
infoset.last_move_dict['landlord']) infoset.last_move_dict['landlord'], compressed_form)
last_landlord_action_batch = np.repeat( last_landlord_action_batch = np.repeat(
last_landlord_action[np.newaxis, :], last_landlord_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_num_cards_left = _get_one_hot_array( landlord_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord'], 33) infoset.num_cards_left_dict['landlord'], 33, 15)
landlord_num_cards_left_batch = np.repeat( landlord_num_cards_left_batch = np.repeat(
landlord_num_cards_left[np.newaxis, :], landlord_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_played_cards = _cards2array( landlord_played_cards = _cards2array(
infoset.played_cards['landlord']) infoset.played_cards['landlord'], compressed_form)
landlord_played_cards_batch = np.repeat( landlord_played_cards_batch = np.repeat(
landlord_played_cards[np.newaxis, :], landlord_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_teammate_action = _cards2array( last_teammate_action = _cards2array(
infoset.last_move_dict['landlord_down']) infoset.last_move_dict['landlord_down'], compressed_form)
last_teammate_action_batch = np.repeat( last_teammate_action_batch = np.repeat(
last_teammate_action[np.newaxis, :], last_teammate_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_num_cards_left = _get_one_hot_array( teammate_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_down'], 25) infoset.num_cards_left_dict['landlord_down'], 25, 8)
teammate_num_cards_left_batch = np.repeat( teammate_num_cards_left_batch = np.repeat(
teammate_num_cards_left[np.newaxis, :], teammate_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_played_cards = _cards2array( teammate_played_cards = _cards2array(
infoset.played_cards['landlord_down']) infoset.played_cards['landlord_down'], compressed_form)
teammate_played_cards_batch = np.repeat( teammate_played_cards_batch = np.repeat(
teammate_played_cards[np.newaxis, :], teammate_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_teammate_front_action = _cards2array( last_teammate_front_action = _cards2array(
infoset.last_move_dict['landlord_front']) infoset.last_move_dict['landlord_front'], compressed_form)
last_teammate_front_action_batch = np.repeat( last_teammate_front_action_batch = np.repeat(
last_teammate_front_action[np.newaxis, :], last_teammate_front_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_front_num_cards_left = _get_one_hot_array( teammate_front_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_front'], 25) infoset.num_cards_left_dict['landlord_front'], 25, 8)
teammate_front_num_cards_left_batch = np.repeat( teammate_front_num_cards_left_batch = np.repeat(
teammate_front_num_cards_left[np.newaxis, :], teammate_front_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_front_played_cards = _cards2array( teammate_front_played_cards = _cards2array(
infoset.played_cards['landlord_front']) infoset.played_cards['landlord_front'], compressed_form)
teammate_front_played_cards_batch = np.repeat( teammate_front_played_cards_batch = np.repeat(
teammate_played_cards[np.newaxis, :], teammate_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -734,7 +801,7 @@ def _get_obs_landlord_front(infoset, use_legacy = False):
teammate_front_num_cards_left, teammate_front_num_cards_left,
bomb_num)) bomb_num))
z = _action_seq_list2array(_process_action_seq( z = _action_seq_list2array(_process_action_seq(
infoset.card_play_action_seq, 20, False), False) infoset.card_play_action_seq, 20, False), False, compressed_form)
z_batch = np.repeat( z_batch = np.repeat(
z[np.newaxis, :, :], z[np.newaxis, :, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -748,75 +815,75 @@ def _get_obs_landlord_front(infoset, use_legacy = False):
} }
return obs return obs
def _get_obs_landlord_down(infoset, use_legacy = False): def _get_obs_landlord_down(infoset, use_legacy = False, compressed_form = False):
""" """
Obttain the landlord_down features. See Table 5 in Obttain the landlord_down features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
""" """
num_legal_actions = len(infoset.legal_actions) num_legal_actions = len(infoset.legal_actions)
my_handcards = _cards2array(infoset.player_hand_cards) my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
other_handcards = _cards2array(infoset.other_hand_cards) other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :], other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_action = _cards2array(infoset.last_move) last_action = _cards2array(infoset.last_move, compressed_form)
last_action_batch = np.repeat(last_action[np.newaxis, :], last_action_batch = np.repeat(last_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
my_action_batch = np.zeros(my_handcards_batch.shape) my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions): for j, action in enumerate(infoset.legal_actions):
my_action_batch[j, :] = _cards2array(action) my_action_batch[j, :] = _cards2array(action, compressed_form)
last_landlord_action = _cards2array( last_landlord_action = _cards2array(
infoset.last_move_dict['landlord']) infoset.last_move_dict['landlord'], compressed_form)
last_landlord_action_batch = np.repeat( last_landlord_action_batch = np.repeat(
last_landlord_action[np.newaxis, :], last_landlord_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_num_cards_left = _get_one_hot_array( landlord_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord'], 33) infoset.num_cards_left_dict['landlord'], 33, 15)
landlord_num_cards_left_batch = np.repeat( landlord_num_cards_left_batch = np.repeat(
landlord_num_cards_left[np.newaxis, :], landlord_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
landlord_played_cards = _cards2array( landlord_played_cards = _cards2array(
infoset.played_cards['landlord']) infoset.played_cards['landlord'], compressed_form)
landlord_played_cards_batch = np.repeat( landlord_played_cards_batch = np.repeat(
landlord_played_cards[np.newaxis, :], landlord_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_teammate_action = _cards2array( last_teammate_action = _cards2array(
infoset.last_move_dict['landlord_up']) infoset.last_move_dict['landlord_up'], compressed_form)
last_teammate_action_batch = np.repeat( last_teammate_action_batch = np.repeat(
last_teammate_action[np.newaxis, :], last_teammate_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_num_cards_left = _get_one_hot_array( teammate_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_up'], 25) infoset.num_cards_left_dict['landlord_up'], 25, 8)
teammate_num_cards_left_batch = np.repeat( teammate_num_cards_left_batch = np.repeat(
teammate_num_cards_left[np.newaxis, :], teammate_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_played_cards = _cards2array( teammate_played_cards = _cards2array(
infoset.played_cards['landlord_up']) infoset.played_cards['landlord_up'], compressed_form)
teammate_played_cards_batch = np.repeat( teammate_played_cards_batch = np.repeat(
teammate_played_cards[np.newaxis, :], teammate_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
last_teammate_front_action = _cards2array( last_teammate_front_action = _cards2array(
infoset.last_move_dict['landlord_front']) infoset.last_move_dict['landlord_front'], compressed_form)
last_teammate_front_action_batch = np.repeat( last_teammate_front_action_batch = np.repeat(
last_teammate_front_action[np.newaxis, :], last_teammate_front_action[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_front_num_cards_left = _get_one_hot_array( teammate_front_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_front'], 25) infoset.num_cards_left_dict['landlord_front'], 25, 8)
teammate_front_num_cards_left_batch = np.repeat( teammate_front_num_cards_left_batch = np.repeat(
teammate_front_num_cards_left[np.newaxis, :], teammate_front_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
teammate_front_played_cards = _cards2array( teammate_front_played_cards = _cards2array(
infoset.played_cards['landlord_front']) infoset.played_cards['landlord_front'], compressed_form)
teammate_front_played_cards_batch = np.repeat( teammate_front_played_cards_batch = np.repeat(
teammate_front_played_cards[np.newaxis, :], teammate_front_played_cards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -855,7 +922,7 @@ def _get_obs_landlord_down(infoset, use_legacy = False):
teammate_front_num_cards_left, teammate_front_num_cards_left,
bomb_num)) bomb_num))
z = _action_seq_list2array(_process_action_seq( z = _action_seq_list2array(_process_action_seq(
infoset.card_play_action_seq, 20, False), False) infoset.card_play_action_seq, 20, False), False, compressed_form)
z_batch = np.repeat( z_batch = np.repeat(
z[np.newaxis, :, :], z[np.newaxis, :, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -869,41 +936,41 @@ def _get_obs_landlord_down(infoset, use_legacy = False):
} }
return obs return obs
def _get_obs_general(infoset, position): def _get_obs_general(infoset, position, compressed_form = False):
num_legal_actions = len(infoset.legal_actions) num_legal_actions = len(infoset.legal_actions)
my_handcards = _cards2array(infoset.player_hand_cards) my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
other_handcards = _cards2array(infoset.other_hand_cards) other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
my_action_batch = np.zeros(my_handcards_batch.shape) my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions): for j, action in enumerate(infoset.legal_actions):
my_action_batch[j, :] = _cards2array(action) my_action_batch[j, :] = _cards2array(action, compressed_form)
landlord_num_cards_left = _get_one_hot_array( landlord_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord'], 33) infoset.num_cards_left_dict['landlord'], 33, 15)
landlord_up_num_cards_left = _get_one_hot_array( landlord_up_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_up'], 25) infoset.num_cards_left_dict['landlord_up'], 25, 8)
landlord_front_num_cards_left = _get_one_hot_array( landlord_front_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_front'], 25) infoset.num_cards_left_dict['landlord_front'], 25, 8)
landlord_down_num_cards_left = _get_one_hot_array( landlord_down_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_down'], 25) infoset.num_cards_left_dict['landlord_down'], 25, 8)
landlord_played_cards = _cards2array( landlord_played_cards = _cards2array(
infoset.played_cards['landlord']) infoset.played_cards['landlord'], compressed_form)
landlord_up_played_cards = _cards2array( landlord_up_played_cards = _cards2array(
infoset.played_cards['landlord_up']) infoset.played_cards['landlord_up'], compressed_form)
landlord_front_played_cards = _cards2array( landlord_front_played_cards = _cards2array(
infoset.played_cards['landlord_front']) infoset.played_cards['landlord_front'], compressed_form)
landlord_down_played_cards = _cards2array( landlord_down_played_cards = _cards2array(
infoset.played_cards['landlord_down']) infoset.played_cards['landlord_down'], compressed_form)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num)
@ -911,10 +978,10 @@ def _get_obs_general(infoset, position):
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
num_cards_left = np.hstack(( num_cards_left = np.hstack((
landlord_num_cards_left, # 33 landlord_num_cards_left, # 33/18
landlord_up_num_cards_left, # 25 landlord_up_num_cards_left, # 25/17
landlord_front_num_cards_left, # 25 landlord_front_num_cards_left, # 25/17
landlord_down_num_cards_left)) landlord_down_num_cards_left)) # 25/17
x_batch = np.hstack(( x_batch = np.hstack((
bomb_num_batch, # 56 bomb_num_batch, # 56
@ -924,20 +991,23 @@ def _get_obs_general(infoset, position):
)) ))
z =np.vstack(( z =np.vstack((
num_cards_left, num_cards_left, # 108 / 18+17*3=69
my_handcards, # 108 my_handcards, # 108/69
other_handcards, # 108 other_handcards, # 108/69
landlord_played_cards, # 108 landlord_played_cards, # 108/69
landlord_up_played_cards, # 108 landlord_up_played_cards, # 108/69
landlord_front_played_cards, # 108 landlord_front_played_cards, # 108/69
landlord_down_played_cards, # 108 landlord_down_played_cards, # 108/69
_action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32)) _action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32), True, compressed_form)
)) ))
_z_batch = np.repeat( _z_batch = np.repeat(
z[np.newaxis, :, :], z[np.newaxis, :, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
my_action_batch = my_action_batch[:,np.newaxis,:] my_action_batch = my_action_batch[:,np.newaxis,:]
if compressed_form:
z_batch = np.zeros([len(_z_batch), 40, 69], int)
else:
z_batch = np.zeros([len(_z_batch),40,108],int) z_batch = np.zeros([len(_z_batch),40,108],int)
for i in range(0,len(_z_batch)): for i in range(0,len(_z_batch)):
z_batch[i] = np.vstack((my_action_batch[i],_z_batch[i])) z_batch[i] = np.vstack((my_action_batch[i],_z_batch[i]))

View File

@ -49,6 +49,7 @@ class DeepAgent:
def __init__(self, position, model_path): def __init__(self, position, model_path):
self.use_legacy = True if "legacy" in model_path else False self.use_legacy = True if "legacy" in model_path else False
self.lite_model = True if "lite" in model_path else False
self.model_type = "general" if "resnet" in model_path else "old" self.model_type = "general" if "resnet" in model_path else "old"
self.model = _load_model(position, model_path, self.model_type, self.use_legacy) self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider']) self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
@ -59,7 +60,7 @@ class DeepAgent:
if len(infoset.legal_actions) == 1: if len(infoset.legal_actions) == 1:
return infoset.legal_actions[0] return infoset.legal_actions[0]
obs = get_obs(infoset, self.model_type == "general", self.use_legacy) obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model)
z_batch = torch.from_numpy(obs['z_batch']).float() z_batch = torch.from_numpy(obs['z_batch']).float()
x_batch = torch.from_numpy(obs['x_batch']).float() x_batch = torch.from_numpy(obs['x_batch']).float()