取回legacy_mode参数代码
This commit is contained in:
parent
08cf434500
commit
dba179db0e
|
@ -113,6 +113,104 @@ class FarmerLstmModel(nn.Module):
|
|||
x = self.dense6(x)
|
||||
return dict(values=x)
|
||||
|
||||
class LandlordLstmModelLegacy(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(860 + 128, 1024)
|
||||
self.dense2 = nn.Linear(1024, 1024)
|
||||
self.dense3 = nn.Linear(1024, 768)
|
||||
self.dense4 = nn.Linear(768, 512)
|
||||
self.dense5 = nn.Linear(512, 256)
|
||||
self.dense6 = nn.Linear(256, 1)
|
||||
|
||||
def get_onnx_params(self, device):
|
||||
return {
|
||||
'args': (
|
||||
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
||||
torch.randn(1, 887, requires_grad=True, device=device),
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "batch_size"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "batch_size"
|
||||
},
|
||||
'values': {
|
||||
0: "batch_size"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def forward(self, z, x):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
x = torch.cat([lstm_out,x], dim=-1)
|
||||
x = self.dense1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense3(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense4(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense5(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense6(x)
|
||||
return dict(values=x)
|
||||
|
||||
class FarmerLstmModelLegacy(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(1192 + 128, 1024)
|
||||
self.dense2 = nn.Linear(1024, 1024)
|
||||
self.dense3 = nn.Linear(1024, 768)
|
||||
self.dense4 = nn.Linear(768, 512)
|
||||
self.dense5 = nn.Linear(512, 256)
|
||||
self.dense6 = nn.Linear(256, 1)
|
||||
|
||||
def get_onnx_params(self, device):
|
||||
return {
|
||||
'args': (
|
||||
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
||||
torch.randn(1, 1219, requires_grad=True, device=device),
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "legal_actions"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "legal_actions"
|
||||
},
|
||||
'values': {
|
||||
0: "batch_size"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def forward(self, z, x):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
x = torch.cat([lstm_out,x], dim=-1)
|
||||
x = self.dense1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense3(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense4(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense5(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense6(x)
|
||||
return dict(values=x)
|
||||
|
||||
# 用于ResNet18和34的残差块,用的是2个3x3的卷积
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
@ -214,6 +312,11 @@ model_dict['landlord'] = LandlordLstmModel
|
|||
model_dict['landlord_up'] = FarmerLstmModel
|
||||
model_dict['landlord_front'] = FarmerLstmModel
|
||||
model_dict['landlord_down'] = FarmerLstmModel
|
||||
model_dict_legacy = {}
|
||||
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
|
||||
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
|
||||
model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy
|
||||
model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy
|
||||
model_dict_new = {}
|
||||
model_dict_new['landlord'] = GeneralModel
|
||||
model_dict_new['landlord_up'] = GeneralModel
|
||||
|
|
|
@ -41,7 +41,7 @@ log.setLevel(logging.INFO)
|
|||
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
||||
|
||||
def create_env(flags):
|
||||
return Env(flags.objective, flags.old_model)
|
||||
return Env(flags.objective, flags.old_model, flags.lagacy_model)
|
||||
|
||||
def get_batch(b_queues, position, flags, lock):
|
||||
"""
|
||||
|
|
|
@ -33,7 +33,7 @@ class Env:
|
|||
Doudizhu multi-agent wrapper
|
||||
"""
|
||||
|
||||
def __init__(self, objective, old_model):
|
||||
def __init__(self, objective, old_model, legacy_model=False):
|
||||
"""
|
||||
Objective is wp/adp/logadp. It indicates whether considers
|
||||
bomb in reward calculation. Here, we use dummy agents.
|
||||
|
@ -108,7 +108,7 @@ class Env:
|
|||
self._env.info_sets[pos].player_id = pid
|
||||
self.infoset = self._game_infoset
|
||||
|
||||
return get_obs(self.infoset, self.use_general)
|
||||
return get_obs(self.infoset, self.use_general, self.use_legacy)
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
|
@ -234,7 +234,7 @@ class DummyAgent(object):
|
|||
self.action = action
|
||||
|
||||
|
||||
def get_obs(infoset, use_general=True):
|
||||
def get_obs(infoset, use_general=True, use_legacy = False):
|
||||
"""
|
||||
This function obtains observations with imperfect information
|
||||
from the infoset. It has three branches since we encode
|
||||
|
@ -264,13 +264,13 @@ def get_obs(infoset, use_general=True):
|
|||
return _get_obs_general(infoset, infoset.player_position)
|
||||
else:
|
||||
if infoset.player_position == 'landlord':
|
||||
return _get_obs_landlord(infoset)
|
||||
return _get_obs_landlord(infoset, use_legacy)
|
||||
elif infoset.player_position == 'landlord_up':
|
||||
return _get_obs_landlord_up(infoset)
|
||||
return _get_obs_landlord_up(infoset, use_legacy)
|
||||
elif infoset.player_position == 'landlord_front':
|
||||
return _get_obs_landlord_front(infoset)
|
||||
return _get_obs_landlord_front(infoset, use_legacy)
|
||||
elif infoset.player_position == 'landlord_down':
|
||||
return _get_obs_landlord_down(infoset)
|
||||
return _get_obs_landlord_down(infoset, use_legacy)
|
||||
else:
|
||||
raise ValueError('')
|
||||
|
||||
|
@ -377,11 +377,15 @@ def _process_action_seq(sequence, length=20, new_model=True):
|
|||
return sequence
|
||||
|
||||
|
||||
def _get_one_hot_bomb(bomb_num):
|
||||
def _get_one_hot_bomb(bomb_num, use_legacy = False):
|
||||
"""
|
||||
A utility function to encode the number of bombs
|
||||
into one-hot representation.
|
||||
"""
|
||||
if use_legacy:
|
||||
one_hot = np.zeros(29)
|
||||
one_hot[bomb_num[0] + bomb_num[1]] = 1
|
||||
else:
|
||||
one_hot = np.zeros(56) # 14 + 15 + 27
|
||||
one_hot[bomb_num[0]] = 1
|
||||
one_hot[14 + bomb_num[1]] = 1
|
||||
|
@ -389,7 +393,7 @@ def _get_one_hot_bomb(bomb_num):
|
|||
return one_hot
|
||||
|
||||
|
||||
def _get_obs_landlord(infoset):
|
||||
def _get_obs_landlord(infoset, use_legacy = False):
|
||||
"""
|
||||
Obttain the landlord features. See Table 4 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
|
@ -448,7 +452,7 @@ def _get_obs_landlord(infoset):
|
|||
num_legal_actions, axis=0)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
infoset.bomb_num, use_legacy)
|
||||
bomb_num_batch = np.repeat(
|
||||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -489,7 +493,7 @@ def _get_obs_landlord(infoset):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_landlord_up(infoset):
|
||||
def _get_obs_landlord_up(infoset, use_legacy = False):
|
||||
"""
|
||||
Obttain the landlord_up features. See Table 5 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
|
@ -563,7 +567,7 @@ def _get_obs_landlord_up(infoset):
|
|||
num_legal_actions, axis=0)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
infoset.bomb_num, use_legacy)
|
||||
bomb_num_batch = np.repeat(
|
||||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -610,7 +614,7 @@ def _get_obs_landlord_up(infoset):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_landlord_front(infoset):
|
||||
def _get_obs_landlord_front(infoset, use_legacy = False):
|
||||
"""
|
||||
Obttain the landlord_front features. See Table 5 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
|
@ -684,7 +688,7 @@ def _get_obs_landlord_front(infoset):
|
|||
num_legal_actions, axis=0)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
infoset.bomb_num, use_legacy)
|
||||
bomb_num_batch = np.repeat(
|
||||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
@ -731,7 +735,7 @@ def _get_obs_landlord_front(infoset):
|
|||
}
|
||||
return obs
|
||||
|
||||
def _get_obs_landlord_down(infoset):
|
||||
def _get_obs_landlord_down(infoset, use_legacy = False):
|
||||
"""
|
||||
Obttain the landlord_down features. See Table 5 in
|
||||
https://arxiv.org/pdf/2106.06135.pdf
|
||||
|
@ -805,7 +809,7 @@ def _get_obs_landlord_down(infoset):
|
|||
num_legal_actions, axis=0)
|
||||
|
||||
bomb_num = _get_one_hot_bomb(
|
||||
infoset.bomb_num)
|
||||
infoset.bomb_num, use_legacy)
|
||||
bomb_num_batch = np.repeat(
|
||||
bomb_num[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
|
|
@ -3,11 +3,14 @@ import numpy as np
|
|||
|
||||
from douzero.env.env import get_obs
|
||||
|
||||
def _load_model(position, model_path, model_type):
|
||||
from douzero.dmc.models import model_dict_new, model_dict, model_dict_new_legacy, model_dict_legacy
|
||||
def _load_model(position, model_path, model_type, use_legacy):
|
||||
from douzero.dmc.models import model_dict_new, model_dict, model_dict_legacy
|
||||
model = None
|
||||
if model_type == "general":
|
||||
model = model_dict_new[position]()
|
||||
else:
|
||||
if use_legacy:
|
||||
model = model_dict_legacy[position]()
|
||||
else:
|
||||
model = model_dict[position]()
|
||||
model_state_dict = model.state_dict()
|
||||
|
@ -27,8 +30,9 @@ def _load_model(position, model_path, model_type):
|
|||
class DeepAgent:
|
||||
|
||||
def __init__(self, position, model_path):
|
||||
self.use_legacy = True if "legacy" in model_path else False
|
||||
self.model_type = "general" if "resnet" in model_path else "old"
|
||||
self.model = _load_model(position, model_path, self.model_type)
|
||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
|
||||
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
||||
|
@ -36,7 +40,7 @@ class DeepAgent:
|
|||
if len(infoset.legal_actions) == 1:
|
||||
return infoset.legal_actions[0]
|
||||
|
||||
obs = get_obs(infoset, self.model_type == "general")
|
||||
obs = get_obs(infoset, self.model_type == "general", self.use_legacy)
|
||||
|
||||
z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||
x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||
|
|
Loading…
Reference in New Issue