新增压缩版模型
This commit is contained in:
parent
d5401cefc9
commit
b5982b7195
|
@ -29,6 +29,8 @@ parser.add_argument('--load_model', action='store_true',
|
||||||
help='Load an existing model')
|
help='Load an existing model')
|
||||||
parser.add_argument('--old_model', action='store_true',
|
parser.add_argument('--old_model', action='store_true',
|
||||||
help='Use vanilla model')
|
help='Use vanilla model')
|
||||||
|
parser.add_argument('--lite_model', action='store_true',
|
||||||
|
help='Use lite card model')
|
||||||
parser.add_argument('--lagacy_model', action='store_true',
|
parser.add_argument('--lagacy_model', action='store_true',
|
||||||
help='Use lagacy bomb model')
|
help='Use lagacy bomb model')
|
||||||
parser.add_argument('--disable_checkpoint', action='store_true',
|
parser.add_argument('--disable_checkpoint', action='store_true',
|
||||||
|
|
|
@ -87,7 +87,7 @@ def train(flags):
|
||||||
if flags.actor_device_cpu:
|
if flags.actor_device_cpu:
|
||||||
device_iterator = ['cpu']
|
device_iterator = ['cpu']
|
||||||
else:
|
else:
|
||||||
device_iterator = [0, 'cpu'] #range(flags.num_actor_devices) #[0, 'cpu']
|
device_iterator = range(flags.num_actor_devices) #[0, 'cpu']
|
||||||
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
||||||
|
|
||||||
# Initialize actor models
|
# Initialize actor models
|
||||||
|
@ -96,7 +96,7 @@ def train(flags):
|
||||||
if flags.old_model:
|
if flags.old_model:
|
||||||
model = OldModel(device="cpu", flags = flags)
|
model = OldModel(device="cpu", flags = flags)
|
||||||
else:
|
else:
|
||||||
model = Model(device="cpu", flags = flags)
|
model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
|
||||||
model.share_memory()
|
model.share_memory()
|
||||||
model.eval()
|
model.eval()
|
||||||
models[device] = model
|
models[device] = model
|
||||||
|
@ -111,7 +111,7 @@ def train(flags):
|
||||||
if flags.old_model:
|
if flags.old_model:
|
||||||
learner_model = OldModel(device=flags.training_device)
|
learner_model = OldModel(device=flags.training_device)
|
||||||
else:
|
else:
|
||||||
learner_model = Model(device=flags.training_device)
|
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
|
||||||
|
|
||||||
# Create optimizers
|
# Create optimizers
|
||||||
optimizers = create_optimizers(flags, learner_model)
|
optimizers = create_optimizers(flags, learner_model)
|
||||||
|
|
|
@ -240,6 +240,69 @@ class BasicBlock(nn.Module):
|
||||||
out = F.relu(out)
|
out = F.relu(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralModelLite(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.in_planes = 80
|
||||||
|
#input 1*69*41
|
||||||
|
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
|
||||||
|
stride=(2,), padding=1, bias=False) #1*35*80
|
||||||
|
|
||||||
|
self.bn1 = nn.BatchNorm1d(80)
|
||||||
|
|
||||||
|
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*18*80
|
||||||
|
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*9*160
|
||||||
|
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*5*320
|
||||||
|
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||||
|
self.linear1 = nn.Linear(320 * BasicBlock.expansion * 5 + 56, 1536)
|
||||||
|
self.linear2 = nn.Linear(1536, 768)
|
||||||
|
self.linear3 = nn.Linear(768, 384)
|
||||||
|
self.linear4 = nn.Linear(384, 1)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_planes, planes, stride))
|
||||||
|
self.in_planes = planes * block.expansion
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def get_onnx_params(self, device=None):
|
||||||
|
return {
|
||||||
|
'args': (
|
||||||
|
torch.randn(1, 40, 69, requires_grad=True, device=device),
|
||||||
|
torch.randn(1, 56, requires_grad=True, device=device),
|
||||||
|
),
|
||||||
|
'input_names': ['z_batch','x_batch'],
|
||||||
|
'output_names': ['values'],
|
||||||
|
'dynamic_axes': {
|
||||||
|
'z_batch': {
|
||||||
|
0: "batch_size"
|
||||||
|
},
|
||||||
|
'x_batch': {
|
||||||
|
0: "batch_size"
|
||||||
|
},
|
||||||
|
'values': {
|
||||||
|
0: "batch_size"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, z, x):
|
||||||
|
out = F.relu(self.bn1(self.conv1(z)))
|
||||||
|
out = self.layer1(out)
|
||||||
|
out = self.layer2(out)
|
||||||
|
out = self.layer3(out)
|
||||||
|
out = out.flatten(1,2)
|
||||||
|
out = torch.cat([x,out], dim=-1)
|
||||||
|
out = F.leaky_relu_(self.linear1(out))
|
||||||
|
out = F.leaky_relu_(self.linear2(out))
|
||||||
|
out = F.leaky_relu_(self.linear3(out))
|
||||||
|
out = F.leaky_relu_(self.linear4(out))
|
||||||
|
return dict(values=out)
|
||||||
|
|
||||||
|
|
||||||
class GeneralModel(nn.Module):
|
class GeneralModel(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -414,7 +477,7 @@ class Model:
|
||||||
The wrapper for the three models. We also wrap several
|
The wrapper for the three models. We also wrap several
|
||||||
interfaces such as share_memory, eval, etc.
|
interfaces such as share_memory, eval, etc.
|
||||||
"""
|
"""
|
||||||
def __init__(self, device=0, flags=None):
|
def __init__(self, device=0, flags=None, lite_model = False):
|
||||||
self.models = {}
|
self.models = {}
|
||||||
self.onnx_models = {}
|
self.onnx_models = {}
|
||||||
self.flags = flags
|
self.flags = flags
|
||||||
|
@ -428,7 +491,10 @@ class Model:
|
||||||
self.models[position] = None
|
self.models[position] = None
|
||||||
else:
|
else:
|
||||||
for position in positions:
|
for position in positions:
|
||||||
self.models[position] = GeneralModel().to(self.device)
|
if lite_model:
|
||||||
|
self.models[position] = GeneralModelLite().to(self.device)
|
||||||
|
else:
|
||||||
|
self.models[position] = GeneralModel().to(self.device)
|
||||||
self.onnx_models = {
|
self.onnx_models = {
|
||||||
'landlord': None,
|
'landlord': None,
|
||||||
'landlord_up': None,
|
'landlord_up': None,
|
||||||
|
|
|
@ -36,6 +36,26 @@ NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
|
||||||
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
|
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
|
||||||
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
|
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
|
||||||
|
|
||||||
|
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
||||||
|
1: np.array([1, 0, 0, 0, 0]),
|
||||||
|
2: np.array([1, 1, 0, 0, 0]),
|
||||||
|
3: np.array([1, 1, 1, 0, 0]),
|
||||||
|
4: np.array([1, 1, 1, 1, 0]),
|
||||||
|
5: np.array([1, 0, 0, 0, 1]),
|
||||||
|
6: np.array([1, 1, 0, 0, 1]),
|
||||||
|
7: np.array([1, 1, 1, 0, 1]),
|
||||||
|
8: np.array([1, 1, 1, 1, 1])}
|
||||||
|
|
||||||
|
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
||||||
|
1: np.array([1, 0, 0, 0, 0]),
|
||||||
|
3: np.array([1, 1, 0, 0, 0]),
|
||||||
|
4: np.array([0, 0, 1, 0, 0]),
|
||||||
|
5: np.array([1, 0, 1, 0, 0]),
|
||||||
|
7: np.array([1, 1, 1, 0, 0]),
|
||||||
|
12: np.array([0, 0, 1, 1, 0]),
|
||||||
|
13: np.array([1, 0, 1, 1, 0]),
|
||||||
|
15: np.array([1, 1, 1, 1, 0])}
|
||||||
|
|
||||||
shandle = logging.StreamHandler()
|
shandle = logging.StreamHandler()
|
||||||
shandle.setFormatter(
|
shandle.setFormatter(
|
||||||
logging.Formatter(
|
logging.Formatter(
|
||||||
|
@ -51,7 +71,7 @@ log.setLevel(logging.INFO)
|
||||||
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
||||||
|
|
||||||
def create_env(flags):
|
def create_env(flags):
|
||||||
return Env(flags.objective, flags.old_model, flags.lagacy_model)
|
return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model)
|
||||||
|
|
||||||
def get_batch(b_queues, position, flags, lock):
|
def get_batch(b_queues, position, flags, lock):
|
||||||
"""
|
"""
|
||||||
|
@ -133,11 +153,11 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
else:
|
else:
|
||||||
action = obs['legal_actions'][0]
|
action = obs['legal_actions'][0]
|
||||||
if flags.old_model:
|
if flags.old_model:
|
||||||
obs_action_buf[position].append(_cards2tensor(action))
|
obs_action_buf[position].append(_cards2tensor(action, flags.lite_model))
|
||||||
obs_x_no_action[position].append(env_output['obs_x_no_action'])
|
obs_x_no_action[position].append(env_output['obs_x_no_action'])
|
||||||
obs_z_buf[position].append(env_output['obs_z'])
|
obs_z_buf[position].append(env_output['obs_z'])
|
||||||
else:
|
else:
|
||||||
obs_z_buf[position].append(torch.vstack((_cards2tensor(action).unsqueeze(0), env_output['obs_z'])).float())
|
obs_z_buf[position].append(torch.vstack((_cards2tensor(action, flags.lite_model).unsqueeze(0), env_output['obs_z'])).float())
|
||||||
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
|
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
|
||||||
x_batch = env_output['obs_x_no_action'].float()
|
x_batch = env_output['obs_x_no_action'].float()
|
||||||
obs_x_batch_buf[position].append(x_batch)
|
obs_x_batch_buf[position].append(x_batch)
|
||||||
|
@ -196,32 +216,57 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
print()
|
print()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _cards2tensor(list_cards):
|
def _cards2tensor(list_cards, compress_form = False):
|
||||||
"""
|
"""
|
||||||
Convert a list of integers to the tensor
|
Convert a list of integers to the tensor
|
||||||
representation
|
representation
|
||||||
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
|
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
|
||||||
"""
|
"""
|
||||||
if len(list_cards) == 0:
|
if compress_form:
|
||||||
return torch.zeros(108, dtype=torch.int8)
|
if len(list_cards) == 0:
|
||||||
|
return torch.zeros(69, dtype=torch.int8)
|
||||||
|
|
||||||
matrix = np.zeros([8, 14], dtype=np.int8)
|
matrix = np.zeros([5, 14], dtype=np.int8)
|
||||||
counter = Counter(list_cards)
|
counter = Counter(list_cards)
|
||||||
joker_cnt = 0
|
joker_cnt = 0
|
||||||
for card, num_times in counter.items():
|
for card, num_times in counter.items():
|
||||||
if card < 20:
|
if card < 20:
|
||||||
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
|
||||||
elif card == 20:
|
elif card == 20:
|
||||||
if num_times == 2:
|
if num_times == 2:
|
||||||
joker_cnt |= 0b11
|
joker_cnt |= 0b11
|
||||||
else:
|
else:
|
||||||
joker_cnt |= 0b01
|
joker_cnt |= 0b01
|
||||||
elif card == 30:
|
elif card == 30:
|
||||||
if num_times == 2:
|
if num_times == 2:
|
||||||
joker_cnt |= 0b1100
|
joker_cnt |= 0b1100
|
||||||
else:
|
else:
|
||||||
joker_cnt |= 0b0100
|
joker_cnt |= 0b0100
|
||||||
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
|
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
|
||||||
matrix = matrix.flatten('F')[:-4]
|
matrix = matrix.flatten('F')[:-1]
|
||||||
matrix = torch.from_numpy(matrix)
|
matrix = torch.from_numpy(matrix)
|
||||||
return matrix
|
return matrix
|
||||||
|
else:
|
||||||
|
if len(list_cards) == 0:
|
||||||
|
return torch.zeros(108, dtype=torch.int8)
|
||||||
|
|
||||||
|
matrix = np.zeros([8, 14], dtype=np.int8)
|
||||||
|
counter = Counter(list_cards)
|
||||||
|
joker_cnt = 0
|
||||||
|
for card, num_times in counter.items():
|
||||||
|
if card < 20:
|
||||||
|
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
||||||
|
elif card == 20:
|
||||||
|
if num_times == 2:
|
||||||
|
joker_cnt |= 0b11
|
||||||
|
else:
|
||||||
|
joker_cnt |= 0b01
|
||||||
|
elif card == 30:
|
||||||
|
if num_times == 2:
|
||||||
|
joker_cnt |= 0b1100
|
||||||
|
else:
|
||||||
|
joker_cnt |= 0b0100
|
||||||
|
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
|
||||||
|
matrix = matrix.flatten('F')[:-4]
|
||||||
|
matrix = torch.from_numpy(matrix)
|
||||||
|
return matrix
|
||||||
|
|
|
@ -30,6 +30,26 @@ NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
|
||||||
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
|
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
|
||||||
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
|
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
|
||||||
|
|
||||||
|
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
||||||
|
1: np.array([1, 0, 0, 0, 0]),
|
||||||
|
2: np.array([1, 1, 0, 0, 0]),
|
||||||
|
3: np.array([1, 1, 1, 0, 0]),
|
||||||
|
4: np.array([1, 1, 1, 1, 0]),
|
||||||
|
5: np.array([1, 0, 0, 0, 1]),
|
||||||
|
6: np.array([1, 1, 0, 0, 1]),
|
||||||
|
7: np.array([1, 1, 1, 0, 1]),
|
||||||
|
8: np.array([1, 1, 1, 1, 1])}
|
||||||
|
|
||||||
|
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
||||||
|
1: np.array([1, 0, 0, 0, 0]),
|
||||||
|
3: np.array([1, 1, 0, 0, 0]),
|
||||||
|
4: np.array([0, 0, 1, 0, 0]),
|
||||||
|
5: np.array([1, 0, 1, 0, 0]),
|
||||||
|
7: np.array([1, 1, 1, 0, 0]),
|
||||||
|
12: np.array([0, 0, 1, 1, 0]),
|
||||||
|
13: np.array([1, 0, 1, 1, 0]),
|
||||||
|
15: np.array([1, 1, 1, 1, 0])}
|
||||||
|
|
||||||
|
|
||||||
deck = []
|
deck = []
|
||||||
for i in range(3, 15):
|
for i in range(3, 15):
|
||||||
|
@ -43,7 +63,7 @@ class Env:
|
||||||
Doudizhu multi-agent wrapper
|
Doudizhu multi-agent wrapper
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, objective, old_model, legacy_model=False):
|
def __init__(self, objective, old_model, legacy_model=False, lite_model = False):
|
||||||
"""
|
"""
|
||||||
Objective is wp/adp/logadp. It indicates whether considers
|
Objective is wp/adp/logadp. It indicates whether considers
|
||||||
bomb in reward calculation. Here, we use dummy agents.
|
bomb in reward calculation. Here, we use dummy agents.
|
||||||
|
@ -58,6 +78,7 @@ class Env:
|
||||||
self.objective = objective
|
self.objective = objective
|
||||||
self.use_legacy = legacy_model
|
self.use_legacy = legacy_model
|
||||||
self.use_general = not old_model
|
self.use_general = not old_model
|
||||||
|
self.lite_model = lite_model
|
||||||
|
|
||||||
# Initialize players
|
# Initialize players
|
||||||
# We use three dummy player for the target position
|
# We use three dummy player for the target position
|
||||||
|
@ -92,7 +113,7 @@ class Env:
|
||||||
card_play_data[key].sort()
|
card_play_data[key].sort()
|
||||||
self._env.card_play_init(card_play_data)
|
self._env.card_play_init(card_play_data)
|
||||||
self.infoset = self._game_infoset
|
self.infoset = self._game_infoset
|
||||||
return get_obs(self.infoset, self.use_general)
|
return get_obs(self.infoset, self.use_general, self.lite_model)
|
||||||
else:
|
else:
|
||||||
self.total_round += 1
|
self.total_round += 1
|
||||||
_deck = deck.copy()
|
_deck = deck.copy()
|
||||||
|
@ -118,7 +139,7 @@ class Env:
|
||||||
self._env.info_sets[pos].player_id = pid
|
self._env.info_sets[pos].player_id = pid
|
||||||
self.infoset = self._game_infoset
|
self.infoset = self._game_infoset
|
||||||
|
|
||||||
return get_obs(self.infoset, self.use_general, self.use_legacy)
|
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""
|
"""
|
||||||
|
@ -146,7 +167,7 @@ class Env:
|
||||||
}
|
}
|
||||||
obs = None
|
obs = None
|
||||||
else:
|
else:
|
||||||
obs = get_obs(self.infoset, self.use_general)
|
obs = get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
|
||||||
return obs, reward, done, {}
|
return obs, reward, done, {}
|
||||||
|
|
||||||
def _get_reward(self, pos):
|
def _get_reward(self, pos):
|
||||||
|
@ -244,7 +265,7 @@ class DummyAgent(object):
|
||||||
self.action = action
|
self.action = action
|
||||||
|
|
||||||
|
|
||||||
def get_obs(infoset, use_general=True, use_legacy = False):
|
def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False):
|
||||||
"""
|
"""
|
||||||
This function obtains observations with imperfect information
|
This function obtains observations with imperfect information
|
||||||
from the infoset. It has three branches since we encode
|
from the infoset. It has three branches since we encode
|
||||||
|
@ -271,59 +292,92 @@ def get_obs(infoset, use_general=True, use_legacy = False):
|
||||||
if use_general:
|
if use_general:
|
||||||
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
||||||
raise ValueError('')
|
raise ValueError('')
|
||||||
return _get_obs_general(infoset, infoset.player_position)
|
return _get_obs_general(infoset, infoset.player_position, lite_model)
|
||||||
else:
|
else:
|
||||||
if infoset.player_position == 'landlord':
|
if infoset.player_position == 'landlord':
|
||||||
return _get_obs_landlord(infoset, use_legacy)
|
return _get_obs_landlord(infoset, use_legacy, lite_model)
|
||||||
elif infoset.player_position == 'landlord_up':
|
elif infoset.player_position == 'landlord_up':
|
||||||
return _get_obs_landlord_up(infoset, use_legacy)
|
return _get_obs_landlord_up(infoset, use_legacy, lite_model)
|
||||||
elif infoset.player_position == 'landlord_front':
|
elif infoset.player_position == 'landlord_front':
|
||||||
return _get_obs_landlord_front(infoset, use_legacy)
|
return _get_obs_landlord_front(infoset, use_legacy, lite_model)
|
||||||
elif infoset.player_position == 'landlord_down':
|
elif infoset.player_position == 'landlord_down':
|
||||||
return _get_obs_landlord_down(infoset, use_legacy)
|
return _get_obs_landlord_down(infoset, use_legacy, lite_model)
|
||||||
else:
|
else:
|
||||||
raise ValueError('')
|
raise ValueError('')
|
||||||
|
|
||||||
|
|
||||||
def _get_one_hot_array(num_left_cards, max_num_cards):
|
def _get_one_hot_array(num_left_cards, max_num_cards, compress_size = 0):
|
||||||
"""
|
"""
|
||||||
A utility function to obtain one-hot endoding
|
A utility function to obtain one-hot endoding
|
||||||
"""
|
"""
|
||||||
one_hot = np.zeros(max_num_cards)
|
if compress_size > 0:
|
||||||
if num_left_cards > 0:
|
assert compress_size <= max_num_cards / 2
|
||||||
one_hot[num_left_cards - 1] = 1
|
array_size = max_num_cards - compress_size
|
||||||
|
one_hot = np.zeros(array_size)
|
||||||
|
if num_left_cards >= array_size:
|
||||||
|
one_hot[-1] = 1
|
||||||
|
num_left_cards -= array_size
|
||||||
|
if num_left_cards > 0:
|
||||||
|
one_hot[num_left_cards - 1] = 1
|
||||||
|
else:
|
||||||
|
one_hot = np.zeros(max_num_cards)
|
||||||
|
if num_left_cards > 0:
|
||||||
|
one_hot[num_left_cards - 1] = 1
|
||||||
|
|
||||||
return one_hot
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
def _cards2array(list_cards):
|
def _cards2array(list_cards, compressed_form = False):
|
||||||
"""
|
"""
|
||||||
A utility function that transforms the actions, i.e.,
|
A utility function that transforms the actions, i.e.,
|
||||||
A list of integers into card matrix. Here we remove
|
A list of integers into card matrix. Here we remove
|
||||||
the six entries that are always zero and flatten the
|
the six entries that are always zero and flatten the
|
||||||
the representations.
|
the representations.
|
||||||
"""
|
"""
|
||||||
if len(list_cards) == 0:
|
if compressed_form:
|
||||||
return np.zeros(108, dtype=np.int8)
|
if len(list_cards) == 0:
|
||||||
|
return np.zeros(69, dtype=np.int8)
|
||||||
|
|
||||||
matrix = np.zeros([8, 14], dtype=np.int8)
|
matrix = np.zeros([5, 14], dtype=np.int8)
|
||||||
counter = Counter(list_cards)
|
counter = Counter(list_cards)
|
||||||
joker_cnt = 0
|
joker_cnt = 0
|
||||||
for card, num_times in counter.items():
|
for card, num_times in counter.items():
|
||||||
if card < 20:
|
if card < 20:
|
||||||
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
|
||||||
elif card == 20:
|
elif card == 20:
|
||||||
if num_times == 2:
|
if num_times == 2:
|
||||||
joker_cnt |= 0b11
|
joker_cnt |= 0b11
|
||||||
else:
|
else:
|
||||||
joker_cnt |= 0b01
|
joker_cnt |= 0b01
|
||||||
elif card == 30:
|
elif card == 30:
|
||||||
if num_times == 2:
|
if num_times == 2:
|
||||||
joker_cnt |= 0b1100
|
joker_cnt |= 0b1100
|
||||||
else:
|
else:
|
||||||
joker_cnt |= 0b0100
|
joker_cnt |= 0b0100
|
||||||
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
|
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
|
||||||
return matrix.flatten('F')[:-4]
|
return matrix.flatten('F')[:-1]
|
||||||
|
else:
|
||||||
|
if len(list_cards) == 0:
|
||||||
|
return np.zeros(108, dtype=np.int8)
|
||||||
|
|
||||||
|
matrix = np.zeros([8, 14], dtype=np.int8)
|
||||||
|
counter = Counter(list_cards)
|
||||||
|
joker_cnt = 0
|
||||||
|
for card, num_times in counter.items():
|
||||||
|
if card < 20:
|
||||||
|
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
||||||
|
elif card == 20:
|
||||||
|
if num_times == 2:
|
||||||
|
joker_cnt |= 0b11
|
||||||
|
else:
|
||||||
|
joker_cnt |= 0b01
|
||||||
|
elif card == 30:
|
||||||
|
if num_times == 2:
|
||||||
|
joker_cnt |= 0b1100
|
||||||
|
else:
|
||||||
|
joker_cnt |= 0b0100
|
||||||
|
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
|
||||||
|
return matrix.flatten('F')[:-4]
|
||||||
|
|
||||||
|
|
||||||
# def _action_seq_list2array(action_seq_list):
|
# def _action_seq_list2array(action_seq_list):
|
||||||
|
@ -342,7 +396,7 @@ def _cards2array(list_cards):
|
||||||
# # action_seq_array = action_seq_array.reshape(5, 162)
|
# # action_seq_array = action_seq_array.reshape(5, 162)
|
||||||
# return action_seq_array
|
# return action_seq_array
|
||||||
|
|
||||||
def _action_seq_list2array(action_seq_list, new_model=True):
|
def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = False):
|
||||||
"""
|
"""
|
||||||
A utility function to encode the historical moves.
|
A utility function to encode the historical moves.
|
||||||
We encode the historical 20 actions. If there is
|
We encode the historical 20 actions. If there is
|
||||||
|
@ -355,16 +409,29 @@ def _action_seq_list2array(action_seq_list, new_model=True):
|
||||||
|
|
||||||
if new_model:
|
if new_model:
|
||||||
# position_map = {"landlord": 0, "landlord_up": 1, "landlord_front": 2, "landlord_down": 3}
|
# position_map = {"landlord": 0, "landlord_up": 1, "landlord_front": 2, "landlord_down": 3}
|
||||||
action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area
|
if compressed_form:
|
||||||
for row, list_cards in enumerate(action_seq_list):
|
action_seq_array = np.full((len(action_seq_list), 69), -1) # Default Value -1 for not using area
|
||||||
if list_cards != []:
|
for row, list_cards in enumerate(action_seq_list):
|
||||||
action_seq_array[row, :108] = _cards2array(list_cards[1])
|
if list_cards != []:
|
||||||
|
action_seq_array[row, :69] = _cards2array(list_cards[1], compressed_form)
|
||||||
|
else:
|
||||||
|
action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area
|
||||||
|
for row, list_cards in enumerate(action_seq_list):
|
||||||
|
if list_cards != []:
|
||||||
|
action_seq_array[row, :108] = _cards2array(list_cards[1], compressed_form)
|
||||||
else:
|
else:
|
||||||
action_seq_array = np.zeros((len(action_seq_list), 108))
|
if compressed_form:
|
||||||
for row, list_cards in enumerate(action_seq_list):
|
action_seq_array = np.zeros((len(action_seq_list), 69))
|
||||||
if list_cards != []:
|
for row, list_cards in enumerate(action_seq_list):
|
||||||
action_seq_array[row, :] = _cards2array(list_cards[1])
|
if list_cards != []:
|
||||||
action_seq_array = action_seq_array.reshape(5, 432)
|
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
|
||||||
|
action_seq_array = action_seq_array.reshape(5, 276)
|
||||||
|
else:
|
||||||
|
action_seq_array = np.zeros((len(action_seq_list), 108))
|
||||||
|
for row, list_cards in enumerate(action_seq_list):
|
||||||
|
if list_cards != []:
|
||||||
|
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
|
||||||
|
action_seq_array = action_seq_array.reshape(5, 432)
|
||||||
return action_seq_array
|
return action_seq_array
|
||||||
|
|
||||||
# action_seq_array = np.zeros((len(action_seq_list), 54))
|
# action_seq_array = np.zeros((len(action_seq_list), 54))
|
||||||
|
@ -406,60 +473,60 @@ def _get_one_hot_bomb(bomb_num, use_legacy = False):
|
||||||
return one_hot
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
def _get_obs_landlord(infoset, use_legacy = False):
|
def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False):
|
||||||
"""
|
"""
|
||||||
Obttain the landlord features. See Table 4 in
|
Obttain the landlord features. See Table 4 in
|
||||||
https://arxiv.org/pdf/2106.06135.pdf
|
https://arxiv.org/pdf/2106.06135.pdf
|
||||||
"""
|
"""
|
||||||
num_legal_actions = len(infoset.legal_actions)
|
num_legal_actions = len(infoset.legal_actions)
|
||||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||||
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_action = _cards2array(infoset.last_move)
|
last_action = _cards2array(infoset.last_move, compressed_form)
|
||||||
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||||
for j, action in enumerate(infoset.legal_actions):
|
for j, action in enumerate(infoset.legal_actions):
|
||||||
my_action_batch[j, :] = _cards2array(action)
|
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||||
|
|
||||||
landlord_up_num_cards_left = _get_one_hot_array(
|
landlord_up_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_up'], 25)
|
infoset.num_cards_left_dict['landlord_up'], 25, 15)
|
||||||
landlord_up_num_cards_left_batch = np.repeat(
|
landlord_up_num_cards_left_batch = np.repeat(
|
||||||
landlord_up_num_cards_left[np.newaxis, :],
|
landlord_up_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
landlord_front_num_cards_left = _get_one_hot_array(
|
landlord_front_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||||
landlord_front_num_cards_left_batch = np.repeat(
|
landlord_front_num_cards_left_batch = np.repeat(
|
||||||
landlord_front_num_cards_left[np.newaxis, :],
|
landlord_front_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
landlord_down_num_cards_left = _get_one_hot_array(
|
landlord_down_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_down'], 25)
|
infoset.num_cards_left_dict['landlord_down'], 25, 8)
|
||||||
landlord_down_num_cards_left_batch = np.repeat(
|
landlord_down_num_cards_left_batch = np.repeat(
|
||||||
landlord_down_num_cards_left[np.newaxis, :],
|
landlord_down_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
landlord_up_played_cards = _cards2array(
|
landlord_up_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_up'])
|
infoset.played_cards['landlord_up'], 8)
|
||||||
landlord_up_played_cards_batch = np.repeat(
|
landlord_up_played_cards_batch = np.repeat(
|
||||||
landlord_up_played_cards[np.newaxis, :],
|
landlord_up_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
landlord_front_played_cards = _cards2array(
|
landlord_front_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_front'])
|
infoset.played_cards['landlord_front'], compressed_form)
|
||||||
landlord_front_played_cards_batch = np.repeat(
|
landlord_front_played_cards_batch = np.repeat(
|
||||||
landlord_front_played_cards[np.newaxis, :],
|
landlord_front_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
landlord_down_played_cards = _cards2array(
|
landlord_down_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_down'])
|
infoset.played_cards['landlord_down'], compressed_form)
|
||||||
landlord_down_played_cards_batch = np.repeat(
|
landlord_down_played_cards_batch = np.repeat(
|
||||||
landlord_down_played_cards[np.newaxis, :],
|
landlord_down_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -492,7 +559,7 @@ def _get_obs_landlord(infoset, use_legacy = False):
|
||||||
landlord_down_num_cards_left,
|
landlord_down_num_cards_left,
|
||||||
bomb_num))
|
bomb_num))
|
||||||
z = _action_seq_list2array(_process_action_seq(
|
z = _action_seq_list2array(_process_action_seq(
|
||||||
infoset.card_play_action_seq, 20, False), False)
|
infoset.card_play_action_seq, 20, False), False, compressed_form)
|
||||||
z_batch = np.repeat(
|
z_batch = np.repeat(
|
||||||
z[np.newaxis, :, :],
|
z[np.newaxis, :, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -506,75 +573,75 @@ def _get_obs_landlord(infoset, use_legacy = False):
|
||||||
}
|
}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _get_obs_landlord_up(infoset, use_legacy = False):
|
def _get_obs_landlord_up(infoset, use_legacy = False, compressed_form = False):
|
||||||
"""
|
"""
|
||||||
Obttain the landlord_up features. See Table 5 in
|
Obttain the landlord_up features. See Table 5 in
|
||||||
https://arxiv.org/pdf/2106.06135.pdf
|
https://arxiv.org/pdf/2106.06135.pdf
|
||||||
"""
|
"""
|
||||||
num_legal_actions = len(infoset.legal_actions)
|
num_legal_actions = len(infoset.legal_actions)
|
||||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||||
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_action = _cards2array(infoset.last_move)
|
last_action = _cards2array(infoset.last_move, compressed_form)
|
||||||
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||||
for j, action in enumerate(infoset.legal_actions):
|
for j, action in enumerate(infoset.legal_actions):
|
||||||
my_action_batch[j, :] = _cards2array(action)
|
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||||
|
|
||||||
last_landlord_action = _cards2array(
|
last_landlord_action = _cards2array(
|
||||||
infoset.last_move_dict['landlord'])
|
infoset.last_move_dict['landlord'], compressed_form)
|
||||||
last_landlord_action_batch = np.repeat(
|
last_landlord_action_batch = np.repeat(
|
||||||
last_landlord_action[np.newaxis, :],
|
last_landlord_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
landlord_num_cards_left = _get_one_hot_array(
|
landlord_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord'], 33)
|
infoset.num_cards_left_dict['landlord'], 33, 15)
|
||||||
landlord_num_cards_left_batch = np.repeat(
|
landlord_num_cards_left_batch = np.repeat(
|
||||||
landlord_num_cards_left[np.newaxis, :],
|
landlord_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
landlord_played_cards = _cards2array(
|
landlord_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord'])
|
infoset.played_cards['landlord'], compressed_form)
|
||||||
landlord_played_cards_batch = np.repeat(
|
landlord_played_cards_batch = np.repeat(
|
||||||
landlord_played_cards[np.newaxis, :],
|
landlord_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_teammate_action = _cards2array(
|
last_teammate_action = _cards2array(
|
||||||
infoset.last_move_dict['landlord_down'])
|
infoset.last_move_dict['landlord_down'], compressed_form)
|
||||||
last_teammate_action_batch = np.repeat(
|
last_teammate_action_batch = np.repeat(
|
||||||
last_teammate_action[np.newaxis, :],
|
last_teammate_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
teammate_num_cards_left = _get_one_hot_array(
|
teammate_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_down'], 25)
|
infoset.num_cards_left_dict['landlord_down'], 25, 8)
|
||||||
teammate_num_cards_left_batch = np.repeat(
|
teammate_num_cards_left_batch = np.repeat(
|
||||||
teammate_num_cards_left[np.newaxis, :],
|
teammate_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
teammate_played_cards = _cards2array(
|
teammate_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_down'])
|
infoset.played_cards['landlord_down'], compressed_form)
|
||||||
teammate_played_cards_batch = np.repeat(
|
teammate_played_cards_batch = np.repeat(
|
||||||
teammate_played_cards[np.newaxis, :],
|
teammate_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_teammate_front_action = _cards2array(
|
last_teammate_front_action = _cards2array(
|
||||||
infoset.last_move_dict['landlord_front'])
|
infoset.last_move_dict['landlord_front'], compressed_form)
|
||||||
last_teammate_front_action_batch = np.repeat(
|
last_teammate_front_action_batch = np.repeat(
|
||||||
last_teammate_front_action[np.newaxis, :],
|
last_teammate_front_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
teammate_front_num_cards_left = _get_one_hot_array(
|
teammate_front_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||||
teammate_front_num_cards_left_batch = np.repeat(
|
teammate_front_num_cards_left_batch = np.repeat(
|
||||||
teammate_front_num_cards_left[np.newaxis, :],
|
teammate_front_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
teammate_front_played_cards = _cards2array(
|
teammate_front_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_front'])
|
infoset.played_cards['landlord_front'], compressed_form)
|
||||||
teammate_front_played_cards_batch = np.repeat(
|
teammate_front_played_cards_batch = np.repeat(
|
||||||
teammate_front_played_cards[np.newaxis, :],
|
teammate_front_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -613,7 +680,7 @@ def _get_obs_landlord_up(infoset, use_legacy = False):
|
||||||
teammate_front_num_cards_left,
|
teammate_front_num_cards_left,
|
||||||
bomb_num))
|
bomb_num))
|
||||||
z = _action_seq_list2array(_process_action_seq(
|
z = _action_seq_list2array(_process_action_seq(
|
||||||
infoset.card_play_action_seq, 20, False), False)
|
infoset.card_play_action_seq, 20, False), False, compressed_form)
|
||||||
z_batch = np.repeat(
|
z_batch = np.repeat(
|
||||||
z[np.newaxis, :, :],
|
z[np.newaxis, :, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -627,75 +694,75 @@ def _get_obs_landlord_up(infoset, use_legacy = False):
|
||||||
}
|
}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _get_obs_landlord_front(infoset, use_legacy = False):
|
def _get_obs_landlord_front(infoset, use_legacy = False, compressed_form = False):
|
||||||
"""
|
"""
|
||||||
Obttain the landlord_front features. See Table 5 in
|
Obttain the landlord_front features. See Table 5 in
|
||||||
https://arxiv.org/pdf/2106.06135.pdf
|
https://arxiv.org/pdf/2106.06135.pdf
|
||||||
"""
|
"""
|
||||||
num_legal_actions = len(infoset.legal_actions)
|
num_legal_actions = len(infoset.legal_actions)
|
||||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||||
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_action = _cards2array(infoset.last_move)
|
last_action = _cards2array(infoset.last_move, compressed_form)
|
||||||
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||||
for j, action in enumerate(infoset.legal_actions):
|
for j, action in enumerate(infoset.legal_actions):
|
||||||
my_action_batch[j, :] = _cards2array(action)
|
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||||
|
|
||||||
last_landlord_action = _cards2array(
|
last_landlord_action = _cards2array(
|
||||||
infoset.last_move_dict['landlord'])
|
infoset.last_move_dict['landlord'], compressed_form)
|
||||||
last_landlord_action_batch = np.repeat(
|
last_landlord_action_batch = np.repeat(
|
||||||
last_landlord_action[np.newaxis, :],
|
last_landlord_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
landlord_num_cards_left = _get_one_hot_array(
|
landlord_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord'], 33)
|
infoset.num_cards_left_dict['landlord'], 33, 15)
|
||||||
landlord_num_cards_left_batch = np.repeat(
|
landlord_num_cards_left_batch = np.repeat(
|
||||||
landlord_num_cards_left[np.newaxis, :],
|
landlord_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
landlord_played_cards = _cards2array(
|
landlord_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord'])
|
infoset.played_cards['landlord'], compressed_form)
|
||||||
landlord_played_cards_batch = np.repeat(
|
landlord_played_cards_batch = np.repeat(
|
||||||
landlord_played_cards[np.newaxis, :],
|
landlord_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_teammate_action = _cards2array(
|
last_teammate_action = _cards2array(
|
||||||
infoset.last_move_dict['landlord_down'])
|
infoset.last_move_dict['landlord_down'], compressed_form)
|
||||||
last_teammate_action_batch = np.repeat(
|
last_teammate_action_batch = np.repeat(
|
||||||
last_teammate_action[np.newaxis, :],
|
last_teammate_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
teammate_num_cards_left = _get_one_hot_array(
|
teammate_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_down'], 25)
|
infoset.num_cards_left_dict['landlord_down'], 25, 8)
|
||||||
teammate_num_cards_left_batch = np.repeat(
|
teammate_num_cards_left_batch = np.repeat(
|
||||||
teammate_num_cards_left[np.newaxis, :],
|
teammate_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
teammate_played_cards = _cards2array(
|
teammate_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_down'])
|
infoset.played_cards['landlord_down'], compressed_form)
|
||||||
teammate_played_cards_batch = np.repeat(
|
teammate_played_cards_batch = np.repeat(
|
||||||
teammate_played_cards[np.newaxis, :],
|
teammate_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_teammate_front_action = _cards2array(
|
last_teammate_front_action = _cards2array(
|
||||||
infoset.last_move_dict['landlord_front'])
|
infoset.last_move_dict['landlord_front'], compressed_form)
|
||||||
last_teammate_front_action_batch = np.repeat(
|
last_teammate_front_action_batch = np.repeat(
|
||||||
last_teammate_front_action[np.newaxis, :],
|
last_teammate_front_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
teammate_front_num_cards_left = _get_one_hot_array(
|
teammate_front_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||||
teammate_front_num_cards_left_batch = np.repeat(
|
teammate_front_num_cards_left_batch = np.repeat(
|
||||||
teammate_front_num_cards_left[np.newaxis, :],
|
teammate_front_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
teammate_front_played_cards = _cards2array(
|
teammate_front_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_front'])
|
infoset.played_cards['landlord_front'], compressed_form)
|
||||||
teammate_front_played_cards_batch = np.repeat(
|
teammate_front_played_cards_batch = np.repeat(
|
||||||
teammate_played_cards[np.newaxis, :],
|
teammate_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -734,7 +801,7 @@ def _get_obs_landlord_front(infoset, use_legacy = False):
|
||||||
teammate_front_num_cards_left,
|
teammate_front_num_cards_left,
|
||||||
bomb_num))
|
bomb_num))
|
||||||
z = _action_seq_list2array(_process_action_seq(
|
z = _action_seq_list2array(_process_action_seq(
|
||||||
infoset.card_play_action_seq, 20, False), False)
|
infoset.card_play_action_seq, 20, False), False, compressed_form)
|
||||||
z_batch = np.repeat(
|
z_batch = np.repeat(
|
||||||
z[np.newaxis, :, :],
|
z[np.newaxis, :, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -748,75 +815,75 @@ def _get_obs_landlord_front(infoset, use_legacy = False):
|
||||||
}
|
}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _get_obs_landlord_down(infoset, use_legacy = False):
|
def _get_obs_landlord_down(infoset, use_legacy = False, compressed_form = False):
|
||||||
"""
|
"""
|
||||||
Obttain the landlord_down features. See Table 5 in
|
Obttain the landlord_down features. See Table 5 in
|
||||||
https://arxiv.org/pdf/2106.06135.pdf
|
https://arxiv.org/pdf/2106.06135.pdf
|
||||||
"""
|
"""
|
||||||
num_legal_actions = len(infoset.legal_actions)
|
num_legal_actions = len(infoset.legal_actions)
|
||||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||||
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_action = _cards2array(infoset.last_move)
|
last_action = _cards2array(infoset.last_move, compressed_form)
|
||||||
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
last_action_batch = np.repeat(last_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||||
for j, action in enumerate(infoset.legal_actions):
|
for j, action in enumerate(infoset.legal_actions):
|
||||||
my_action_batch[j, :] = _cards2array(action)
|
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||||
|
|
||||||
last_landlord_action = _cards2array(
|
last_landlord_action = _cards2array(
|
||||||
infoset.last_move_dict['landlord'])
|
infoset.last_move_dict['landlord'], compressed_form)
|
||||||
last_landlord_action_batch = np.repeat(
|
last_landlord_action_batch = np.repeat(
|
||||||
last_landlord_action[np.newaxis, :],
|
last_landlord_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
landlord_num_cards_left = _get_one_hot_array(
|
landlord_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord'], 33)
|
infoset.num_cards_left_dict['landlord'], 33, 15)
|
||||||
landlord_num_cards_left_batch = np.repeat(
|
landlord_num_cards_left_batch = np.repeat(
|
||||||
landlord_num_cards_left[np.newaxis, :],
|
landlord_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
landlord_played_cards = _cards2array(
|
landlord_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord'])
|
infoset.played_cards['landlord'], compressed_form)
|
||||||
landlord_played_cards_batch = np.repeat(
|
landlord_played_cards_batch = np.repeat(
|
||||||
landlord_played_cards[np.newaxis, :],
|
landlord_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_teammate_action = _cards2array(
|
last_teammate_action = _cards2array(
|
||||||
infoset.last_move_dict['landlord_up'])
|
infoset.last_move_dict['landlord_up'], compressed_form)
|
||||||
last_teammate_action_batch = np.repeat(
|
last_teammate_action_batch = np.repeat(
|
||||||
last_teammate_action[np.newaxis, :],
|
last_teammate_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
teammate_num_cards_left = _get_one_hot_array(
|
teammate_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_up'], 25)
|
infoset.num_cards_left_dict['landlord_up'], 25, 8)
|
||||||
teammate_num_cards_left_batch = np.repeat(
|
teammate_num_cards_left_batch = np.repeat(
|
||||||
teammate_num_cards_left[np.newaxis, :],
|
teammate_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
teammate_played_cards = _cards2array(
|
teammate_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_up'])
|
infoset.played_cards['landlord_up'], compressed_form)
|
||||||
teammate_played_cards_batch = np.repeat(
|
teammate_played_cards_batch = np.repeat(
|
||||||
teammate_played_cards[np.newaxis, :],
|
teammate_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
last_teammate_front_action = _cards2array(
|
last_teammate_front_action = _cards2array(
|
||||||
infoset.last_move_dict['landlord_front'])
|
infoset.last_move_dict['landlord_front'], compressed_form)
|
||||||
last_teammate_front_action_batch = np.repeat(
|
last_teammate_front_action_batch = np.repeat(
|
||||||
last_teammate_front_action[np.newaxis, :],
|
last_teammate_front_action[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
teammate_front_num_cards_left = _get_one_hot_array(
|
teammate_front_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||||
teammate_front_num_cards_left_batch = np.repeat(
|
teammate_front_num_cards_left_batch = np.repeat(
|
||||||
teammate_front_num_cards_left[np.newaxis, :],
|
teammate_front_num_cards_left[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
teammate_front_played_cards = _cards2array(
|
teammate_front_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_front'])
|
infoset.played_cards['landlord_front'], compressed_form)
|
||||||
teammate_front_played_cards_batch = np.repeat(
|
teammate_front_played_cards_batch = np.repeat(
|
||||||
teammate_front_played_cards[np.newaxis, :],
|
teammate_front_played_cards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -855,7 +922,7 @@ def _get_obs_landlord_down(infoset, use_legacy = False):
|
||||||
teammate_front_num_cards_left,
|
teammate_front_num_cards_left,
|
||||||
bomb_num))
|
bomb_num))
|
||||||
z = _action_seq_list2array(_process_action_seq(
|
z = _action_seq_list2array(_process_action_seq(
|
||||||
infoset.card_play_action_seq, 20, False), False)
|
infoset.card_play_action_seq, 20, False), False, compressed_form)
|
||||||
z_batch = np.repeat(
|
z_batch = np.repeat(
|
||||||
z[np.newaxis, :, :],
|
z[np.newaxis, :, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -869,41 +936,41 @@ def _get_obs_landlord_down(infoset, use_legacy = False):
|
||||||
}
|
}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _get_obs_general(infoset, position):
|
def _get_obs_general(infoset, position, compressed_form = False):
|
||||||
num_legal_actions = len(infoset.legal_actions)
|
num_legal_actions = len(infoset.legal_actions)
|
||||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
other_handcards = _cards2array(infoset.other_hand_cards)
|
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||||
|
|
||||||
my_action_batch = np.zeros(my_handcards_batch.shape)
|
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||||
for j, action in enumerate(infoset.legal_actions):
|
for j, action in enumerate(infoset.legal_actions):
|
||||||
my_action_batch[j, :] = _cards2array(action)
|
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||||
|
|
||||||
landlord_num_cards_left = _get_one_hot_array(
|
landlord_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord'], 33)
|
infoset.num_cards_left_dict['landlord'], 33, 15)
|
||||||
|
|
||||||
landlord_up_num_cards_left = _get_one_hot_array(
|
landlord_up_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_up'], 25)
|
infoset.num_cards_left_dict['landlord_up'], 25, 8)
|
||||||
|
|
||||||
landlord_front_num_cards_left = _get_one_hot_array(
|
landlord_front_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_front'], 25)
|
infoset.num_cards_left_dict['landlord_front'], 25, 8)
|
||||||
|
|
||||||
landlord_down_num_cards_left = _get_one_hot_array(
|
landlord_down_num_cards_left = _get_one_hot_array(
|
||||||
infoset.num_cards_left_dict['landlord_down'], 25)
|
infoset.num_cards_left_dict['landlord_down'], 25, 8)
|
||||||
|
|
||||||
landlord_played_cards = _cards2array(
|
landlord_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord'])
|
infoset.played_cards['landlord'], compressed_form)
|
||||||
|
|
||||||
landlord_up_played_cards = _cards2array(
|
landlord_up_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_up'])
|
infoset.played_cards['landlord_up'], compressed_form)
|
||||||
|
|
||||||
landlord_front_played_cards = _cards2array(
|
landlord_front_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_front'])
|
infoset.played_cards['landlord_front'], compressed_form)
|
||||||
|
|
||||||
landlord_down_played_cards = _cards2array(
|
landlord_down_played_cards = _cards2array(
|
||||||
infoset.played_cards['landlord_down'])
|
infoset.played_cards['landlord_down'], compressed_form)
|
||||||
|
|
||||||
bomb_num = _get_one_hot_bomb(
|
bomb_num = _get_one_hot_bomb(
|
||||||
infoset.bomb_num)
|
infoset.bomb_num)
|
||||||
|
@ -911,10 +978,10 @@ def _get_obs_general(infoset, position):
|
||||||
bomb_num[np.newaxis, :],
|
bomb_num[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
num_cards_left = np.hstack((
|
num_cards_left = np.hstack((
|
||||||
landlord_num_cards_left, # 33
|
landlord_num_cards_left, # 33/18
|
||||||
landlord_up_num_cards_left, # 25
|
landlord_up_num_cards_left, # 25/17
|
||||||
landlord_front_num_cards_left, # 25
|
landlord_front_num_cards_left, # 25/17
|
||||||
landlord_down_num_cards_left))
|
landlord_down_num_cards_left)) # 25/17
|
||||||
|
|
||||||
x_batch = np.hstack((
|
x_batch = np.hstack((
|
||||||
bomb_num_batch, # 56
|
bomb_num_batch, # 56
|
||||||
|
@ -924,21 +991,24 @@ def _get_obs_general(infoset, position):
|
||||||
))
|
))
|
||||||
|
|
||||||
z =np.vstack((
|
z =np.vstack((
|
||||||
num_cards_left,
|
num_cards_left, # 108 / 18+17*3=69
|
||||||
my_handcards, # 108
|
my_handcards, # 108/69
|
||||||
other_handcards, # 108
|
other_handcards, # 108/69
|
||||||
landlord_played_cards, # 108
|
landlord_played_cards, # 108/69
|
||||||
landlord_up_played_cards, # 108
|
landlord_up_played_cards, # 108/69
|
||||||
landlord_front_played_cards, # 108
|
landlord_front_played_cards, # 108/69
|
||||||
landlord_down_played_cards, # 108
|
landlord_down_played_cards, # 108/69
|
||||||
_action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32))
|
_action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32), True, compressed_form)
|
||||||
))
|
))
|
||||||
|
|
||||||
_z_batch = np.repeat(
|
_z_batch = np.repeat(
|
||||||
z[np.newaxis, :, :],
|
z[np.newaxis, :, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
my_action_batch = my_action_batch[:,np.newaxis,:]
|
my_action_batch = my_action_batch[:,np.newaxis,:]
|
||||||
z_batch = np.zeros([len(_z_batch),40,108],int)
|
if compressed_form:
|
||||||
|
z_batch = np.zeros([len(_z_batch), 40, 69], int)
|
||||||
|
else:
|
||||||
|
z_batch = np.zeros([len(_z_batch),40,108],int)
|
||||||
for i in range(0,len(_z_batch)):
|
for i in range(0,len(_z_batch)):
|
||||||
z_batch[i] = np.vstack((my_action_batch[i],_z_batch[i]))
|
z_batch[i] = np.vstack((my_action_batch[i],_z_batch[i]))
|
||||||
obs = {
|
obs = {
|
||||||
|
|
|
@ -49,6 +49,7 @@ class DeepAgent:
|
||||||
|
|
||||||
def __init__(self, position, model_path):
|
def __init__(self, position, model_path):
|
||||||
self.use_legacy = True if "legacy" in model_path else False
|
self.use_legacy = True if "legacy" in model_path else False
|
||||||
|
self.lite_model = True if "lite" in model_path else False
|
||||||
self.model_type = "general" if "resnet" in model_path else "old"
|
self.model_type = "general" if "resnet" in model_path else "old"
|
||||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
|
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
|
||||||
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
|
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
|
||||||
|
@ -59,7 +60,7 @@ class DeepAgent:
|
||||||
if len(infoset.legal_actions) == 1:
|
if len(infoset.legal_actions) == 1:
|
||||||
return infoset.legal_actions[0]
|
return infoset.legal_actions[0]
|
||||||
|
|
||||||
obs = get_obs(infoset, self.model_type == "general", self.use_legacy)
|
obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model)
|
||||||
|
|
||||||
z_batch = torch.from_numpy(obs['z_batch']).float()
|
z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||||
x_batch = torch.from_numpy(obs['x_batch']).float()
|
x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||||
|
|
Loading…
Reference in New Issue