调整obs,增加炸弹作为特征之一

This commit is contained in:
ZaneYork 2021-12-11 22:10:40 +08:00
parent 82e941a5eb
commit 64bf792d15
5 changed files with 79 additions and 53 deletions

View File

@ -27,6 +27,8 @@ parser.add_argument('--load_model', action='store_true',
help='Load an existing model')
parser.add_argument('--old_model', action='store_true',
help='Use vanilla model')
parser.add_argument('--lagacy_model', action='store_true',
help='Use lagacy bomb model')
parser.add_argument('--disable_checkpoint', action='store_true',
help='Disable saving checkpoint')
parser.add_argument('--savedir', default='douzero_checkpoints',

View File

@ -250,9 +250,10 @@ class BasicBlock(nn.Module):
class GeneralModel(nn.Module):
def __init__(self):
def __init__(self, use_legacy = False):
super().__init__()
self.in_planes = 80
self.use_legacy = use_legacy
#input 1*108*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*108*80
@ -264,7 +265,10 @@ class GeneralModel(nn.Module):
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
if self.use_legacy:
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048)
else:
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 80, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256)
@ -285,7 +289,10 @@ class GeneralModel(nn.Module):
out = self.layer3(out)
out = self.layer4(out)
out = out.flatten(1,2)
if self.use_legacy:
out = torch.cat([x,x,x,x,out], dim=-1)
else:
out = torch.cat([x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
@ -451,15 +458,15 @@ class Model:
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0):
def __init__(self, device=0, use_legacy = False):
self.models = {}
if not device == "cpu":
device = 'cuda:' + str(device)
# model = GeneralModel().to(torch.device(device))
self.models['landlord'] = GeneralModel().to(torch.device(device))
self.models['landlord_up'] = GeneralModel().to(torch.device(device))
self.models['landlord_front'] = GeneralModel().to(torch.device(device))
self.models['landlord_down'] = GeneralModel().to(torch.device(device))
self.models['landlord'] = GeneralModel(use_legacy).to(torch.device(device))
self.models['landlord_up'] = GeneralModel(use_legacy).to(torch.device(device))
self.models['landlord_front'] = GeneralModel(use_legacy).to(torch.device(device))
self.models['landlord_down'] = GeneralModel(use_legacy).to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None, debug=False):

View File

