调整obs,增加炸弹作为特征之一
This commit is contained in:
parent
82e941a5eb
commit
64bf792d15
|
@ -27,6 +27,8 @@ parser.add_argument('--load_model', action='store_true',
|
|||
help='Load an existing model')
|
||||
parser.add_argument('--old_model', action='store_true',
|
||||
help='Use vanilla model')
|
||||
parser.add_argument('--lagacy_model', action='store_true',
|
||||
help='Use lagacy bomb model')
|
||||
parser.add_argument('--disable_checkpoint', action='store_true',
|
||||
help='Disable saving checkpoint')
|
||||
parser.add_argument('--savedir', default='douzero_checkpoints',
|
||||
|
|
|
@ -250,9 +250,10 @@ class BasicBlock(nn.Module):
|
|||
|
||||
|
||||
class GeneralModel(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, use_legacy = False):
|
||||
super().__init__()
|
||||
self.in_planes = 80
|
||||
self.use_legacy = use_legacy
|
||||
#input 1*108*41
|
||||
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
|
||||
stride=(2,), padding=1, bias=False) #1*108*80
|
||||
|
@ -264,7 +265,10 @@ class GeneralModel(nn.Module):
|
|||
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
|
||||
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
|
||||
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
if self.use_legacy:
|
||||
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048)
|
||||
else:
|
||||
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 80, 2048)
|
||||
self.linear2 = nn.Linear(2048, 1024)
|
||||
self.linear3 = nn.Linear(1024, 512)
|
||||
self.linear4 = nn.Linear(512, 256)
|
||||
|
@ -285,7 +289,10 @@ class GeneralModel(nn.Module):
|
|||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = out.flatten(1,2)
|
||||
if self.use_legacy:
|
||||
out = torch.cat([x,x,x,x,out], dim=-1)
|
||||
else:
|
||||
out = torch.cat([x,out], dim=-1)
|
||||
out = F.leaky_relu_(self.linear1(out))
|
||||
out = F.leaky_relu_(self.linear2(out))
|
||||
out = F.leaky_relu_(self.linear3(out))
|
||||
|
@ -451,15 +458,15 @@ class Model:
|
|||
The wrapper for the three models. We also wrap several
|
||||
interfaces such as share_memory, eval, etc.
|
||||
"""
|
||||
def __init__(self, device=0):
|
||||
def __init__(self, device=0, use_legacy = False):
|
||||
self.models = {}
|
||||
if not device == "cpu":
|
||||
device = 'cuda:' + str(device)
|
||||
# model = GeneralModel().to(torch.device(device))
|
||||
self.models['landlord'] = GeneralModel().to(torch.device(device))
|
||||
self.models['landlord_up'] = GeneralModel().to(torch.device(device))
|
||||
self.models['landlord_front'] = GeneralModel().to(torch.device(device))
|
||||
self.models['landlord_down'] = GeneralModel().to(torch.device(device))
|
||||
self.models['landlord'] = GeneralModel(use_legacy).to(torch.device(device))
|
||||
self.models['landlord_up'] = GeneralModel(use_legacy).to(torch.device(device))
|
||||
self.models['landlord_front'] = GeneralModel(use_legacy).to(torch.device(device))
|
||||
self.models['landlord_down'] = GeneralModel(use_legacy).to(torch.device(device))
|
||||
self.models['bidding'] = BidModel().to(torch.device(device))
|
||||
|
||||
def forward(self, position, z, x, training=False, flags=None, debug=False):
|
||||
|
|
|
@ -41,7 +41,7 @@ log.setLevel(logging.INFO)
|
|||
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
||||
|
||||
def create_env(flags):
|
||||
return Env(flags.objective, flags.old_model)
|
||||
return Env(flags.objective, flags.old_model, flags.lagacy_model)
|
||||
|
||||
def get_batch(b_queues, position, flags, lock):
|
||||
"""
|
||||
|
|
|
@ -34,7 +34,7 @@ class Env:
|
|||
Doudizhu multi-agent wrapper
|
||||
"""
|
||||
|
||||
def __init__(self, objective, old_model):
|
||||
def __init__(self, objective, old_model, legacy_model=False):
|
||||
"""
|
||||
Objective is wp/adp/logadp. It indicates whether considers
|
||||
bomb in reward calculation. Here, we use dummy agents.
|
||||
|
@ -48,6 +48,7 @@ class Env:
|
|||
"""
|
||||
self.objective = objective
|
||||
self.use_general = not old_model
|
||||
self.use_legacy = legacy_model
|
||||
|
||||
# Initialize players
|
||||
# We use three dummy player for the target position
|
||||
|
@ -83,7 +84,7 @@ class Env:
|
|||
card_play_data[key].sort()
|
||||
self._env.card_play_init(card_play_data)
|
||||
self.infoset = self._game_infoset
|
||||
return get_obs(self.infoset, self.use_general)
|
||||
return get_obs(self.infoset, self.use_general, self.use_legacy)
|
||||
else:
|
||||
self.total_round += 1
|
||||
bid_done = False
|
||||
|
@ -233,7 +234,7 @@ class Env:
|
|||
print("发牌情况: %i/%i %.1f%%" % (self.force_bid, self.total_round, self.force_bid / self.total_round * 100))
|
||||
self.force_bid = 0
|
||||
self.total_round = 0
|
||||
return get_obs(self.infoset, self.use_general), {
|
||||
return get_obs(self.infoset, self.use_general, self.use_legacy), {
|
||||
"bid_obs_buffer": bid_obs_buffer,
|
||||
"multiply_obs_buffer": multiply_obs_buffer
|
||||
}
|
||||
|
@ -270,7 +271,7 @@ class Env:
|
|||
}
|
||||
obs = None
|
||||
else:
|
||||
obs = get_obs(self.infoset, self.use_general)
|
||||
obs = get_obs(self.infoset, self.use_general, self.use_legacy)
|
||||
return obs, reward, done, {}
|
||||
|
||||
def _get_reward(self, pos):
|
||||
|
@ -381,7 +382,7 @@ class DummyAgent(object):
|
|||
self.action = action
|
||||
|
||||
|
||||
def get_obs(infoset, use_general=True):
|
||||
def get_obs(infoset, use_general=True, use_legacy = False):
|
||||
"""
|
||||
This function obtains observations with imperfect information
|
||||
from the infoset. It has three branches since we encode
|
||||
|
@ -408,16 +409,16 @@ def get_obs(infoset, use_general=True):
|
|||
if use_general:
|
||||
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
||||
raise ValueError('')
|
||||
return _get_obs_general(infoset, infoset.player_position)
|
||||
return _get_obs_general(infoset, infoset.player_position, use_legacy)
|
||||
else:
|
||||
if infoset.player_position == 'landlord':
|
||||
return _get_obs_landlord(infoset)
|
||||
return _get_obs_landlord(infoset, use_legacy)
|
||||
elif infoset.player_position == 'landlord_up':
|
||||
return _get_obs_landlord_up(infoset)
|
||||
return _get_obs_landlord_up(infoset, use_legacy)
|
||||
elif infoset.player_position == 'landlord_front':
|
||||
return _get_obs_landlord_front(infoset)
|
||||
return _get_obs_landlord_front(infoset, use_legacy)
|
||||
elif infoset.player_position == 'landlord_down':
|
||||
return _get_obs_landlord_down(infoset)
|
||||
return _get_obs_landlord_down(infoset, use_legacy)
|
||||
else:
|
||||
raise ValueError('')
|
||||
|
||||
|
@ -524,17 +525,23 @@ def _process_action_seq(sequence, length=20, new_model=True):
|
|||
return sequence
|
||||
|
||||
|
||||
def _get_one_hot_bomb(bomb_num):
|
||||
def _get_one_hot_bomb(bomb_num, use_legacy = False):
|
||||
"""
|
||||
A utility function to encode the number of bombs
|
||||
into one-hot representation.
|
||||
"""
|
||||
if use_legacy:
|
||||
one_hot = np.zeros(29)
|
||||
one_hot[bomb_num[0] + bomb_num[1]] = 1
|
||||
else:
|
||||
one_hot = np.zeros(56) # 14 + 15 + 27
|
||||
one_hot[bomb_num[0]] = 1
|
||||
one_hot[14 + bomb_num[1]] = 1
|
||||
one_hot[29 + bomb_num[2]] = 1
|
||||
return one_hot
|
||||
|
||||
|
||||
def _get_obs_landlord(infoset):
|
||||
def _get_obs_landlord(infoset, use_legacy = False):
|
||||
"""
|
||||
Obttain the landlord features. See Table 4 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
|
@ -593,7 +600,7 @@ def _get_obs_landlord(infoset):
|
|||
num_legal_actions, axis=0)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
infoset.bomb_num, use_legacy)
|
||||
bomb_num_batch = np.repeat(
|
||||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -634,7 +641,7 @@ def _get_obs_landlord(infoset):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_landlord_up(infoset):
|
||||
def _get_obs_landlord_up(infoset, use_legacy = False):
|
||||
"""
|
||||
Obttain the landlord_up features. See Table 5 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
|
@ -708,7 +715,7 @@ def _get_obs_landlord_up(infoset):
|
|||
num_legal_actions, axis=0)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
infoset.bomb_num, use_legacy)
|
||||
bomb_num_batch = np.repeat(
|
||||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -755,7 +762,7 @@ def _get_obs_landlord_up(infoset):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_landlord_front(infoset):
|
||||
def _get_obs_landlord_front(infoset, use_legacy = False):
|
||||
"""
|
||||
Obttain the landlord_front features. See Table 5 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
|
@ -829,7 +836,7 @@ def _get_obs_landlord_front(infoset):
|
|||
num_legal_actions, axis=0)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
infoset.bomb_num, use_legacy)
|
||||
bomb_num_batch = np.repeat(
|
||||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -876,7 +883,7 @@ def _get_obs_landlord_front(infoset):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_landlord_down(infoset):
|
||||
def _get_obs_landlord_down(infoset, use_legacy = False):
|
||||
"""
|
||||
Obttain the landlord_down features. See Table 5 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
|
@ -950,7 +957,7 @@ def _get_obs_landlord_down(infoset):
|
|||
num_legal_actions, axis=0)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
infoset.bomb_num, use_legacy)
|
||||
bomb_num_batch = np.repeat(
|
||||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -1212,7 +1219,7 @@ def _get_obs_general1(infoset, position):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_general(infoset, position):
|
||||
def _get_obs_general(infoset, position, use_legacy = False):
|
||||
num_legal_actions = len(infoset.legal_actions)
|
||||
my_handcards = _cards2array(infoset.player_hand_cards)
|
||||
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||
|
@ -1306,7 +1313,7 @@ def _get_obs_general(infoset, position):
|
|||
num_legal_actions, axis=0)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
infoset.bomb_num, use_legacy)
|
||||
bomb_num_batch = np.repeat(
|
||||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -1316,12 +1323,22 @@ def _get_obs_general(infoset, position):
|
|||
landlord_front_num_cards_left, # 25
|
||||
landlord_down_num_cards_left))
|
||||
|
||||
if use_legacy:
|
||||
x_batch = np.hstack((
|
||||
bid_info_batch, # 16
|
||||
bid_info_batch, # 20
|
||||
multiply_info_batch)) # 4
|
||||
x_no_action = np.hstack((
|
||||
bid_info,
|
||||
multiply_info))
|
||||
else:
|
||||
x_batch = np.hstack((
|
||||
bomb_num_batch, # 56
|
||||
bid_info_batch, # 20
|
||||
multiply_info_batch)) # 4
|
||||
x_no_action = np.hstack((
|
||||
bomb_num, # 56
|
||||
bid_info,
|
||||
multiply_info))
|
||||
z =np.vstack((
|
||||
num_cards_left,
|
||||
my_handcards, # 108
|
||||
|
|
|
@ -11,19 +11,16 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
|
|||
'8': 8, '9': 9, '10': 10, 'J': 11, 'Q': 12,
|
||||
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
|
||||
|
||||
bombs = [
|
||||
[[3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7],
|
||||
[8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10], [11, 11, 11, 11, 11, 11],
|
||||
[12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13], [14, 14, 14, 14, 14, 14], [17, 17, 17, 17, 17, 17],
|
||||
[3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7, 7],
|
||||
[8, 8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10, 10], [11, 11, 11, 11, 11, 11, 11],
|
||||
[12, 12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13, 13], [14, 14, 14, 14, 14, 14, 14],
|
||||
[17, 17, 17, 17, 17, 17, 17]],
|
||||
[[3, 3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6, 6, 6],
|
||||
[7, 7, 7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9, 9, 9], [10, 10, 10, 10, 10, 10, 10, 10],
|
||||
[11, 11, 11, 11, 11, 11, 11, 11], [12, 12, 12, 12, 12, 12, 12, 12], [13, 13, 13, 13, 13, 13, 13, 13],
|
||||
[14, 14, 14, 14, 14, 14, 14, 14], [17, 17, 17, 17, 17, 17, 17, 17],
|
||||
[20, 20, 30, 30]]]
|
||||
cards_idx = [x for x in range(3, 15)]
|
||||
cards_idx.extend([17, 20, 30])
|
||||
|
||||
bombs = [[[x] * 6 for x in cards_idx[:-2]], [[x] * 8 for x in cards_idx[:-2]], [[x] * 4 for x in cards_idx[:-2]]]
|
||||
# Rocket bomb
|
||||
bombs[0].extend([[x] * 7 for x in cards_idx[:-2]])
|
||||
# King bomb
|
||||
bombs[1].extend([[20, 20, 30, 30]])
|
||||
# Normal bomb
|
||||
bombs[2].extend([[x] * 5 for x in cards_idx[:-2]])
|
||||
|
||||
class GameEnv(object):
|
||||
|
||||
|
@ -63,7 +60,7 @@ class GameEnv(object):
|
|||
'landlord_front': InfoSet('landlord_front'),
|
||||
'landlord_down': InfoSet('landlord_down')}
|
||||
|
||||
self.bomb_num = [0, 0]
|
||||
self.bomb_num = [0, 0, 0]
|
||||
self.pos_bomb_num = {
|
||||
"landlord": 0,
|
||||
"landlord_up": 0,
|
||||
|
@ -162,6 +159,9 @@ class GameEnv(object):
|
|||
self.bomb_num[1] += 1
|
||||
self.pos_bomb_num[self.acting_player_position] += 1
|
||||
|
||||
if action in bombs[2]:
|
||||
self.bomb_num[2] += 1
|
||||
|
||||
self.last_move_dict[
|
||||
self.acting_player_position] = action.copy()
|
||||
|
||||
|
@ -367,7 +367,7 @@ class GameEnv(object):
|
|||
'landlord_front': InfoSet('landlord_front'),
|
||||
'landlord_down': InfoSet('landlord_down')}
|
||||
|
||||
self.bomb_num = [0, 0]
|
||||
self.bomb_num = [0, 0, 0]
|
||||
self.pos_bomb_num = {
|
||||
"landlord": 0,
|
||||
"landlord_up": 0,
|
||||
|
|
Loading…
Reference in New Issue