调整obs,增加炸弹作为特征之一

This commit is contained in:
ZaneYork 2021-12-11 22:10:40 +08:00
parent 82e941a5eb
commit 64bf792d15
5 changed files with 79 additions and 53 deletions

View File

@ -27,6 +27,8 @@ parser.add_argument('--load_model', action='store_true',
help='Load an existing model') help='Load an existing model')
parser.add_argument('--old_model', action='store_true', parser.add_argument('--old_model', action='store_true',
help='Use vanilla model') help='Use vanilla model')
parser.add_argument('--lagacy_model', action='store_true',
help='Use lagacy bomb model')
parser.add_argument('--disable_checkpoint', action='store_true', parser.add_argument('--disable_checkpoint', action='store_true',
help='Disable saving checkpoint') help='Disable saving checkpoint')
parser.add_argument('--savedir', default='douzero_checkpoints', parser.add_argument('--savedir', default='douzero_checkpoints',

View File

@ -250,9 +250,10 @@ class BasicBlock(nn.Module):
class GeneralModel(nn.Module): class GeneralModel(nn.Module):
def __init__(self): def __init__(self, use_legacy = False):
super().__init__() super().__init__()
self.in_planes = 80 self.in_planes = 80
self.use_legacy = use_legacy
#input 1*108*41 #input 1*108*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,), self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*108*80 stride=(2,), padding=1, bias=False) #1*108*80
@ -264,7 +265,10 @@ class GeneralModel(nn.Module):
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320 self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640 self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048) if self.use_legacy:
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048)
else:
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 80, 2048)
self.linear2 = nn.Linear(2048, 1024) self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512) self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256) self.linear4 = nn.Linear(512, 256)
@ -285,7 +289,10 @@ class GeneralModel(nn.Module):
out = self.layer3(out) out = self.layer3(out)
out = self.layer4(out) out = self.layer4(out)
out = out.flatten(1,2) out = out.flatten(1,2)
out = torch.cat([x,x,x,x,out], dim=-1) if self.use_legacy:
out = torch.cat([x,x,x,x,out], dim=-1)
else:
out = torch.cat([x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out)) out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out)) out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out)) out = F.leaky_relu_(self.linear3(out))
@ -451,15 +458,15 @@ class Model:
The wrapper for the three models. We also wrap several The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc. interfaces such as share_memory, eval, etc.
""" """
def __init__(self, device=0): def __init__(self, device=0, use_legacy = False):
self.models = {} self.models = {}
if not device == "cpu": if not device == "cpu":
device = 'cuda:' + str(device) device = 'cuda:' + str(device)
# model = GeneralModel().to(torch.device(device)) # model = GeneralModel().to(torch.device(device))
self.models['landlord'] = GeneralModel().to(torch.device(device)) self.models['landlord'] = GeneralModel(use_legacy).to(torch.device(device))
self.models['landlord_up'] = GeneralModel().to(torch.device(device)) self.models['landlord_up'] = GeneralModel(use_legacy).to(torch.device(device))
self.models['landlord_front'] = GeneralModel().to(torch.device(device)) self.models['landlord_front'] = GeneralModel(use_legacy).to(torch.device(device))
self.models['landlord_down'] = GeneralModel().to(torch.device(device)) self.models['landlord_down'] = GeneralModel(use_legacy).to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None, debug=False): def forward(self, position, z, x, training=False, flags=None, debug=False):

View File

@ -41,7 +41,7 @@ log.setLevel(logging.INFO)
Buffers = typing.Dict[str, typing.List[torch.Tensor]] Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags): def create_env(flags):
return Env(flags.objective, flags.old_model) return Env(flags.objective, flags.old_model, flags.lagacy_model)
def get_batch(b_queues, position, flags, lock): def get_batch(b_queues, position, flags, lock):
""" """

75
douzero/env/env.py vendored
View File

