diff --git a/douzero/dmc/arguments.py b/douzero/dmc/arguments.py index 9f2b65e..02f23b0 100644 --- a/douzero/dmc/arguments.py +++ b/douzero/dmc/arguments.py @@ -27,6 +27,8 @@ parser.add_argument('--load_model', action='store_true', help='Load an existing model') parser.add_argument('--old_model', action='store_true', help='Use vanilla model') +parser.add_argument('--lagacy_model', action='store_true', + help='Use lagacy bomb model') parser.add_argument('--disable_checkpoint', action='store_true', help='Disable saving checkpoint') parser.add_argument('--savedir', default='douzero_checkpoints', diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index b9f84d2..18554bd 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -250,9 +250,10 @@ class BasicBlock(nn.Module): class GeneralModel(nn.Module): - def __init__(self): + def __init__(self, use_legacy = False): super().__init__() self.in_planes = 80 + self.use_legacy = use_legacy #input 1*108*41 self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,), stride=(2,), padding=1, bias=False) #1*108*80 @@ -264,7 +265,10 @@ class GeneralModel(nn.Module): self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320 self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640 # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048) + if self.use_legacy: + self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048) + else: + self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 80, 2048) self.linear2 = nn.Linear(2048, 1024) self.linear3 = nn.Linear(1024, 512) self.linear4 = nn.Linear(512, 256) @@ -285,7 +289,10 @@ class GeneralModel(nn.Module): out = self.layer3(out) out = self.layer4(out) out = out.flatten(1,2) - out = torch.cat([x,x,x,x,out], dim=-1) + if self.use_legacy: + out = torch.cat([x,x,x,x,out], dim=-1) + else: + out = torch.cat([x,out], dim=-1) out = F.leaky_relu_(self.linear1(out)) out = F.leaky_relu_(self.linear2(out)) out = F.leaky_relu_(self.linear3(out)) @@ -451,15 +458,15 @@ class Model: The wrapper for the three models. We also wrap several interfaces such as share_memory, eval, etc. """ - def __init__(self, device=0): + def __init__(self, device=0, use_legacy = False): self.models = {} if not device == "cpu": device = 'cuda:' + str(device) # model = GeneralModel().to(torch.device(device)) - self.models['landlord'] = GeneralModel().to(torch.device(device)) - self.models['landlord_up'] = GeneralModel().to(torch.device(device)) - self.models['landlord_front'] = GeneralModel().to(torch.device(device)) - self.models['landlord_down'] = GeneralModel().to(torch.device(device)) + self.models['landlord'] = GeneralModel(use_legacy).to(torch.device(device)) + self.models['landlord_up'] = GeneralModel(use_legacy).to(torch.device(device)) + self.models['landlord_front'] = GeneralModel(use_legacy).to(torch.device(device)) + self.models['landlord_down'] = GeneralModel(use_legacy).to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device)) def forward(self, position, z, x, training=False, flags=None, debug=False): diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index ea003f6..1b350b2 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -41,7 +41,7 @@ log.setLevel(logging.INFO) Buffers = typing.Dict[str, typing.List[torch.Tensor]] def create_env(flags): - return Env(flags.objective, flags.old_model) + return Env(flags.objective, flags.old_model, flags.lagacy_model) def get_batch(b_queues, position, flags, lock): """ diff --git a/douzero/env/env.py b/douzero/env/env.py index 666a49e..da07f65 100644 --- a/douzero/env/env.py +++ b/douzero/env/env.py @@ -34,7 +34,7 @@ class Env: Doudizhu multi-agent wrapper """ - def __init__(self, objective, old_model): + def __init__(self, objective, old_model, legacy_model=False): """ Objective is wp/adp/logadp. It indicates whether considers bomb in reward calculation. Here, we use dummy agents. @@ -48,6 +48,7 @@ class Env: """ self.objective = objective self.use_general = not old_model + self.use_legacy = legacy_model # Initialize players # We use three dummy player for the target position @@ -83,7 +84,7 @@ class Env: card_play_data[key].sort() self._env.card_play_init(card_play_data) self.infoset = self._game_infoset - return get_obs(self.infoset, self.use_general) + return get_obs(self.infoset, self.use_general, self.use_legacy) else: self.total_round += 1 bid_done = False @@ -233,7 +234,7 @@ class Env: print("发牌情况: %i/%i %.1f%%" % (self.force_bid, self.total_round, self.force_bid / self.total_round * 100)) self.force_bid = 0 self.total_round = 0 - return get_obs(self.infoset, self.use_general), { + return get_obs(self.infoset, self.use_general, self.use_legacy), { "bid_obs_buffer": bid_obs_buffer, "multiply_obs_buffer": multiply_obs_buffer } @@ -270,7 +271,7 @@ class Env: } obs = None else: - obs = get_obs(self.infoset, self.use_general) + obs = get_obs(self.infoset, self.use_general, self.use_legacy) return obs, reward, done, {} def _get_reward(self, pos): @@ -381,7 +382,7 @@ class DummyAgent(object): self.action = action -def get_obs(infoset, use_general=True): +def get_obs(infoset, use_general=True, use_legacy = False): """ This function obtains observations with imperfect information from the infoset. It has three branches since we encode @@ -408,16 +409,16 @@ def get_obs(infoset, use_general=True): if use_general: if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]: raise ValueError('') - return _get_obs_general(infoset, infoset.player_position) + return _get_obs_general(infoset, infoset.player_position, use_legacy) else: if infoset.player_position == 'landlord': - return _get_obs_landlord(infoset) + return _get_obs_landlord(infoset, use_legacy) elif infoset.player_position == 'landlord_up': - return _get_obs_landlord_up(infoset) + return _get_obs_landlord_up(infoset, use_legacy) elif infoset.player_position == 'landlord_front': - return _get_obs_landlord_front(infoset) + return _get_obs_landlord_front(infoset, use_legacy) elif infoset.player_position == 'landlord_down': - return _get_obs_landlord_down(infoset) + return _get_obs_landlord_down(infoset, use_legacy) else: raise ValueError('') @@ -524,17 +525,23 @@ def _process_action_seq(sequence, length=20, new_model=True): return sequence -def _get_one_hot_bomb(bomb_num): +def _get_one_hot_bomb(bomb_num, use_legacy = False): """ A utility function to encode the number of bombs into one-hot representation. """ - one_hot = np.zeros(29) - one_hot[bomb_num[0] + bomb_num[1]] = 1 + if use_legacy: + one_hot = np.zeros(29) + one_hot[bomb_num[0] + bomb_num[1]] = 1 + else: + one_hot = np.zeros(56) # 14 + 15 + 27 + one_hot[bomb_num[0]] = 1 + one_hot[14 + bomb_num[1]] = 1 + one_hot[29 + bomb_num[2]] = 1 return one_hot -def _get_obs_landlord(infoset): +def _get_obs_landlord(infoset, use_legacy = False): """ Obttain the landlord features. See Table 4 in https://arxiv.org/pdf/2106.06135.pdf @@ -593,7 +600,7 @@ def _get_obs_landlord(infoset): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, use_legacy) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -634,7 +641,7 @@ def _get_obs_landlord(infoset): } return obs -def _get_obs_landlord_up(infoset): +def _get_obs_landlord_up(infoset, use_legacy = False): """ Obttain the landlord_up features. See Table 5 in https://arxiv.org/pdf/2106.06135.pdf @@ -708,7 +715,7 @@ def _get_obs_landlord_up(infoset): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, use_legacy) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -755,7 +762,7 @@ def _get_obs_landlord_up(infoset): } return obs -def _get_obs_landlord_front(infoset): +def _get_obs_landlord_front(infoset, use_legacy = False): """ Obttain the landlord_front features. See Table 5 in https://arxiv.org/pdf/2106.06135.pdf @@ -829,7 +836,7 @@ def _get_obs_landlord_front(infoset): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, use_legacy) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -876,7 +883,7 @@ def _get_obs_landlord_front(infoset): } return obs -def _get_obs_landlord_down(infoset): +def _get_obs_landlord_down(infoset, use_legacy = False): """ Obttain the landlord_down features. See Table 5 in https://arxiv.org/pdf/2106.06135.pdf @@ -950,7 +957,7 @@ def _get_obs_landlord_down(infoset): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, use_legacy) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -1212,7 +1219,7 @@ def _get_obs_general1(infoset, position): } return obs -def _get_obs_general(infoset, position): +def _get_obs_general(infoset, position, use_legacy = False): num_legal_actions = len(infoset.legal_actions) my_handcards = _cards2array(infoset.player_hand_cards) my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], @@ -1306,7 +1313,7 @@ def _get_obs_general(infoset, position): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, use_legacy) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -1316,12 +1323,22 @@ def _get_obs_general(infoset, position): landlord_front_num_cards_left, # 25 landlord_down_num_cards_left)) - x_batch = np.hstack(( - bid_info_batch, # 16 - multiply_info_batch)) # 4 - x_no_action = np.hstack(( - bid_info, - multiply_info)) + if use_legacy: + x_batch = np.hstack(( + bid_info_batch, # 20 + multiply_info_batch)) # 4 + x_no_action = np.hstack(( + bid_info, + multiply_info)) + else: + x_batch = np.hstack(( + bomb_num_batch, # 56 + bid_info_batch, # 20 + multiply_info_batch)) # 4 + x_no_action = np.hstack(( + bomb_num, # 56 + bid_info, + multiply_info)) z =np.vstack(( num_cards_left, my_handcards, # 108 diff --git a/douzero/env/game.py b/douzero/env/game.py index c898a27..4aea422 100644 --- a/douzero/env/game.py +++ b/douzero/env/game.py @@ -11,19 +11,16 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '10': 10, 'J': 11, 'Q': 12, 'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30} -bombs = [ - [[3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7], - [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10], [11, 11, 11, 11, 11, 11], - [12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13], [14, 14, 14, 14, 14, 14], [17, 17, 17, 17, 17, 17], - [3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7, 7], - [8, 8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10, 10], [11, 11, 11, 11, 11, 11, 11], - [12, 12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13, 13], [14, 14, 14, 14, 14, 14, 14], - [17, 17, 17, 17, 17, 17, 17]], - [[3, 3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6, 6, 6], - [7, 7, 7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10, 10, 10], - [11, 11, 11, 11, 11, 11, 11, 11], [12, 12, 12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13, 13, 13], - [14, 14, 14, 14, 14, 14, 14, 14], [17, 17, 17, 17, 17, 17, 17, 17], - [20, 20, 30, 30]]] +cards_idx = [x for x in range(3, 15)] +cards_idx.extend([17, 20, 30]) + +bombs = [[[x] * 6 for x in cards_idx[:-2]], [[x] * 8 for x in cards_idx[:-2]], [[x] * 4 for x in cards_idx[:-2]]] +# Rocket bomb +bombs[0].extend([[x] * 7 for x in cards_idx[:-2]]) +# King bomb +bombs[1].extend([[20, 20, 30, 30]]) +# Normal bomb +bombs[2].extend([[x] * 5 for x in cards_idx[:-2]]) class GameEnv(object): @@ -63,7 +60,7 @@ class GameEnv(object): 'landlord_front': InfoSet('landlord_front'), 'landlord_down': InfoSet('landlord_down')} - self.bomb_num = [0, 0] + self.bomb_num = [0, 0, 0] self.pos_bomb_num = { "landlord": 0, "landlord_up": 0, @@ -162,6 +159,9 @@ class GameEnv(object): self.bomb_num[1] += 1 self.pos_bomb_num[self.acting_player_position] += 1 + if action in bombs[2]: + self.bomb_num[2] += 1 + self.last_move_dict[ self.acting_player_position] = action.copy() @@ -367,7 +367,7 @@ class GameEnv(object): 'landlord_front': InfoSet('landlord_front'), 'landlord_down': InfoSet('landlord_down')} - self.bomb_num = [0, 0] + self.bomb_num = [0, 0, 0] self.pos_bomb_num = { "landlord": 0, "landlord_up": 0,