@ -41,7 +41,7 @@ log.setLevel(logging.INFO)
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags):
return Env(flags.objective, flags.old_model)
return Env(flags.objective, flags.old_model, flags.lagacy_model)
def get_batch(b_queues, position, flags, lock):
"""

61
douzero/env/env.py vendored
View File

@ -34,7 +34,7 @@ class Env:
Doudizhu multi-agent wrapper
"""
def __init__(self, objective, old_model):
def __init__(self, objective, old_model, legacy_model=False):
"""
Objective is wp/adp/logadp. It indicates whether considers
bomb in reward calculation. Here, we use dummy agents.
@ -48,6 +48,7 @@ class Env:
"""
self.objective = objective
self.use_general = not old_model
self.use_legacy = legacy_model
# Initialize players
# We use three dummy player for the target position
@ -83,7 +84,7 @@ class Env:
card_play_data[key].sort()
self._env.card_play_init(card_play_data)
self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general)
return get_obs(self.infoset, self.use_general, self.use_legacy)
else:
self.total_round += 1
bid_done = False
@ -233,7 +234,7 @@ class Env:
print("发牌情况: %i/%i %.1f%%" % (self.force_bid, self.total_round, self.force_bid / self.total_round * 100))
self.force_bid = 0
self.total_round = 0
return get_obs(self.infoset, self.use_general), {
return get_obs(self.infoset, self.use_general, self.use_legacy), {
"bid_obs_buffer": bid_obs_buffer,
"multiply_obs_buffer": multiply_obs_buffer
}
@ -270,7 +271,7 @@ class Env:
}
obs = None
else:
obs = get_obs(self.infoset, self.use_general)
obs = get_obs(self.infoset, self.use_general, self.use_legacy)
return obs, reward, done, {}
def _get_reward(self, pos):
@ -381,7 +382,7 @@ class DummyAgent(object):
self.action = action
def get_obs(infoset, use_general=True):
def get_obs(infoset, use_general=True, use_legacy = False):
"""
This function obtains observations with imperfect information
from the infoset. It has three branches since we encode
@ -408,16 +409,16 @@ def get_obs(infoset, use_general=True):
if use_general:
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
raise ValueError('')
return _get_obs_general(infoset, infoset.player_position)
return _get_obs_general(infoset, infoset.player_position, use_legacy)
else:
if infoset.player_position == 'landlord':
return _get_obs_landlord(infoset)
return _get_obs_landlord(infoset, use_legacy)
elif infoset.player_position == 'landlord_up':
return _get_obs_landlord_up(infoset)
return _get_obs_landlord_up(infoset, use_legacy)
elif infoset.player_position == 'landlord_front':
return _get_obs_landlord_front(infoset)
return _get_obs_landlord_front(infoset, use_legacy)
elif infoset.player_position == 'landlord_down':
return _get_obs_landlord_down(infoset)
return _get_obs_landlord_down(infoset, use_legacy)
else:
raise ValueError('')
@ -524,17 +525,23 @@ def _process_action_seq(sequence, length=20, new_model=True):
return sequence
def _get_one_hot_bomb(bomb_num):
def _get_one_hot_bomb(bomb_num, use_legacy = False):
"""
A utility function to encode the number of bombs
into one-hot representation.
"""
if use_legacy:
one_hot = np.zeros(29)
one_hot[bomb_num[0] + bomb_num[1]] = 1
else:
one_hot = np.zeros(56) # 14 + 15 + 27
one_hot[bomb_num[0]] = 1
one_hot[14 + bomb_num[1]] = 1
one_hot[29 + bomb_num[2]] = 1
return one_hot
def _get_obs_landlord(infoset):
def _get_obs_landlord(infoset, use_legacy = False):
"""
Obttain the landlord features. See Table 4 in
https://arxiv.org/pdf/2106.06135.pdf
@ -593,7 +600,7 @@ def _get_obs_landlord(infoset):
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)
infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
@ -634,7 +641,7 @@ def _get_obs_landlord(infoset):
}
return obs
def _get_obs_landlord_up(infoset):
def _get_obs_landlord_up(infoset, use_legacy = False):
"""
Obttain the landlord_up features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf
@ -708,7 +715,7 @@ def _get_obs_landlord_up(infoset):
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)
infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
@ -755,7 +762,7 @@ def _get_obs_landlord_up(infoset):
}
return obs
def _get_obs_landlord_front(infoset):
def _get_obs_landlord_front(infoset, use_legacy = False):
"""
Obttain the landlord_front features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf
@ -829,7 +836,7 @@ def _get_obs_landlord_front(infoset):
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)
infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
@ -876,7 +883,7 @@ def _get_obs_landlord_front(infoset):
}
return obs
def _get_obs_landlord_down(infoset):
def _get_obs_landlord_down(infoset, use_legacy = False):
"""
Obttain the landlord_down features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf
@ -950,7 +957,7 @@ def _get_obs_landlord_down(infoset):
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)
infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
@ -1212,7 +1219,7 @@ def _get_obs_general1(infoset, position):
}
return obs
def _get_obs_general(infoset, position):
def _get_obs_general(infoset, position, use_legacy = False):
num_legal_actions = len(infoset.legal_actions)
my_handcards = _cards2array(infoset.player_hand_cards)
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
@ -1306,7 +1313,7 @@ def _get_obs_general(infoset, position):
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)
infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
@ -1316,12 +1323,22 @@ def _get_obs_general(infoset, position):
landlord_front_num_cards_left, # 25
landlord_down_num_cards_left))
if use_legacy:
x_batch = np.hstack((
bid_info_batch, # 16
bid_info_batch, # 20
multiply_info_batch)) # 4
x_no_action = np.hstack((
bid_info,
multiply_info))
else:
x_batch = np.hstack((
bomb_num_batch, # 56
bid_info_batch, # 20
multiply_info_batch)) # 4
x_no_action = np.hstack((
bomb_num, # 56
bid_info,
multiply_info))
z =np.vstack((
num_cards_left,
my_handcards, # 108

30
douzero/env/game.py vendored
View File

@ -11,19 +11,16 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, '10': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
bombs = [
[[3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7],
[8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10], [11, 11, 11, 11, 11, 11],
[12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13], [14, 14, 14, 14, 14, 14], [17, 17, 17, 17, 17, 17],
[3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7, 7],
[8, 8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10, 10], [11, 11, 11, 11, 11, 11, 11],
[12, 12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13, 13], [14, 14, 14, 14, 14, 14, 14],
[17, 17, 17, 17, 17, 17, 17]],
[[3, 3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6, 6, 6],
[7, 7, 7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10, 10, 10],
[11, 11, 11, 11, 11, 11, 11, 11], [12, 12, 12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13, 13, 13],
[14, 14, 14, 14, 14, 14, 14, 14], [17, 17, 17, 17, 17, 17, 17, 17],
[20, 20, 30, 30]]]
cards_idx = [x for x in range(3, 15)]
cards_idx.extend([17, 20, 30])
bombs = [[[x] * 6 for x in cards_idx[:-2]], [[x] * 8 for x in cards_idx[:-2]], [[x] * 4 for x in cards_idx[:-2]]]
# Rocket bomb
bombs[0].extend([[x] * 7 for x in cards_idx[:-2]])
# King bomb
bombs[1].extend([[20, 20, 30, 30]])
# Normal bomb
bombs[2].extend([[x] * 5 for x in cards_idx[:-2]])
class GameEnv(object):
@ -63,7 +60,7 @@ class GameEnv(object):
'landlord_front': InfoSet('landlord_front'),
'landlord_down': InfoSet('landlord_down')}
self.bomb_num = [0, 0]
self.bomb_num = [0, 0, 0]
self.pos_bomb_num = {
"landlord": 0,
"landlord_up": 0,
@ -162,6 +159,9 @@ class GameEnv(object):
self.bomb_num[1] += 1
self.pos_bomb_num[self.acting_player_position] += 1
if action in bombs[2]:
self.bomb_num[2] += 1
self.last_move_dict[
self.acting_player_position] = action.copy()
@ -367,7 +367,7 @@ class GameEnv(object):
'landlord_front': InfoSet('landlord_front'),
'landlord_down': InfoSet('landlord_down')}
self.bomb_num = [0, 0]
self.bomb_num = [0, 0, 0]
self.pos_bomb_num = {
"landlord": 0,
"landlord_up": 0,