@ -34,7 +34,7 @@ class Env:
Doudizhu multi-agent wrapper Doudizhu multi-agent wrapper
""" """
def __init__(self, objective, old_model): def __init__(self, objective, old_model, legacy_model=False):
""" """
Objective is wp/adp/logadp. It indicates whether considers Objective is wp/adp/logadp. It indicates whether considers
bomb in reward calculation. Here, we use dummy agents. bomb in reward calculation. Here, we use dummy agents.
@ -48,6 +48,7 @@ class Env:
""" """
self.objective = objective self.objective = objective
self.use_general = not old_model self.use_general = not old_model
self.use_legacy = legacy_model
# Initialize players # Initialize players
# We use three dummy player for the target position # We use three dummy player for the target position
@ -83,7 +84,7 @@ class Env:
card_play_data[key].sort() card_play_data[key].sort()
self._env.card_play_init(card_play_data) self._env.card_play_init(card_play_data)
self.infoset = self._game_infoset self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general) return get_obs(self.infoset, self.use_general, self.use_legacy)
else: else:
self.total_round += 1 self.total_round += 1
bid_done = False bid_done = False
@ -233,7 +234,7 @@ class Env:
print("发牌情况: %i/%i %.1f%%" % (self.force_bid, self.total_round, self.force_bid / self.total_round * 100)) print("发牌情况: %i/%i %.1f%%" % (self.force_bid, self.total_round, self.force_bid / self.total_round * 100))
self.force_bid = 0 self.force_bid = 0
self.total_round = 0 self.total_round = 0
return get_obs(self.infoset, self.use_general), { return get_obs(self.infoset, self.use_general, self.use_legacy), {
"bid_obs_buffer": bid_obs_buffer, "bid_obs_buffer": bid_obs_buffer,
"multiply_obs_buffer": multiply_obs_buffer "multiply_obs_buffer": multiply_obs_buffer
} }
@ -270,7 +271,7 @@ class Env:
} }
obs = None obs = None
else: else:
obs = get_obs(self.infoset, self.use_general) obs = get_obs(self.infoset, self.use_general, self.use_legacy)
return obs, reward, done, {} return obs, reward, done, {}
def _get_reward(self, pos): def _get_reward(self, pos):
@ -381,7 +382,7 @@ class DummyAgent(object):
self.action = action self.action = action
def get_obs(infoset, use_general=True): def get_obs(infoset, use_general=True, use_legacy = False):
""" """
This function obtains observations with imperfect information This function obtains observations with imperfect information
from the infoset. It has three branches since we encode from the infoset. It has three branches since we encode
@ -408,16 +409,16 @@ def get_obs(infoset, use_general=True):
if use_general: if use_general:
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]: if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
raise ValueError('') raise ValueError('')
return _get_obs_general(infoset, infoset.player_position) return _get_obs_general(infoset, infoset.player_position, use_legacy)
else: else:
if infoset.player_position == 'landlord': if infoset.player_position == 'landlord':
return _get_obs_landlord(infoset) return _get_obs_landlord(infoset, use_legacy)
elif infoset.player_position == 'landlord_up': elif infoset.player_position == 'landlord_up':
return _get_obs_landlord_up(infoset) return _get_obs_landlord_up(infoset, use_legacy)
elif infoset.player_position == 'landlord_front': elif infoset.player_position == 'landlord_front':
return _get_obs_landlord_front(infoset) return _get_obs_landlord_front(infoset, use_legacy)
elif infoset.player_position == 'landlord_down': elif infoset.player_position == 'landlord_down':
return _get_obs_landlord_down(infoset) return _get_obs_landlord_down(infoset, use_legacy)
else: else:
raise ValueError('') raise ValueError('')
@ -524,17 +525,23 @@ def _process_action_seq(sequence, length=20, new_model=True):
return sequence return sequence
def _get_one_hot_bomb(bomb_num): def _get_one_hot_bomb(bomb_num, use_legacy = False):
""" """
A utility function to encode the number of bombs A utility function to encode the number of bombs
into one-hot representation. into one-hot representation.
""" """
one_hot = np.zeros(29) if use_legacy:
one_hot[bomb_num[0] + bomb_num[1]] = 1 one_hot = np.zeros(29)
one_hot[bomb_num[0] + bomb_num[1]] = 1
else:
one_hot = np.zeros(56) # 14 + 15 + 27
one_hot[bomb_num[0]] = 1
one_hot[14 + bomb_num[1]] = 1
one_hot[29 + bomb_num[2]] = 1
return one_hot return one_hot
def _get_obs_landlord(infoset): def _get_obs_landlord(infoset, use_legacy = False):
""" """
Obttain the landlord features. See Table 4 in Obttain the landlord features. See Table 4 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
@ -593,7 +600,7 @@ def _get_obs_landlord(infoset):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -634,7 +641,7 @@ def _get_obs_landlord(infoset):
} }
return obs return obs
def _get_obs_landlord_up(infoset): def _get_obs_landlord_up(infoset, use_legacy = False):
""" """
Obttain the landlord_up features. See Table 5 in Obttain the landlord_up features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
@ -708,7 +715,7 @@ def _get_obs_landlord_up(infoset):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -755,7 +762,7 @@ def _get_obs_landlord_up(infoset):
} }
return obs return obs
def _get_obs_landlord_front(infoset): def _get_obs_landlord_front(infoset, use_legacy = False):
""" """
Obttain the landlord_front features. See Table 5 in Obttain the landlord_front features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
@ -829,7 +836,7 @@ def _get_obs_landlord_front(infoset):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -876,7 +883,7 @@ def _get_obs_landlord_front(infoset):
} }
return obs return obs
def _get_obs_landlord_down(infoset): def _get_obs_landlord_down(infoset, use_legacy = False):
""" """
Obttain the landlord_down features. See Table 5 in Obttain the landlord_down features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
@ -950,7 +957,7 @@ def _get_obs_landlord_down(infoset):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -1212,7 +1219,7 @@ def _get_obs_general1(infoset, position):
} }
return obs return obs
def _get_obs_general(infoset, position): def _get_obs_general(infoset, position, use_legacy = False):
num_legal_actions = len(infoset.legal_actions) num_legal_actions = len(infoset.legal_actions)
my_handcards = _cards2array(infoset.player_hand_cards) my_handcards = _cards2array(infoset.player_hand_cards)
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :], my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
@ -1306,7 +1313,7 @@ def _get_obs_general(infoset, position):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -1316,12 +1323,22 @@ def _get_obs_general(infoset, position):
landlord_front_num_cards_left, # 25 landlord_front_num_cards_left, # 25
landlord_down_num_cards_left)) landlord_down_num_cards_left))
x_batch = np.hstack(( if use_legacy:
bid_info_batch, # 16 x_batch = np.hstack((
multiply_info_batch)) # 4 bid_info_batch, # 20
x_no_action = np.hstack(( multiply_info_batch)) # 4
bid_info, x_no_action = np.hstack((
multiply_info)) bid_info,
multiply_info))
else:
x_batch = np.hstack((
bomb_num_batch, # 56
bid_info_batch, # 20
multiply_info_batch)) # 4
x_no_action = np.hstack((
bomb_num, # 56
bid_info,
multiply_info))
z =np.vstack(( z =np.vstack((
num_cards_left, num_cards_left,
my_handcards, # 108 my_handcards, # 108

30
douzero/env/game.py vendored
View File

@ -11,19 +11,16 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, '10': 10, 'J': 11, 'Q': 12, '8': 8, '9': 9, '10': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30} 'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
bombs = [ cards_idx = [x for x in range(3, 15)]
[[3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7], cards_idx.extend([17, 20, 30])
[8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10], [11, 11, 11, 11, 11, 11],
[12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13], [14, 14, 14, 14, 14, 14], [17, 17, 17, 17, 17, 17], bombs = [[[x] * 6 for x in cards_idx[:-2]], [[x] * 8 for x in cards_idx[:-2]], [[x] * 4 for x in cards_idx[:-2]]]
[3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7, 7], # Rocket bomb
[8, 8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10, 10], [11, 11, 11, 11, 11, 11, 11], bombs[0].extend([[x] * 7 for x in cards_idx[:-2]])
[12, 12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13, 13], [14, 14, 14, 14, 14, 14, 14], # King bomb
[17, 17, 17, 17, 17, 17, 17]], bombs[1].extend([[20, 20, 30, 30]])
[[3, 3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6, 6, 6], # Normal bomb
[7, 7, 7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10, 10, 10], bombs[2].extend([[x] * 5 for x in cards_idx[:-2]])
[11, 11, 11, 11, 11, 11, 11, 11], [12, 12, 12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13, 13, 13],
[14, 14, 14, 14, 14, 14, 14, 14], [17, 17, 17, 17, 17, 17, 17, 17],
[20, 20, 30, 30]]]
class GameEnv(object): class GameEnv(object):
@ -63,7 +60,7 @@ class GameEnv(object):
'landlord_front': InfoSet('landlord_front'), 'landlord_front': InfoSet('landlord_front'),
'landlord_down': InfoSet('landlord_down')} 'landlord_down': InfoSet('landlord_down')}
self.bomb_num = [0, 0] self.bomb_num = [0, 0, 0]
self.pos_bomb_num = { self.pos_bomb_num = {
"landlord": 0, "landlord": 0,
"landlord_up": 0, "landlord_up": 0,
@ -162,6 +159,9 @@ class GameEnv(object):
self.bomb_num[1] += 1 self.bomb_num[1] += 1
self.pos_bomb_num[self.acting_player_position] += 1 self.pos_bomb_num[self.acting_player_position] += 1
if action in bombs[2]:
self.bomb_num[2] += 1
self.last_move_dict[ self.last_move_dict[
self.acting_player_position] = action.copy() self.acting_player_position] = action.copy()
@ -367,7 +367,7 @@ class GameEnv(object):
'landlord_front': InfoSet('landlord_front'), 'landlord_front': InfoSet('landlord_front'),
'landlord_down': InfoSet('landlord_down')} 'landlord_down': InfoSet('landlord_down')}
self.bomb_num = [0, 0] self.bomb_num = [0, 0, 0]
self.pos_bomb_num = { self.pos_bomb_num = {
"landlord": 0, "landlord": 0,
"landlord_up": 0, "landlord_up": 0,