From b5982b719569533a3e37e2b18079172185d53507 Mon Sep 17 00:00:00 2001 From: ZaneYork Date: Wed, 22 Dec 2021 21:19:10 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=8E=8B=E7=BC=A9=E7=89=88?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/arguments.py | 2 + douzero/dmc/dmc.py | 6 +- douzero/dmc/models.py | 70 ++++++- douzero/dmc/utils.py | 97 ++++++--- douzero/env/env.py | 324 +++++++++++++++++++------------ douzero/evaluation/deep_agent.py | 3 +- 6 files changed, 343 insertions(+), 159 deletions(-) diff --git a/douzero/dmc/arguments.py b/douzero/dmc/arguments.py index ccddcdf..b7b94e6 100644 --- a/douzero/dmc/arguments.py +++ b/douzero/dmc/arguments.py @@ -29,6 +29,8 @@ parser.add_argument('--load_model', action='store_true', help='Load an existing model') parser.add_argument('--old_model', action='store_true', help='Use vanilla model') +parser.add_argument('--lite_model', action='store_true', + help='Use lite card model') parser.add_argument('--lagacy_model', action='store_true', help='Use lagacy bomb model') parser.add_argument('--disable_checkpoint', action='store_true', diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 858f622..9212eee 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -87,7 +87,7 @@ def train(flags): if flags.actor_device_cpu: device_iterator = ['cpu'] else: - device_iterator = [0, 'cpu'] #range(flags.num_actor_devices) #[0, 'cpu'] + device_iterator = range(flags.num_actor_devices) #[0, 'cpu'] assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices' # Initialize actor models @@ -96,7 +96,7 @@ def train(flags): if flags.old_model: model = OldModel(device="cpu", flags = flags) else: - model = Model(device="cpu", flags = flags) + model = Model(device="cpu", flags = flags, lite_model = flags.lite_model) model.share_memory() model.eval() models[device] = model @@ -111,7 +111,7 @@ def train(flags): if flags.old_model: learner_model = OldModel(device=flags.training_device) else: - learner_model = Model(device=flags.training_device) + learner_model = Model(device=flags.training_device, lite_model = flags.lite_model) # Create optimizers optimizers = create_optimizers(flags, learner_model) diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index 8707bd2..77bdd82 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -240,6 +240,69 @@ class BasicBlock(nn.Module): out = F.relu(out) return out + +class GeneralModelLite(nn.Module): + def __init__(self): + super().__init__() + self.in_planes = 80 + #input 1*69*41 + self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,), + stride=(2,), padding=1, bias=False) #1*35*80 + + self.bn1 = nn.BatchNorm1d(80) + + self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*18*80 + self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*9*160 + self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*5*320 + # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear1 = nn.Linear(320 * BasicBlock.expansion * 5 + 56, 1536) + self.linear2 = nn.Linear(1536, 768) + self.linear3 = nn.Linear(768, 384) + self.linear4 = nn.Linear(384, 1) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def get_onnx_params(self, device=None): + return { + 'args': ( + torch.randn(1, 40, 69, requires_grad=True, device=device), + torch.randn(1, 56, requires_grad=True, device=device), + ), + 'input_names': ['z_batch','x_batch'], + 'output_names': ['values'], + 'dynamic_axes': { + 'z_batch': { + 0: "batch_size" + }, + 'x_batch': { + 0: "batch_size" + }, + 'values': { + 0: "batch_size" + } + } + } + + def forward(self, z, x): + out = F.relu(self.bn1(self.conv1(z))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = out.flatten(1,2) + out = torch.cat([x,out], dim=-1) + out = F.leaky_relu_(self.linear1(out)) + out = F.leaky_relu_(self.linear2(out)) + out = F.leaky_relu_(self.linear3(out)) + out = F.leaky_relu_(self.linear4(out)) + return dict(values=out) + + class GeneralModel(nn.Module): def __init__(self): super().__init__() @@ -414,7 +477,7 @@ class Model: The wrapper for the three models. We also wrap several interfaces such as share_memory, eval, etc. """ - def __init__(self, device=0, flags=None): + def __init__(self, device=0, flags=None, lite_model = False): self.models = {} self.onnx_models = {} self.flags = flags @@ -428,7 +491,10 @@ class Model: self.models[position] = None else: for position in positions: - self.models[position] = GeneralModel().to(self.device) + if lite_model: + self.models[position] = GeneralModelLite().to(self.device) + else: + self.models[position] = GeneralModel().to(self.device) self.onnx_models = { 'landlord': None, 'landlord_up': None, diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index 7df335d..36204b7 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -36,6 +36,26 @@ NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]), 13: np.array([1, 0, 1, 1, 0, 0, 0, 0]), 15: np.array([1, 1, 1, 1, 0, 0, 0, 0])} +NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]), + 1: np.array([1, 0, 0, 0, 0]), + 2: np.array([1, 1, 0, 0, 0]), + 3: np.array([1, 1, 1, 0, 0]), + 4: np.array([1, 1, 1, 1, 0]), + 5: np.array([1, 0, 0, 0, 1]), + 6: np.array([1, 1, 0, 0, 1]), + 7: np.array([1, 1, 1, 0, 1]), + 8: np.array([1, 1, 1, 1, 1])} + +NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]), + 1: np.array([1, 0, 0, 0, 0]), + 3: np.array([1, 1, 0, 0, 0]), + 4: np.array([0, 0, 1, 0, 0]), + 5: np.array([1, 0, 1, 0, 0]), + 7: np.array([1, 1, 1, 0, 0]), + 12: np.array([0, 0, 1, 1, 0]), + 13: np.array([1, 0, 1, 1, 0]), + 15: np.array([1, 1, 1, 1, 0])} + shandle = logging.StreamHandler() shandle.setFormatter( logging.Formatter( @@ -51,7 +71,7 @@ log.setLevel(logging.INFO) Buffers = typing.Dict[str, typing.List[torch.Tensor]] def create_env(flags): - return Env(flags.objective, flags.old_model, flags.lagacy_model) + return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model) def get_batch(b_queues, position, flags, lock): """ @@ -133,11 +153,11 @@ def act(i, device, batch_queues, model, flags, onnx_frame): else: action = obs['legal_actions'][0] if flags.old_model: - obs_action_buf[position].append(_cards2tensor(action)) + obs_action_buf[position].append(_cards2tensor(action, flags.lite_model)) obs_x_no_action[position].append(env_output['obs_x_no_action']) obs_z_buf[position].append(env_output['obs_z']) else: - obs_z_buf[position].append(torch.vstack((_cards2tensor(action).unsqueeze(0), env_output['obs_z'])).float()) + obs_z_buf[position].append(torch.vstack((_cards2tensor(action, flags.lite_model).unsqueeze(0), env_output['obs_z'])).float()) # x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float() x_batch = env_output['obs_x_no_action'].float() obs_x_batch_buf[position].append(x_batch) @@ -196,32 +216,57 @@ def act(i, device, batch_queues, model, flags, onnx_frame): print() raise e -def _cards2tensor(list_cards): +def _cards2tensor(list_cards, compress_form = False): """ Convert a list of integers to the tensor representation See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf """ - if len(list_cards) == 0: - return torch.zeros(108, dtype=torch.int8) + if compress_form: + if len(list_cards) == 0: + return torch.zeros(69, dtype=torch.int8) - matrix = np.zeros([8, 14], dtype=np.int8) - counter = Counter(list_cards) - joker_cnt = 0 - for card, num_times in counter.items(): - if card < 20: - matrix[:, Card2Column[card]] = NumOnes2Array[num_times] - elif card == 20: - if num_times == 2: - joker_cnt |= 0b11 - else: - joker_cnt |= 0b01 - elif card == 30: - if num_times == 2: - joker_cnt |= 0b1100 - else: - joker_cnt |= 0b0100 - matrix[:, 13] = NumOnesJoker2Array[joker_cnt] - matrix = matrix.flatten('F')[:-4] - matrix = torch.from_numpy(matrix) - return matrix + matrix = np.zeros([5, 14], dtype=np.int8) + counter = Counter(list_cards) + joker_cnt = 0 + for card, num_times in counter.items(): + if card < 20: + matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times] + elif card == 20: + if num_times == 2: + joker_cnt |= 0b11 + else: + joker_cnt |= 0b01 + elif card == 30: + if num_times == 2: + joker_cnt |= 0b1100 + else: + joker_cnt |= 0b0100 + matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt] + matrix = matrix.flatten('F')[:-1] + matrix = torch.from_numpy(matrix) + return matrix + else: + if len(list_cards) == 0: + return torch.zeros(108, dtype=torch.int8) + + matrix = np.zeros([8, 14], dtype=np.int8) + counter = Counter(list_cards) + joker_cnt = 0 + for card, num_times in counter.items(): + if card < 20: + matrix[:, Card2Column[card]] = NumOnes2Array[num_times] + elif card == 20: + if num_times == 2: + joker_cnt |= 0b11 + else: + joker_cnt |= 0b01 + elif card == 30: + if num_times == 2: + joker_cnt |= 0b1100 + else: + joker_cnt |= 0b0100 + matrix[:, 13] = NumOnesJoker2Array[joker_cnt] + matrix = matrix.flatten('F')[:-4] + matrix = torch.from_numpy(matrix) + return matrix diff --git a/douzero/env/env.py b/douzero/env/env.py index 30255dc..ae5ab30 100644 --- a/douzero/env/env.py +++ b/douzero/env/env.py @@ -30,6 +30,26 @@ NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]), 13: np.array([1, 0, 1, 1, 0, 0, 0, 0]), 15: np.array([1, 1, 1, 1, 0, 0, 0, 0])} +NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]), + 1: np.array([1, 0, 0, 0, 0]), + 2: np.array([1, 1, 0, 0, 0]), + 3: np.array([1, 1, 1, 0, 0]), + 4: np.array([1, 1, 1, 1, 0]), + 5: np.array([1, 0, 0, 0, 1]), + 6: np.array([1, 1, 0, 0, 1]), + 7: np.array([1, 1, 1, 0, 1]), + 8: np.array([1, 1, 1, 1, 1])} + +NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]), + 1: np.array([1, 0, 0, 0, 0]), + 3: np.array([1, 1, 0, 0, 0]), + 4: np.array([0, 0, 1, 0, 0]), + 5: np.array([1, 0, 1, 0, 0]), + 7: np.array([1, 1, 1, 0, 0]), + 12: np.array([0, 0, 1, 1, 0]), + 13: np.array([1, 0, 1, 1, 0]), + 15: np.array([1, 1, 1, 1, 0])} + deck = [] for i in range(3, 15): @@ -43,7 +63,7 @@ class Env: Doudizhu multi-agent wrapper """ - def __init__(self, objective, old_model, legacy_model=False): + def __init__(self, objective, old_model, legacy_model=False, lite_model = False): """ Objective is wp/adp/logadp. It indicates whether considers bomb in reward calculation. Here, we use dummy agents. @@ -58,6 +78,7 @@ class Env: self.objective = objective self.use_legacy = legacy_model self.use_general = not old_model + self.lite_model = lite_model # Initialize players # We use three dummy player for the target position @@ -92,7 +113,7 @@ class Env: card_play_data[key].sort() self._env.card_play_init(card_play_data) self.infoset = self._game_infoset - return get_obs(self.infoset, self.use_general) + return get_obs(self.infoset, self.use_general, self.lite_model) else: self.total_round += 1 _deck = deck.copy() @@ -118,7 +139,7 @@ class Env: self._env.info_sets[pos].player_id = pid self.infoset = self._game_infoset - return get_obs(self.infoset, self.use_general, self.use_legacy) + return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model) def step(self, action): """ @@ -146,7 +167,7 @@ class Env: } obs = None else: - obs = get_obs(self.infoset, self.use_general) + obs = get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model) return obs, reward, done, {} def _get_reward(self, pos): @@ -244,7 +265,7 @@ class DummyAgent(object): self.action = action -def get_obs(infoset, use_general=True, use_legacy = False): +def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False): """ This function obtains observations with imperfect information from the infoset. It has three branches since we encode @@ -271,59 +292,92 @@ def get_obs(infoset, use_general=True, use_legacy = False): if use_general: if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]: raise ValueError('') - return _get_obs_general(infoset, infoset.player_position) + return _get_obs_general(infoset, infoset.player_position, lite_model) else: if infoset.player_position == 'landlord': - return _get_obs_landlord(infoset, use_legacy) + return _get_obs_landlord(infoset, use_legacy, lite_model) elif infoset.player_position == 'landlord_up': - return _get_obs_landlord_up(infoset, use_legacy) + return _get_obs_landlord_up(infoset, use_legacy, lite_model) elif infoset.player_position == 'landlord_front': - return _get_obs_landlord_front(infoset, use_legacy) + return _get_obs_landlord_front(infoset, use_legacy, lite_model) elif infoset.player_position == 'landlord_down': - return _get_obs_landlord_down(infoset, use_legacy) + return _get_obs_landlord_down(infoset, use_legacy, lite_model) else: raise ValueError('') -def _get_one_hot_array(num_left_cards, max_num_cards): +def _get_one_hot_array(num_left_cards, max_num_cards, compress_size = 0): """ A utility function to obtain one-hot endoding """ - one_hot = np.zeros(max_num_cards) - if num_left_cards > 0: - one_hot[num_left_cards - 1] = 1 + if compress_size > 0: + assert compress_size <= max_num_cards / 2 + array_size = max_num_cards - compress_size + one_hot = np.zeros(array_size) + if num_left_cards >= array_size: + one_hot[-1] = 1 + num_left_cards -= array_size + if num_left_cards > 0: + one_hot[num_left_cards - 1] = 1 + else: + one_hot = np.zeros(max_num_cards) + if num_left_cards > 0: + one_hot[num_left_cards - 1] = 1 return one_hot -def _cards2array(list_cards): +def _cards2array(list_cards, compressed_form = False): """ A utility function that transforms the actions, i.e., A list of integers into card matrix. Here we remove the six entries that are always zero and flatten the the representations. """ - if len(list_cards) == 0: - return np.zeros(108, dtype=np.int8) + if compressed_form: + if len(list_cards) == 0: + return np.zeros(69, dtype=np.int8) - matrix = np.zeros([8, 14], dtype=np.int8) - counter = Counter(list_cards) - joker_cnt = 0 - for card, num_times in counter.items(): - if card < 20: - matrix[:, Card2Column[card]] = NumOnes2Array[num_times] - elif card == 20: - if num_times == 2: - joker_cnt |= 0b11 - else: - joker_cnt |= 0b01 - elif card == 30: - if num_times == 2: - joker_cnt |= 0b1100 - else: - joker_cnt |= 0b0100 - matrix[:, 13] = NumOnesJoker2Array[joker_cnt] - return matrix.flatten('F')[:-4] + matrix = np.zeros([5, 14], dtype=np.int8) + counter = Counter(list_cards) + joker_cnt = 0 + for card, num_times in counter.items(): + if card < 20: + matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times] + elif card == 20: + if num_times == 2: + joker_cnt |= 0b11 + else: + joker_cnt |= 0b01 + elif card == 30: + if num_times == 2: + joker_cnt |= 0b1100 + else: + joker_cnt |= 0b0100 + matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt] + return matrix.flatten('F')[:-1] + else: + if len(list_cards) == 0: + return np.zeros(108, dtype=np.int8) + + matrix = np.zeros([8, 14], dtype=np.int8) + counter = Counter(list_cards) + joker_cnt = 0 + for card, num_times in counter.items(): + if card < 20: + matrix[:, Card2Column[card]] = NumOnes2Array[num_times] + elif card == 20: + if num_times == 2: + joker_cnt |= 0b11 + else: + joker_cnt |= 0b01 + elif card == 30: + if num_times == 2: + joker_cnt |= 0b1100 + else: + joker_cnt |= 0b0100 + matrix[:, 13] = NumOnesJoker2Array[joker_cnt] + return matrix.flatten('F')[:-4] # def _action_seq_list2array(action_seq_list): @@ -342,7 +396,7 @@ def _cards2array(list_cards): # # action_seq_array = action_seq_array.reshape(5, 162) # return action_seq_array -def _action_seq_list2array(action_seq_list, new_model=True): +def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = False): """ A utility function to encode the historical moves. We encode the historical 20 actions. If there is @@ -355,16 +409,29 @@ def _action_seq_list2array(action_seq_list, new_model=True): if new_model: # position_map = {"landlord": 0, "landlord_up": 1, "landlord_front": 2, "landlord_down": 3} - action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area - for row, list_cards in enumerate(action_seq_list): - if list_cards != []: - action_seq_array[row, :108] = _cards2array(list_cards[1]) + if compressed_form: + action_seq_array = np.full((len(action_seq_list), 69), -1) # Default Value -1 for not using area + for row, list_cards in enumerate(action_seq_list): + if list_cards != []: + action_seq_array[row, :69] = _cards2array(list_cards[1], compressed_form) + else: + action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area + for row, list_cards in enumerate(action_seq_list): + if list_cards != []: + action_seq_array[row, :108] = _cards2array(list_cards[1], compressed_form) else: - action_seq_array = np.zeros((len(action_seq_list), 108)) - for row, list_cards in enumerate(action_seq_list): - if list_cards != []: - action_seq_array[row, :] = _cards2array(list_cards[1]) - action_seq_array = action_seq_array.reshape(5, 432) + if compressed_form: + action_seq_array = np.zeros((len(action_seq_list), 69)) + for row, list_cards in enumerate(action_seq_list): + if list_cards != []: + action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form) + action_seq_array = action_seq_array.reshape(5, 276) + else: + action_seq_array = np.zeros((len(action_seq_list), 108)) + for row, list_cards in enumerate(action_seq_list): + if list_cards != []: + action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form) + action_seq_array = action_seq_array.reshape(5, 432) return action_seq_array # action_seq_array = np.zeros((len(action_seq_list), 54)) @@ -406,60 +473,60 @@ def _get_one_hot_bomb(bomb_num, use_legacy = False): return one_hot -def _get_obs_landlord(infoset, use_legacy = False): +def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False): """ Obttain the landlord features. See Table 4 in https://arxiv.org/pdf/2106.06135.pdf """ num_legal_actions = len(infoset.legal_actions) - my_handcards = _cards2array(infoset.player_hand_cards) + my_handcards = _cards2array(infoset.player_hand_cards, compressed_form) my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], num_legal_actions, axis=0) - other_handcards = _cards2array(infoset.other_hand_cards) + other_handcards = _cards2array(infoset.other_hand_cards, compressed_form) other_handcards_batch = np.repeat(other_handcards[np.newaxis, :], num_legal_actions, axis=0) - last_action = _cards2array(infoset.last_move) + last_action = _cards2array(infoset.last_move, compressed_form) last_action_batch = np.repeat(last_action[np.newaxis, :], num_legal_actions, axis=0) my_action_batch = np.zeros(my_handcards_batch.shape) for j, action in enumerate(infoset.legal_actions): - my_action_batch[j, :] = _cards2array(action) + my_action_batch[j, :] = _cards2array(action, compressed_form) landlord_up_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_up'], 25) + infoset.num_cards_left_dict['landlord_up'], 25, 15) landlord_up_num_cards_left_batch = np.repeat( landlord_up_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) landlord_front_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_front'], 25) + infoset.num_cards_left_dict['landlord_front'], 25, 8) landlord_front_num_cards_left_batch = np.repeat( landlord_front_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) landlord_down_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_down'], 25) + infoset.num_cards_left_dict['landlord_down'], 25, 8) landlord_down_num_cards_left_batch = np.repeat( landlord_down_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) landlord_up_played_cards = _cards2array( - infoset.played_cards['landlord_up']) + infoset.played_cards['landlord_up'], 8) landlord_up_played_cards_batch = np.repeat( landlord_up_played_cards[np.newaxis, :], num_legal_actions, axis=0) landlord_front_played_cards = _cards2array( - infoset.played_cards['landlord_front']) + infoset.played_cards['landlord_front'], compressed_form) landlord_front_played_cards_batch = np.repeat( landlord_front_played_cards[np.newaxis, :], num_legal_actions, axis=0) landlord_down_played_cards = _cards2array( - infoset.played_cards['landlord_down']) + infoset.played_cards['landlord_down'], compressed_form) landlord_down_played_cards_batch = np.repeat( landlord_down_played_cards[np.newaxis, :], num_legal_actions, axis=0) @@ -492,7 +559,7 @@ def _get_obs_landlord(infoset, use_legacy = False): landlord_down_num_cards_left, bomb_num)) z = _action_seq_list2array(_process_action_seq( - infoset.card_play_action_seq, 20, False), False) + infoset.card_play_action_seq, 20, False), False, compressed_form) z_batch = np.repeat( z[np.newaxis, :, :], num_legal_actions, axis=0) @@ -506,75 +573,75 @@ def _get_obs_landlord(infoset, use_legacy = False): } return obs -def _get_obs_landlord_up(infoset, use_legacy = False): +def _get_obs_landlord_up(infoset, use_legacy = False, compressed_form = False): """ Obttain the landlord_up features. See Table 5 in https://arxiv.org/pdf/2106.06135.pdf """ num_legal_actions = len(infoset.legal_actions) - my_handcards = _cards2array(infoset.player_hand_cards) + my_handcards = _cards2array(infoset.player_hand_cards, compressed_form) my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], num_legal_actions, axis=0) - other_handcards = _cards2array(infoset.other_hand_cards) + other_handcards = _cards2array(infoset.other_hand_cards, compressed_form) other_handcards_batch = np.repeat(other_handcards[np.newaxis, :], num_legal_actions, axis=0) - last_action = _cards2array(infoset.last_move) + last_action = _cards2array(infoset.last_move, compressed_form) last_action_batch = np.repeat(last_action[np.newaxis, :], num_legal_actions, axis=0) my_action_batch = np.zeros(my_handcards_batch.shape) for j, action in enumerate(infoset.legal_actions): - my_action_batch[j, :] = _cards2array(action) + my_action_batch[j, :] = _cards2array(action, compressed_form) last_landlord_action = _cards2array( - infoset.last_move_dict['landlord']) + infoset.last_move_dict['landlord'], compressed_form) last_landlord_action_batch = np.repeat( last_landlord_action[np.newaxis, :], num_legal_actions, axis=0) landlord_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord'], 33) + infoset.num_cards_left_dict['landlord'], 33, 15) landlord_num_cards_left_batch = np.repeat( landlord_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) landlord_played_cards = _cards2array( - infoset.played_cards['landlord']) + infoset.played_cards['landlord'], compressed_form) landlord_played_cards_batch = np.repeat( landlord_played_cards[np.newaxis, :], num_legal_actions, axis=0) last_teammate_action = _cards2array( - infoset.last_move_dict['landlord_down']) + infoset.last_move_dict['landlord_down'], compressed_form) last_teammate_action_batch = np.repeat( last_teammate_action[np.newaxis, :], num_legal_actions, axis=0) teammate_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_down'], 25) + infoset.num_cards_left_dict['landlord_down'], 25, 8) teammate_num_cards_left_batch = np.repeat( teammate_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) teammate_played_cards = _cards2array( - infoset.played_cards['landlord_down']) + infoset.played_cards['landlord_down'], compressed_form) teammate_played_cards_batch = np.repeat( teammate_played_cards[np.newaxis, :], num_legal_actions, axis=0) last_teammate_front_action = _cards2array( - infoset.last_move_dict['landlord_front']) + infoset.last_move_dict['landlord_front'], compressed_form) last_teammate_front_action_batch = np.repeat( last_teammate_front_action[np.newaxis, :], num_legal_actions, axis=0) teammate_front_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_front'], 25) + infoset.num_cards_left_dict['landlord_front'], 25, 8) teammate_front_num_cards_left_batch = np.repeat( teammate_front_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) teammate_front_played_cards = _cards2array( - infoset.played_cards['landlord_front']) + infoset.played_cards['landlord_front'], compressed_form) teammate_front_played_cards_batch = np.repeat( teammate_front_played_cards[np.newaxis, :], num_legal_actions, axis=0) @@ -613,7 +680,7 @@ def _get_obs_landlord_up(infoset, use_legacy = False): teammate_front_num_cards_left, bomb_num)) z = _action_seq_list2array(_process_action_seq( - infoset.card_play_action_seq, 20, False), False) + infoset.card_play_action_seq, 20, False), False, compressed_form) z_batch = np.repeat( z[np.newaxis, :, :], num_legal_actions, axis=0) @@ -627,75 +694,75 @@ def _get_obs_landlord_up(infoset, use_legacy = False): } return obs -def _get_obs_landlord_front(infoset, use_legacy = False): +def _get_obs_landlord_front(infoset, use_legacy = False, compressed_form = False): """ Obttain the landlord_front features. See Table 5 in https://arxiv.org/pdf/2106.06135.pdf """ num_legal_actions = len(infoset.legal_actions) - my_handcards = _cards2array(infoset.player_hand_cards) + my_handcards = _cards2array(infoset.player_hand_cards, compressed_form) my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], num_legal_actions, axis=0) - other_handcards = _cards2array(infoset.other_hand_cards) + other_handcards = _cards2array(infoset.other_hand_cards, compressed_form) other_handcards_batch = np.repeat(other_handcards[np.newaxis, :], num_legal_actions, axis=0) - last_action = _cards2array(infoset.last_move) + last_action = _cards2array(infoset.last_move, compressed_form) last_action_batch = np.repeat(last_action[np.newaxis, :], num_legal_actions, axis=0) my_action_batch = np.zeros(my_handcards_batch.shape) for j, action in enumerate(infoset.legal_actions): - my_action_batch[j, :] = _cards2array(action) + my_action_batch[j, :] = _cards2array(action, compressed_form) last_landlord_action = _cards2array( - infoset.last_move_dict['landlord']) + infoset.last_move_dict['landlord'], compressed_form) last_landlord_action_batch = np.repeat( last_landlord_action[np.newaxis, :], num_legal_actions, axis=0) landlord_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord'], 33) + infoset.num_cards_left_dict['landlord'], 33, 15) landlord_num_cards_left_batch = np.repeat( landlord_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) landlord_played_cards = _cards2array( - infoset.played_cards['landlord']) + infoset.played_cards['landlord'], compressed_form) landlord_played_cards_batch = np.repeat( landlord_played_cards[np.newaxis, :], num_legal_actions, axis=0) last_teammate_action = _cards2array( - infoset.last_move_dict['landlord_down']) + infoset.last_move_dict['landlord_down'], compressed_form) last_teammate_action_batch = np.repeat( last_teammate_action[np.newaxis, :], num_legal_actions, axis=0) teammate_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_down'], 25) + infoset.num_cards_left_dict['landlord_down'], 25, 8) teammate_num_cards_left_batch = np.repeat( teammate_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) teammate_played_cards = _cards2array( - infoset.played_cards['landlord_down']) + infoset.played_cards['landlord_down'], compressed_form) teammate_played_cards_batch = np.repeat( teammate_played_cards[np.newaxis, :], num_legal_actions, axis=0) last_teammate_front_action = _cards2array( - infoset.last_move_dict['landlord_front']) + infoset.last_move_dict['landlord_front'], compressed_form) last_teammate_front_action_batch = np.repeat( last_teammate_front_action[np.newaxis, :], num_legal_actions, axis=0) teammate_front_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_front'], 25) + infoset.num_cards_left_dict['landlord_front'], 25, 8) teammate_front_num_cards_left_batch = np.repeat( teammate_front_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) teammate_front_played_cards = _cards2array( - infoset.played_cards['landlord_front']) + infoset.played_cards['landlord_front'], compressed_form) teammate_front_played_cards_batch = np.repeat( teammate_played_cards[np.newaxis, :], num_legal_actions, axis=0) @@ -734,7 +801,7 @@ def _get_obs_landlord_front(infoset, use_legacy = False): teammate_front_num_cards_left, bomb_num)) z = _action_seq_list2array(_process_action_seq( - infoset.card_play_action_seq, 20, False), False) + infoset.card_play_action_seq, 20, False), False, compressed_form) z_batch = np.repeat( z[np.newaxis, :, :], num_legal_actions, axis=0) @@ -748,75 +815,75 @@ def _get_obs_landlord_front(infoset, use_legacy = False): } return obs -def _get_obs_landlord_down(infoset, use_legacy = False): +def _get_obs_landlord_down(infoset, use_legacy = False, compressed_form = False): """ Obttain the landlord_down features. See Table 5 in https://arxiv.org/pdf/2106.06135.pdf """ num_legal_actions = len(infoset.legal_actions) - my_handcards = _cards2array(infoset.player_hand_cards) + my_handcards = _cards2array(infoset.player_hand_cards, compressed_form) my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], num_legal_actions, axis=0) - other_handcards = _cards2array(infoset.other_hand_cards) + other_handcards = _cards2array(infoset.other_hand_cards, compressed_form) other_handcards_batch = np.repeat(other_handcards[np.newaxis, :], num_legal_actions, axis=0) - last_action = _cards2array(infoset.last_move) + last_action = _cards2array(infoset.last_move, compressed_form) last_action_batch = np.repeat(last_action[np.newaxis, :], num_legal_actions, axis=0) my_action_batch = np.zeros(my_handcards_batch.shape) for j, action in enumerate(infoset.legal_actions): - my_action_batch[j, :] = _cards2array(action) + my_action_batch[j, :] = _cards2array(action, compressed_form) last_landlord_action = _cards2array( - infoset.last_move_dict['landlord']) + infoset.last_move_dict['landlord'], compressed_form) last_landlord_action_batch = np.repeat( last_landlord_action[np.newaxis, :], num_legal_actions, axis=0) landlord_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord'], 33) + infoset.num_cards_left_dict['landlord'], 33, 15) landlord_num_cards_left_batch = np.repeat( landlord_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) landlord_played_cards = _cards2array( - infoset.played_cards['landlord']) + infoset.played_cards['landlord'], compressed_form) landlord_played_cards_batch = np.repeat( landlord_played_cards[np.newaxis, :], num_legal_actions, axis=0) last_teammate_action = _cards2array( - infoset.last_move_dict['landlord_up']) + infoset.last_move_dict['landlord_up'], compressed_form) last_teammate_action_batch = np.repeat( last_teammate_action[np.newaxis, :], num_legal_actions, axis=0) teammate_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_up'], 25) + infoset.num_cards_left_dict['landlord_up'], 25, 8) teammate_num_cards_left_batch = np.repeat( teammate_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) teammate_played_cards = _cards2array( - infoset.played_cards['landlord_up']) + infoset.played_cards['landlord_up'], compressed_form) teammate_played_cards_batch = np.repeat( teammate_played_cards[np.newaxis, :], num_legal_actions, axis=0) last_teammate_front_action = _cards2array( - infoset.last_move_dict['landlord_front']) + infoset.last_move_dict['landlord_front'], compressed_form) last_teammate_front_action_batch = np.repeat( last_teammate_front_action[np.newaxis, :], num_legal_actions, axis=0) teammate_front_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_front'], 25) + infoset.num_cards_left_dict['landlord_front'], 25, 8) teammate_front_num_cards_left_batch = np.repeat( teammate_front_num_cards_left[np.newaxis, :], num_legal_actions, axis=0) teammate_front_played_cards = _cards2array( - infoset.played_cards['landlord_front']) + infoset.played_cards['landlord_front'], compressed_form) teammate_front_played_cards_batch = np.repeat( teammate_front_played_cards[np.newaxis, :], num_legal_actions, axis=0) @@ -855,7 +922,7 @@ def _get_obs_landlord_down(infoset, use_legacy = False): teammate_front_num_cards_left, bomb_num)) z = _action_seq_list2array(_process_action_seq( - infoset.card_play_action_seq, 20, False), False) + infoset.card_play_action_seq, 20, False), False, compressed_form) z_batch = np.repeat( z[np.newaxis, :, :], num_legal_actions, axis=0) @@ -869,41 +936,41 @@ def _get_obs_landlord_down(infoset, use_legacy = False): } return obs -def _get_obs_general(infoset, position): +def _get_obs_general(infoset, position, compressed_form = False): num_legal_actions = len(infoset.legal_actions) - my_handcards = _cards2array(infoset.player_hand_cards) + my_handcards = _cards2array(infoset.player_hand_cards, compressed_form) my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], num_legal_actions, axis=0) - other_handcards = _cards2array(infoset.other_hand_cards) + other_handcards = _cards2array(infoset.other_hand_cards, compressed_form) my_action_batch = np.zeros(my_handcards_batch.shape) for j, action in enumerate(infoset.legal_actions): - my_action_batch[j, :] = _cards2array(action) + my_action_batch[j, :] = _cards2array(action, compressed_form) landlord_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord'], 33) + infoset.num_cards_left_dict['landlord'], 33, 15) landlord_up_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_up'], 25) + infoset.num_cards_left_dict['landlord_up'], 25, 8) landlord_front_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_front'], 25) + infoset.num_cards_left_dict['landlord_front'], 25, 8) landlord_down_num_cards_left = _get_one_hot_array( - infoset.num_cards_left_dict['landlord_down'], 25) + infoset.num_cards_left_dict['landlord_down'], 25, 8) landlord_played_cards = _cards2array( - infoset.played_cards['landlord']) + infoset.played_cards['landlord'], compressed_form) landlord_up_played_cards = _cards2array( - infoset.played_cards['landlord_up']) + infoset.played_cards['landlord_up'], compressed_form) landlord_front_played_cards = _cards2array( - infoset.played_cards['landlord_front']) + infoset.played_cards['landlord_front'], compressed_form) landlord_down_played_cards = _cards2array( - infoset.played_cards['landlord_down']) + infoset.played_cards['landlord_down'], compressed_form) bomb_num = _get_one_hot_bomb( infoset.bomb_num) @@ -911,10 +978,10 @@ def _get_obs_general(infoset, position): bomb_num[np.newaxis, :], num_legal_actions, axis=0) num_cards_left = np.hstack(( - landlord_num_cards_left, # 33 - landlord_up_num_cards_left, # 25 - landlord_front_num_cards_left, # 25 - landlord_down_num_cards_left)) + landlord_num_cards_left, # 33/18 + landlord_up_num_cards_left, # 25/17 + landlord_front_num_cards_left, # 25/17 + landlord_down_num_cards_left)) # 25/17 x_batch = np.hstack(( bomb_num_batch, # 56 @@ -924,21 +991,24 @@ def _get_obs_general(infoset, position): )) z =np.vstack(( - num_cards_left, - my_handcards, # 108 - other_handcards, # 108 - landlord_played_cards, # 108 - landlord_up_played_cards, # 108 - landlord_front_played_cards, # 108 - landlord_down_played_cards, # 108 - _action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32)) + num_cards_left, # 108 / 18+17*3=69 + my_handcards, # 108/69 + other_handcards, # 108/69 + landlord_played_cards, # 108/69 + landlord_up_played_cards, # 108/69 + landlord_front_played_cards, # 108/69 + landlord_down_played_cards, # 108/69 + _action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32), True, compressed_form) )) _z_batch = np.repeat( z[np.newaxis, :, :], num_legal_actions, axis=0) my_action_batch = my_action_batch[:,np.newaxis,:] - z_batch = np.zeros([len(_z_batch),40,108],int) + if compressed_form: + z_batch = np.zeros([len(_z_batch), 40, 69], int) + else: + z_batch = np.zeros([len(_z_batch),40,108],int) for i in range(0,len(_z_batch)): z_batch[i] = np.vstack((my_action_batch[i],_z_batch[i])) obs = { diff --git a/douzero/evaluation/deep_agent.py b/douzero/evaluation/deep_agent.py index 5fdf6c1..73162cc 100644 --- a/douzero/evaluation/deep_agent.py +++ b/douzero/evaluation/deep_agent.py @@ -49,6 +49,7 @@ class DeepAgent: def __init__(self, position, model_path): self.use_legacy = True if "legacy" in model_path else False + self.lite_model = True if "lite" in model_path else False self.model_type = "general" if "resnet" in model_path else "old" self.model = _load_model(position, model_path, self.model_type, self.use_legacy) self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider']) @@ -59,7 +60,7 @@ class DeepAgent: if len(infoset.legal_actions) == 1: return infoset.legal_actions[0] - obs = get_obs(infoset, self.model_type == "general", self.use_legacy) + obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model) z_batch = torch.from_numpy(obs['z_batch']).float() x_batch = torch.from_numpy(obs['x_batch']).float()