新增压缩版模型
This commit is contained in:
parent
d5401cefc9
commit
b5982b7195
|
@ -29,6 +29,8 @@ parser.add_argument('--load_model', action='store_true',
|
|||
help='Load an existing model')
|
||||
parser.add_argument('--old_model', action='store_true',
|
||||
help='Use vanilla model')
|
||||
parser.add_argument('--lite_model', action='store_true',
|
||||
help='Use lite card model')
|
||||
parser.add_argument('--lagacy_model', action='store_true',
|
||||
help='Use lagacy bomb model')
|
||||
parser.add_argument('--disable_checkpoint', action='store_true',
|
||||
|
|
|
@ -87,7 +87,7 @@ def train(flags):
|
|||
if flags.actor_device_cpu:
|
||||
device_iterator = ['cpu']
|
||||
else:
|
||||
device_iterator = [0, 'cpu'] #range(flags.num_actor_devices) #[0, 'cpu']
|
||||
device_iterator = range(flags.num_actor_devices) #[0, 'cpu']
|
||||
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
||||
|
||||
# Initialize actor models
|
||||
|
@ -96,7 +96,7 @@ def train(flags):
|
|||
if flags.old_model:
|
||||
model = OldModel(device="cpu", flags = flags)
|
||||
else:
|
||||
model = Model(device="cpu", flags = flags)
|
||||
model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
|
||||
model.share_memory()
|
||||
model.eval()
|
||||
models[device] = model
|
||||
|
@ -111,7 +111,7 @@ def train(flags):
|
|||
if flags.old_model:
|
||||
learner_model = OldModel(device=flags.training_device)
|
||||
else:
|
||||
learner_model = Model(device=flags.training_device)
|
||||
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
|
||||
|
||||
# Create optimizers
|
||||
optimizers = create_optimizers(flags, learner_model)
|
||||
|
|
|
@ -240,6 +240,69 @@ class BasicBlock(nn.Module):
|
|||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class GeneralModelLite(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_planes = 80
|
||||
#input 1*69*41
|
||||
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
|
||||
stride=(2,), padding=1, bias=False) #1*35*80
|
||||
|
||||
self.bn1 = nn.BatchNorm1d(80)
|
||||
|
||||
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*18*80
|
||||
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*9*160
|
||||
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*5*320
|
||||
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear1 = nn.Linear(320 * BasicBlock.expansion * 5 + 56, 1536)
|
||||
self.linear2 = nn.Linear(1536, 768)
|
||||
self.linear3 = nn.Linear(768, 384)
|
||||
self.linear4 = nn.Linear(384, 1)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_onnx_params(self, device=None):
|
||||
return {
|
||||
'args': (
|
||||
torch.randn(1, 40, 69, requires_grad=True, device=device),
|
||||
torch.randn(1, 56, requires_grad=True, device=device),
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "batch_size"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "batch_size"
|
||||
},
|
||||
'values': {
|
||||
0: "batch_size"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def forward(self, z, x):
|
||||
out = F.relu(self.bn1(self.conv1(z)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = out.flatten(1,2)
|
||||
out = torch.cat([x,out], dim=-1)
|
||||
out = F.leaky_relu_(self.linear1(out))
|
||||
out = F.leaky_relu_(self.linear2(out))
|
||||
out = F.leaky_relu_(self.linear3(out))
|
||||
out = F.leaky_relu_(self.linear4(out))
|
||||
return dict(values=out)
|
||||
|
||||
|
||||
class GeneralModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -414,7 +477,7 @@ class Model:
|
|||
The wrapper for the three models. We also wrap several
|
||||
interfaces such as share_memory, eval, etc.
|
||||
"""
|
||||
def __init__(self, device=0, flags=None):
|
||||
def __init__(self, device=0, flags=None, lite_model = False):
|
||||
self.models = {}
|
||||
self.onnx_models = {}
|
||||
self.flags = flags
|
||||
|
@ -428,7 +491,10 @@ class Model:
|
|||
self.models[position] = None
|
||||
else:
|
||||
for position in positions:
|
||||
self.models[position] = GeneralModel().to(self.device)
|
||||
if lite_model:
|
||||
self.models[position] = GeneralModelLite().to(self.device)
|
||||
else:
|
||||
self.models[position] = GeneralModel().to(self.device)
|
||||
self.onnx_models = {
|
||||
'landlord': None,
|
||||
'landlord_up': None,
|
||||
|
|
|
@ -36,6 +36,26 @@ NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
|
|||
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
|
||||
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
|
||||
|
||||
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
||||
1: np.array([1, 0, 0, 0, 0]),
|
||||
2: np.array([1, 1, 0, 0, 0]),
|
||||
3: np.array([1, 1, 1, 0, 0]),
|
||||
4: np.array([1, 1, 1, 1, 0]),
|
||||
5: np.array([1, 0, 0, 0, 1]),
|
||||
6: np.array([1, 1, 0, 0, 1]),
|
||||
7: np.array([1, 1, 1, 0, 1]),
|
||||
8: np.array([1, 1, 1, 1, 1])}
|
||||
|
||||
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
||||
1: np.array([1, 0, 0, 0, 0]),
|
||||
3: np.array([1, 1, 0, 0, 0]),
|
||||
4: np.array([0, 0, 1, 0, 0]),
|
||||
5: np.array([1, 0, 1, 0, 0]),
|
||||
7: np.array([1, 1, 1, 0, 0]),
|
||||
12: np.array([0, 0, 1, 1, 0]),
|
||||
13: np.array([1, 0, 1, 1, 0]),
|
||||
15: np.array([1, 1, 1, 1, 0])}
|
||||
|
||||
shandle = logging.StreamHandler()
|
||||
shandle.setFormatter(
|
||||
logging.Formatter(
|
||||
|
@ -51,7 +71,7 @@ log.setLevel(logging.INFO)
|
|||
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
||||
|
||||
def create_env(flags):
|
||||
return Env(flags.objective, flags.old_model, flags.lagacy_model)
|
||||
return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model)
|
||||
|
||||
def get_batch(b_queues, position, flags, lock):
|
||||
"""
|
||||
|
@ -133,11 +153,11 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
|||
else:
|
||||
action = obs['legal_actions'][0]
|
||||
if flags.old_model:
|
||||
obs_action_buf[position].append(_cards2tensor(action))
|
||||
obs_action_buf[position].append(_cards2tensor(action, flags.lite_model))
|
||||
obs_x_no_action[position].append(env_output['obs_x_no_action'])
|
||||
obs_z_buf[position].append(env_output['obs_z'])
|
||||
else:
|
||||
obs_z_buf[position].append(torch.vstack((_cards2tensor(action).unsqueeze(0), env_output['obs_z'])).float())
|
||||
obs_z_buf[position].append(torch.vstack((_cards2tensor(action, flags.lite_model).unsqueeze(0), env_output['obs_z'])).float())
|
||||
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
|
||||
x_batch = env_output['obs_x_no_action'].float()
|
||||
obs_x_batch_buf[position].append(x_batch)
|
||||
|
@ -196,32 +216,57 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
|||
print()
|
||||
raise e
|
||||
|
||||
def _cards2tensor(list_cards):
|
||||
def _cards2tensor(list_cards, compress_form = False):
|
||||
"""
|
||||
Convert a list of integers to the tensor
|
||||
representation
|
||||
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
|
||||
"""
|
||||
if len(list_cards) == 0:
|
||||
return torch.zeros(108, dtype=torch.int8)
|
||||
if compress_form:
|
||||
if len(list_cards) == 0:
|
||||
return torch.zeros(69, dtype=torch.int8)
|
||||
|
||||
matrix = np.zeros([8, 14], dtype=np.int8)
|
||||
counter = Counter(list_cards)
|
||||
joker_cnt = 0
|
||||
for card, num_times in counter.items():
|
||||
if card < 20:
|
||||
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
||||
elif card == 20:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b11
|
||||
else:
|
||||
joker_cnt |= 0b01
|
||||
elif card == 30:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b1100
|
||||
else:
|
||||
joker_cnt |= 0b0100
|
||||
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
|
||||
matrix = matrix.flatten('F')[:-4]
|
||||
matrix = torch.from_numpy(matrix)
|
||||
return matrix
|
||||
matrix = np.zeros([5, 14], dtype=np.int8)
|
||||
counter = Counter(list_cards)
|
||||
joker_cnt = 0
|
||||
for card, num_times in counter.items():
|
||||
if card < 20:
|
||||
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
|
||||
elif card == 20:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b11
|
||||
else:
|
||||
joker_cnt |= 0b01
|
||||
elif card == 30:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b1100
|
||||
else:
|
||||
joker_cnt |= 0b0100
|
||||
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
|
||||
matrix = matrix.flatten('F')[:-1]
|
||||
matrix = torch.from_numpy(matrix)
|
||||
return matrix
|
||||
else:
|
||||
if len(list_cards) == 0:
|
||||
return torch.zeros(108, dtype=torch.int8)
|
||||
|
||||
matrix = np.zeros([8, 14], dtype=np.int8)
|
||||
counter = Counter(list_cards)
|
||||
joker_cnt = 0
|
||||
for card, num_times in counter.items():
|
||||
if card < 20:
|
||||
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
||||
elif card == 20:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b11
|
||||
else:
|
||||
joker_cnt |= 0b01
|
||||
elif card == 30:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b1100
|
||||
else:
|
||||
joker_cnt |= 0b0100
|
||||
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
|
||||
matrix = matrix.flatten('F')[:-4]
|
||||
matrix = torch.from_numpy(matrix)
|
||||
return matrix
|
||||
|
|
|
@ -30,6 +30,26 @@ NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
|
|||
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
|
||||
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
|
||||
|
||||
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
||||
1: np.array([1, 0, 0, 0, 0]),
|
||||
2: np.array([1, 1, 0, 0, 0]),
|
||||
3: np.array([1, 1, 1, 0, 0]),
|
||||
4: np.array([1, 1, 1, 1, 0]),
|
||||
5: np.array([1, 0, 0, 0, 1]),
|
||||
6: np.array([1, 1, 0, 0, 1]),
|
||||
7: np.array([1, 1, 1, 0, 1]),
|
||||
8: np.array([1, 1, 1, 1, 1])}
|
||||
|
||||
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
||||
1: np.array([1, 0, 0, 0, 0]),
|
||||
3: np.array([1, 1, 0, 0, 0]),
|
||||
4: np.array([0, 0, 1, 0, 0]),
|
||||
5: np.array([1, 0, 1, 0, 0]),
|
||||
7: np.array([1, 1, 1, 0, 0]),
|
||||
12: np.array([0, 0, 1, 1, 0]),
|
||||
13: np.array([1, 0, 1, 1, 0]),
|
||||
15: np.array([1, 1, 1, 1, 0])}
|
||||
|
||||
|
||||
deck = []
|
||||
for i in range(3, 15):
|
||||
|
@ -43,7 +63,7 @@ class Env:
|
|||
Doudizhu multi-agent wrapper
|
||||
"""
|
||||
|
||||
def __init__(self, objective, old_model, legacy_model=False):
|
||||
def __init__(self, objective, old_model, legacy_model=False, lite_model = False):
|
||||
"""
|
||||
Objective is wp/adp/logadp. It indicates whether considers
|
||||
bomb in reward calculation. Here, we use dummy agents.
|
||||
|
@ -58,6 +78,7 @@ class Env:
|
|||
self.objective = objective
|
||||
self.use_legacy = legacy_model
|
||||
self.use_general = not old_model
|
||||
self.lite_model = lite_model
|
||||
|
||||
# Initialize players
|
||||
# We use three dummy player for the target position
|
||||
|
@ -92,7 +113,7 @@ class Env:
|
|||
card_play_data[key].sort()
|
||||
self._env.card_play_init(card_play_data)
|
||||
self.infoset = self._game_infoset
|
||||
return get_obs(self.infoset, self.use_general)
|
||||
return get_obs(self.infoset, self.use_general, self.lite_model)
|
||||
else:
|
||||
self.total_round += 1
|
||||
_deck = deck.copy()
|
||||
|
@ -118,7 +139,7 @@ class Env:
|
|||
self._env.info_sets[pos].player_id = pid
|
||||
self.infoset = self._game_infoset
|
||||
|
||||
return get_obs(self.infoset, self.use_general, self.use_legacy)
|
||||
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
|
@ -146,7 +167,7 @@ class Env:
|
|||
}
|
||||
obs = None
|
||||
else:
|
||||
obs = get_obs(self.infoset, self.use_general)
|
||||
obs = get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
|
||||
return obs, reward, done, {}
|
||||
|
||||
def _get_reward(self, pos):
|
||||
|
@ -244,7 +265,7 @@ class DummyAgent(object):
|
|||
self.action = action
|
||||
|
||||
|
||||
def get_obs(infoset, use_general=True, use_legacy = False):
|
||||
def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False):
|
||||
"""
|
||||
This function obtains observations with imperfect information
|
||||
from the infoset. It has three branches since we encode
|
||||
|
@ -271,59 +292,92 @@ def get_obs(infoset, use_general=True, use_legacy = False):
|
|||
if use_general:
|
||||
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
||||
raise ValueError('')
|
||||
return _get_obs_general(infoset, infoset.player_position)
|
||||
return _get_obs_general(infoset, infoset.player_position, lite_model)
|
||||
else:
|
||||
if infoset.player_position == 'landlord':
|
||||
return _get_obs_landlord(infoset, use_legacy)
|
||||
return _get_obs_landlord(infoset, use_legacy, lite_model)
|
||||
elif infoset.player_position == 'landlord_up':
|
||||
return _get_obs_landlord_up(infoset, use_legacy)
|
||||
return _get_obs_landlord_up(infoset, use_legacy, lite_model)
|
||||
elif infoset.player_position == 'landlord_front':
|
||||
return _get_obs_landlord_front(infoset, use_legacy)
|
||||
return _get_obs_landlord_front(infoset, use_legacy, lite_model)
|
||||
elif infoset.player_position == 'landlord_down':
|
||||
return _get_obs_landlord_down(infoset, use_legacy)
|
||||
return _get_obs_landlord_down(infoset, use_legacy, lite_model)
|
||||
else:
|
||||
raise ValueError('')
|
||||
|
||||
|
||||
def _get_one_hot_array(num_left_cards, max_num_cards):
|
||||
def _get_one_hot_array(num_left_cards, max_num_cards, compress_size = 0):
|
||||
"""
|
||||
A utility function to obtain one-hot endoding
|
||||
"""
|
||||
one_hot = np.zeros(max_num_cards)
|
||||
if num_left_cards > 0:
|
||||
one_hot[num_left_cards - 1] = 1
|
||||
if compress_size > 0:
|
||||
assert compress_size <= max_num_cards / 2
|
||||
array_size = max_num_cards - compress_size
|
||||
one_hot = np.zeros(array_size)
|
||||
if num_left_cards >= array_size:
|
||||
one_hot[-1] = 1
|
||||
num_left_cards -= array_size
|
||||
if num_left_cards > 0:
|
||||
one_hot[num_left_cards - 1] = 1
|
||||
else:
|
||||
one_hot = np.zeros(max_num_cards)
|
||||
if num_left_cards > 0:
|
||||
one_hot[num_left_cards - 1] = 1
|
||||
|
||||
return one_hot
|
||||
|
||||
|
||||
def _cards2array(list_cards):
|
||||
def _cards2array(list_cards, compressed_form = False):
|
||||
"""
|
||||
A utility function that transforms the actions, i.e.,
|
||||
A list of integers into card matrix. Here we remove
|
||||
the six entries that are always zero and flatten the
|
||||
the representations.
|
||||
"""
|
||||
if len(list_cards) == 0:
|
||||
return np.zeros(108, dtype=np.int8)
|
||||
if compressed_form:
|
||||
if len(list_cards) == 0:
|
||||
return np.zeros(69, dtype=np.int8)
|
||||
|
||||
matrix = np.zeros([8, 14], dtype=np.int8)
|
||||
counter = Counter(list_cards)
|
||||
joker_cnt = 0
|
||||
for card, num_times in counter.items():
|
||||
if card < 20:
|
||||
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
||||
elif card == 20:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b11
|
||||
else:
|
||||
joker_cnt |= 0b01
|
||||
elif card == 30:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b1100
|
||||
else:
|
||||
joker_cnt |= 0b0100
|
||||
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
|
||||
return matrix.flatten('F')[:-4]
|
||||
matrix = np.zeros([5, 14], dtype=np.int8)
|
||||
counter = Counter(list_cards)
|
||||
joker_cnt = 0
|
||||
for card, num_times in counter.items():
|
||||
if card < 20:
|
||||
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
|
||||
elif card == 20:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b11
|
||||
else:
|
||||
joker_cnt |= 0b01
|
||||
elif card == 30:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b1100
|
||||
else:
|
||||
joker_cnt |= 0b0100
|
||||
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
|
||||
return matrix.flatten('F')[:-1]
|
||||
else:
|
||||
if len(list_cards) == 0:
|
||||
return np.zeros(108, dtype=np.int8)
|
||||
|
||||
matrix = np.zeros([8, 14], dtype=np.int8)
|
||||
counter = Counter(list_cards)
|
||||
joker_cnt = 0
|
||||
for card, num_times in counter.items():
|
||||
if card < 20:
|
||||
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
||||
elif card == 20:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b11
|
||||
else:
|
||||
joker_cnt |= 0b01
|
||||
elif card == 30:
|
||||
if num_times == 2:
|
||||
joker_cnt |= 0b1100
|
||||
else:
|
||||
joker_cnt |= 0b0100
|
||||
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
|
||||
return matrix.flatten('F')[:-4]
|
||||
|
||||
|
||||
# def _action_seq_list2array(action_seq_list):
|
||||
|
@ -342,7 +396,7 @@ def _cards2array(list_cards):
|
|||
# # action_seq_array = action_seq_array.reshape(5, 162)
|
||||
# return action_seq_array
|
||||
|
||||
def _action_seq_list2array(action_seq_list, new_model=True):
|
||||
def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = False):
|
||||
"""
|
||||
A utility function to encode the historical moves.
|
||||
We encode the historical 20 actions. If there is
|
||||
|
@ -355,16 +409,29 @@ def _action_seq_list2array(action_seq_list, new_model=True):
|
|||
|
||||
if new_model:
|
||||
# position_map = {"landlord": 0, "landlord_up": 1, "landlord_front": 2, "landlord_down": 3}
|
||||
action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area
|
||||
for row, list_cards in enumerate(action_seq_list):
|
||||
if list_cards != []:
|
||||
action_seq_array[row, :108] = _cards2array(list_cards[1])
|
||||
if compressed_form:
|
||||
action_seq_array = np.full((len(action_seq_list), 69), -1) # Default Value -1 for not using area
|
||||
for row, list_cards in enumerate(action_seq_list):
|
||||
if list_cards != []:
|
||||
action_seq_array[row, :69] = _cards2array(list_cards[1], compressed_form)
|
||||
else:
|
||||
action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area
|
||||
for row, list_cards in enumerate(action_seq_list):
|
||||
if list_cards != []:
|
||||
action_seq_array[row, :108] = _cards2array(list_cards[1], compressed_form)
|
||||
else:
|
||||
action_seq_array = np.zeros((len(action_seq_list), 108))
|
||||
for row, list_cards in enumerate(action_seq_list):
|
||||
if list_cards != []:
|
||||
action_seq_array[row, :] = _cards2array(list_cards[1])
|
||||
action_seq_array = action_seq_array.reshape(5, 432)
|
||||
if compressed_form:
|
||||
action_seq_array = np.zeros((len(action_seq_list), 69))
|
||||
for row, list_cards in enumerate(action_seq_list):
|
||||
if list_cards != []:
|
||||
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
|
||||
action_seq_array = action_seq_array.reshape(5, 276)
|
||||
else:
|
||||
action_seq_array = np.zeros((len(action_seq_list), 108))
|
||||
for row, list_cards in enumerate(action_seq_list):
|
||||
if list_cards != []:
|
||||
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
|
||||
action_seq_array = action_seq_array.reshape(5, 432)
|
||||
return action_seq_array
|
||||
|
||||
# action_seq_array = np.zeros((len(action_seq_list), 54))
|
||||
|
@ -406,60 +473,60 @@ def _get_one_hot_bomb(bomb_num, use_legacy = False):
|
|||
return one_hot
|
||||
|
||||
|
||||
def _get_obs_landlord(infoset, use_legacy = False):
|
||||
def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False):
|
||||
"""
|
||||
Obttain the landlord features. See Table 4 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
"""
|
||||
num_legal_actions = len(infoset.legal_actions)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
||||
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_action = _cards2array(infoset.last_move)
|
||||
last_action = _cards2array(infoset.last_move, compressed_form)
|
||||
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||
for j, action in enumerate(infoset.legal_actions):
|
||||
my_action_batch[j, :] = _cards2array(action)
|
||||
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||
|
||||
landlord_up_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_up'], 25)
|
||||
infoset.num_cards_left_dict['landlord_up'], 25, 15)
|
||||
landlord_up_num_cards_left_batch = np.repeat(
|
||||
landlord_up_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
landlord_front_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
||||
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||
landlord_front_num_cards_left_batch = np.repeat(
|
||||
landlord_front_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
landlord_down_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_down'], 25)
|
||||
infoset.num_cards_left_dict['landlord_down'], 25, 8)
|
||||
landlord_down_num_cards_left_batch = np.repeat(
|
||||
landlord_down_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
landlord_up_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_up'])
|
||||
infoset.played_cards['landlord_up'], 8)
|
||||
landlord_up_played_cards_batch = np.repeat(
|
||||
landlord_up_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
landlord_front_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_front'])
|
||||
infoset.played_cards['landlord_front'], compressed_form)
|
||||
landlord_front_played_cards_batch = np.repeat(
|
||||
landlord_front_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
landlord_down_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_down'])
|
||||
infoset.played_cards['landlord_down'], compressed_form)
|
||||
landlord_down_played_cards_batch = np.repeat(
|
||||
landlord_down_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -492,7 +559,7 @@ def _get_obs_landlord(infoset, use_legacy = False):
|
|||
landlord_down_num_cards_left,
|
||||
bomb_num))
|
||||
z = _action_seq_list2array(_process_action_seq(
|
||||
infoset.card_play_action_seq, 20, False), False)
|
||||
infoset.card_play_action_seq, 20, False), False, compressed_form)
|
||||
z_batch = np.repeat(
|
||||
z[np.newaxis, :, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -506,75 +573,75 @@ def _get_obs_landlord(infoset, use_legacy = False):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_landlord_up(infoset, use_legacy = False):
|
||||
def _get_obs_landlord_up(infoset, use_legacy = False, compressed_form = False):
|
||||
"""
|
||||
Obttain the landlord_up features. See Table 5 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
"""
|
||||
num_legal_actions = len(infoset.legal_actions)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
||||
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_action = _cards2array(infoset.last_move)
|
||||
last_action = _cards2array(infoset.last_move, compressed_form)
|
||||
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||
for j, action in enumerate(infoset.legal_actions):
|
||||
my_action_batch[j, :] = _cards2array(action)
|
||||
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||
|
||||
last_landlord_action = _cards2array(
|
||||
infoset.last_move_dict['landlord'])
|
||||
infoset.last_move_dict['landlord'], compressed_form)
|
||||
last_landlord_action_batch = np.repeat(
|
||||
last_landlord_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
landlord_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord'], 33)
|
||||
infoset.num_cards_left_dict['landlord'], 33, 15)
|
||||
landlord_num_cards_left_batch = np.repeat(
|
||||
landlord_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
landlord_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord'])
|
||||
infoset.played_cards['landlord'], compressed_form)
|
||||
landlord_played_cards_batch = np.repeat(
|
||||
landlord_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_teammate_action = _cards2array(
|
||||
infoset.last_move_dict['landlord_down'])
|
||||
infoset.last_move_dict['landlord_down'], compressed_form)
|
||||
last_teammate_action_batch = np.repeat(
|
||||
last_teammate_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
teammate_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_down'], 25)
|
||||
infoset.num_cards_left_dict['landlord_down'], 25, 8)
|
||||
teammate_num_cards_left_batch = np.repeat(
|
||||
teammate_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
teammate_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_down'])
|
||||
infoset.played_cards['landlord_down'], compressed_form)
|
||||
teammate_played_cards_batch = np.repeat(
|
||||
teammate_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_teammate_front_action = _cards2array(
|
||||
infoset.last_move_dict['landlord_front'])
|
||||
infoset.last_move_dict['landlord_front'], compressed_form)
|
||||
last_teammate_front_action_batch = np.repeat(
|
||||
last_teammate_front_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
teammate_front_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
||||
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||
teammate_front_num_cards_left_batch = np.repeat(
|
||||
teammate_front_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
teammate_front_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_front'])
|
||||
infoset.played_cards['landlord_front'], compressed_form)
|
||||
teammate_front_played_cards_batch = np.repeat(
|
||||
teammate_front_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -613,7 +680,7 @@ def _get_obs_landlord_up(infoset, use_legacy = False):
|
|||
teammate_front_num_cards_left,
|
||||
bomb_num))
|
||||
z = _action_seq_list2array(_process_action_seq(
|
||||
infoset.card_play_action_seq, 20, False), False)
|
||||
infoset.card_play_action_seq, 20, False), False, compressed_form)
|
||||
z_batch = np.repeat(
|
||||
z[np.newaxis, :, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -627,75 +694,75 @@ def _get_obs_landlord_up(infoset, use_legacy = False):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_landlord_front(infoset, use_legacy = False):
|
||||
def _get_obs_landlord_front(infoset, use_legacy = False, compressed_form = False):
|
||||
"""
|
||||
Obttain the landlord_front features. See Table 5 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
"""
|
||||
num_legal_actions = len(infoset.legal_actions)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
||||
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_action = _cards2array(infoset.last_move)
|
||||
last_action = _cards2array(infoset.last_move, compressed_form)
|
||||
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||
for j, action in enumerate(infoset.legal_actions):
|
||||
my_action_batch[j, :] = _cards2array(action)
|
||||
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||
|
||||
last_landlord_action = _cards2array(
|
||||
infoset.last_move_dict['landlord'])
|
||||
infoset.last_move_dict['landlord'], compressed_form)
|
||||
last_landlord_action_batch = np.repeat(
|
||||
last_landlord_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
landlord_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord'], 33)
|
||||
infoset.num_cards_left_dict['landlord'], 33, 15)
|
||||
landlord_num_cards_left_batch = np.repeat(
|
||||
landlord_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
landlord_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord'])
|
||||
infoset.played_cards['landlord'], compressed_form)
|
||||
landlord_played_cards_batch = np.repeat(
|
||||
landlord_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_teammate_action = _cards2array(
|
||||
infoset.last_move_dict['landlord_down'])
|
||||
infoset.last_move_dict['landlord_down'], compressed_form)
|
||||
last_teammate_action_batch = np.repeat(
|
||||
last_teammate_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
teammate_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_down'], 25)
|
||||
infoset.num_cards_left_dict['landlord_down'], 25, 8)
|
||||
teammate_num_cards_left_batch = np.repeat(
|
||||
teammate_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
teammate_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_down'])
|
||||
infoset.played_cards['landlord_down'], compressed_form)
|
||||
teammate_played_cards_batch = np.repeat(
|
||||
teammate_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_teammate_front_action = _cards2array(
|
||||
infoset.last_move_dict['landlord_front'])
|
||||
infoset.last_move_dict['landlord_front'], compressed_form)
|
||||
last_teammate_front_action_batch = np.repeat(
|
||||
last_teammate_front_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
teammate_front_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
||||
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||
teammate_front_num_cards_left_batch = np.repeat(
|
||||
teammate_front_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
teammate_front_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_front'])
|
||||
infoset.played_cards['landlord_front'], compressed_form)
|
||||
teammate_front_played_cards_batch = np.repeat(
|
||||
teammate_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -734,7 +801,7 @@ def _get_obs_landlord_front(infoset, use_legacy = False):
|
|||
teammate_front_num_cards_left,
|
||||
bomb_num))
|
||||
z = _action_seq_list2array(_process_action_seq(
|
||||
infoset.card_play_action_seq, 20, False), False)
|
||||
infoset.card_play_action_seq, 20, False), False, compressed_form)
|
||||
z_batch = np.repeat(
|
||||
z[np.newaxis, :, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -748,75 +815,75 @@ def _get_obs_landlord_front(infoset, use_legacy = False):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_landlord_down(infoset, use_legacy = False):
|
||||
def _get_obs_landlord_down(infoset, use_legacy = False, compressed_form = False):
|
||||
"""
|
||||
Obttain the landlord_down features. See Table 5 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
"""
|
||||
num_legal_actions = len(infoset.legal_actions)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
||||
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_action = _cards2array(infoset.last_move)
|
||||
last_action = _cards2array(infoset.last_move, compressed_form)
|
||||
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||
for j, action in enumerate(infoset.legal_actions):
|
||||
my_action_batch[j, :] = _cards2array(action)
|
||||
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||
|
||||
last_landlord_action = _cards2array(
|
||||
infoset.last_move_dict['landlord'])
|
||||
infoset.last_move_dict['landlord'], compressed_form)
|
||||
last_landlord_action_batch = np.repeat(
|
||||
last_landlord_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
landlord_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord'], 33)
|
||||
infoset.num_cards_left_dict['landlord'], 33, 15)
|
||||
landlord_num_cards_left_batch = np.repeat(
|
||||
landlord_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
landlord_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord'])
|
||||
infoset.played_cards['landlord'], compressed_form)
|
||||
landlord_played_cards_batch = np.repeat(
|
||||
landlord_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_teammate_action = _cards2array(
|
||||
infoset.last_move_dict['landlord_up'])
|
||||
infoset.last_move_dict['landlord_up'], compressed_form)
|
||||
last_teammate_action_batch = np.repeat(
|
||||
last_teammate_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
teammate_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_up'], 25)
|
||||
infoset.num_cards_left_dict['landlord_up'], 25, 8)
|
||||
teammate_num_cards_left_batch = np.repeat(
|
||||
teammate_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
teammate_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_up'])
|
||||
infoset.played_cards['landlord_up'], compressed_form)
|
||||
teammate_played_cards_batch = np.repeat(
|
||||
teammate_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
last_teammate_front_action = _cards2array(
|
||||
infoset.last_move_dict['landlord_front'])
|
||||
infoset.last_move_dict['landlord_front'], compressed_form)
|
||||
last_teammate_front_action_batch = np.repeat(
|
||||
last_teammate_front_action[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
teammate_front_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
||||
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||
teammate_front_num_cards_left_batch = np.repeat(
|
||||
teammate_front_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
teammate_front_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_front'])
|
||||
infoset.played_cards['landlord_front'], compressed_form)
|
||||
teammate_front_played_cards_batch = np.repeat(
|
||||
teammate_front_played_cards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -855,7 +922,7 @@ def _get_obs_landlord_down(infoset, use_legacy = False):
|
|||
teammate_front_num_cards_left,
|
||||
bomb_num))
|
||||
z = _action_seq_list2array(_process_action_seq(
|
||||
infoset.card_play_action_seq, 20, False), False)
|
||||
infoset.card_play_action_seq, 20, False), False, compressed_form)
|
||||
z_batch = np.repeat(
|
||||
z[np.newaxis, :, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -869,41 +936,41 @@ def _get_obs_landlord_down(infoset, use_legacy = False):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_general(infoset, position):
|
||||
def _get_obs_general(infoset, position, compressed_form = False):
|
||||
num_legal_actions = len(infoset.legal_actions)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
||||
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||
|
||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||
for j, action in enumerate(infoset.legal_actions):
|
||||
my_action_batch[j, :] = _cards2array(action)
|
||||
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||
|
||||
landlord_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord'], 33)
|
||||
infoset.num_cards_left_dict['landlord'], 33, 15)
|
||||
|
||||
landlord_up_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_up'], 25)
|
||||
infoset.num_cards_left_dict['landlord_up'], 25, 8)
|
||||
|
||||
landlord_front_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
||||
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||
|
||||
landlord_down_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_down'], 25)
|
||||
infoset.num_cards_left_dict['landlord_down'], 25, 8)
|
||||
|
||||
landlord_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord'])
|
||||
infoset.played_cards['landlord'], compressed_form)
|
||||
|
||||
landlord_up_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_up'])
|
||||
infoset.played_cards['landlord_up'], compressed_form)
|
||||
|
||||
landlord_front_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_front'])
|
||||
infoset.played_cards['landlord_front'], compressed_form)
|
||||
|
||||
landlord_down_played_cards = _cards2array(
|
||||
infoset.played_cards['landlord_down'])
|
||||
infoset.played_cards['landlord_down'], compressed_form)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
|
@ -911,10 +978,10 @@ def _get_obs_general(infoset, position):
|
|||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
num_cards_left = np.hstack((
|
||||
landlord_num_cards_left, # 33
|
||||
landlord_up_num_cards_left, # 25
|
||||
landlord_front_num_cards_left, # 25
|
||||
landlord_down_num_cards_left))
|
||||
landlord_num_cards_left, # 33/18
|
||||
landlord_up_num_cards_left, # 25/17
|
||||
landlord_front_num_cards_left, # 25/17
|
||||
landlord_down_num_cards_left)) # 25/17
|
||||
|
||||
x_batch = np.hstack((
|
||||
bomb_num_batch, # 56
|
||||
|
@ -924,21 +991,24 @@ def _get_obs_general(infoset, position):
|
|||
))
|
||||
|
||||
z =np.vstack((
|
||||
num_cards_left,
|
||||
my_handcards, # 108
|
||||
other_handcards, # 108
|
||||
landlord_played_cards, # 108
|
||||
landlord_up_played_cards, # 108
|
||||
landlord_front_played_cards, # 108
|
||||
landlord_down_played_cards, # 108
|
||||
_action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32))
|
||||
num_cards_left, # 108 / 18+17*3=69
|
||||
my_handcards, # 108/69
|
||||
other_handcards, # 108/69
|
||||
landlord_played_cards, # 108/69
|
||||
landlord_up_played_cards, # 108/69
|
||||
landlord_front_played_cards, # 108/69
|
||||
landlord_down_played_cards, # 108/69
|
||||
_action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32), True, compressed_form)
|
||||
))
|
||||
|
||||
_z_batch = np.repeat(
|
||||
z[np.newaxis, :, :],
|
||||
num_legal_actions, axis=0)
|
||||
my_action_batch = my_action_batch[:,np.newaxis,:]
|
||||
z_batch = np.zeros([len(_z_batch),40,108],int)
|
||||
if compressed_form:
|
||||
z_batch = np.zeros([len(_z_batch), 40, 69], int)
|
||||
else:
|
||||
z_batch = np.zeros([len(_z_batch),40,108],int)
|
||||
for i in range(0,len(_z_batch)):
|
||||
z_batch[i] = np.vstack((my_action_batch[i],_z_batch[i]))
|
||||
obs = {
|
||||
|
|
|
@ -49,6 +49,7 @@ class DeepAgent:
|
|||
|
||||
def __init__(self, position, model_path):
|
||||
self.use_legacy = True if "legacy" in model_path else False
|
||||
self.lite_model = True if "lite" in model_path else False
|
||||
self.model_type = "general" if "resnet" in model_path else "old"
|
||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
|
||||
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
|
||||
|
@ -59,7 +60,7 @@ class DeepAgent:
|
|||
if len(infoset.legal_actions) == 1:
|
||||
return infoset.legal_actions[0]
|
||||
|
||||
obs = get_obs(infoset, self.model_type == "general", self.use_legacy)
|
||||
obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model)
|
||||
|
||||
z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||
x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||
|
|
Loading…
Reference in New Issue