取回legacy_mode参数代码

This commit is contained in:
zhiyang7 2021-12-20 10:02:55 +08:00
parent 08cf434500
commit dba179db0e
4 changed files with 137 additions and 26 deletions

View File

@ -113,6 +113,104 @@ class FarmerLstmModel(nn.Module):
x = self.dense6(x) x = self.dense6(x)
return dict(values=x) return dict(values=x)
class LandlordLstmModelLegacy(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(860 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 432, requires_grad=True, device=device),
torch.randn(1, 887, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class FarmerLstmModelLegacy(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(1192 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 432, requires_grad=True, device=device),
torch.randn(1, 1219, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
# 用于ResNet18和34的残差块用的是2个3x3的卷积 # 用于ResNet18和34的残差块用的是2个3x3的卷积
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
@ -214,6 +312,11 @@ model_dict['landlord'] = LandlordLstmModel
model_dict['landlord_up'] = FarmerLstmModel model_dict['landlord_up'] = FarmerLstmModel
model_dict['landlord_front'] = FarmerLstmModel model_dict['landlord_front'] = FarmerLstmModel
model_dict['landlord_down'] = FarmerLstmModel model_dict['landlord_down'] = FarmerLstmModel
model_dict_legacy = {}
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy
model_dict_new = {} model_dict_new = {}
model_dict_new['landlord'] = GeneralModel model_dict_new['landlord'] = GeneralModel
model_dict_new['landlord_up'] = GeneralModel model_dict_new['landlord_up'] = GeneralModel

View File

@ -41,7 +41,7 @@ log.setLevel(logging.INFO)
Buffers = typing.Dict[str, typing.List[torch.Tensor]] Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags): def create_env(flags):
return Env(flags.objective, flags.old_model) return Env(flags.objective, flags.old_model, flags.lagacy_model)
def get_batch(b_queues, position, flags, lock): def get_batch(b_queues, position, flags, lock):
""" """

44
douzero/env/env.py vendored
View File

@ -33,7 +33,7 @@ class Env:
Doudizhu multi-agent wrapper Doudizhu multi-agent wrapper
""" """
def __init__(self, objective, old_model): def __init__(self, objective, old_model, legacy_model=False):
""" """
Objective is wp/adp/logadp. It indicates whether considers Objective is wp/adp/logadp. It indicates whether considers
bomb in reward calculation. Here, we use dummy agents. bomb in reward calculation. Here, we use dummy agents.
@ -108,7 +108,7 @@ class Env:
self._env.info_sets[pos].player_id = pid self._env.info_sets[pos].player_id = pid
self.infoset = self._game_infoset self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general) return get_obs(self.infoset, self.use_general, self.use_legacy)
def step(self, action): def step(self, action):
""" """
@ -234,7 +234,7 @@ class DummyAgent(object):
self.action = action self.action = action
def get_obs(infoset, use_general=True): def get_obs(infoset, use_general=True, use_legacy = False):
""" """
This function obtains observations with imperfect information This function obtains observations with imperfect information
from the infoset. It has three branches since we encode from the infoset. It has three branches since we encode
@ -264,13 +264,13 @@ def get_obs(infoset, use_general=True):
return _get_obs_general(infoset, infoset.player_position) return _get_obs_general(infoset, infoset.player_position)
else: else:
if infoset.player_position == 'landlord': if infoset.player_position == 'landlord':
return _get_obs_landlord(infoset) return _get_obs_landlord(infoset, use_legacy)
elif infoset.player_position == 'landlord_up': elif infoset.player_position == 'landlord_up':
return _get_obs_landlord_up(infoset) return _get_obs_landlord_up(infoset, use_legacy)
elif infoset.player_position == 'landlord_front': elif infoset.player_position == 'landlord_front':
return _get_obs_landlord_front(infoset) return _get_obs_landlord_front(infoset, use_legacy)
elif infoset.player_position == 'landlord_down': elif infoset.player_position == 'landlord_down':
return _get_obs_landlord_down(infoset) return _get_obs_landlord_down(infoset, use_legacy)
else: else:
raise ValueError('') raise ValueError('')
@ -377,19 +377,23 @@ def _process_action_seq(sequence, length=20, new_model=True):
return sequence return sequence
def _get_one_hot_bomb(bomb_num): def _get_one_hot_bomb(bomb_num, use_legacy = False):
""" """
A utility function to encode the number of bombs A utility function to encode the number of bombs
into one-hot representation. into one-hot representation.
""" """
one_hot = np.zeros(56) # 14 + 15 + 27 if use_legacy:
one_hot[bomb_num[0]] = 1 one_hot = np.zeros(29)
one_hot[14 + bomb_num[1]] = 1 one_hot[bomb_num[0] + bomb_num[1]] = 1
one_hot[29 + bomb_num[2]] = 1 else:
one_hot = np.zeros(56) # 14 + 15 + 27
one_hot[bomb_num[0]] = 1
one_hot[14 + bomb_num[1]] = 1
one_hot[29 + bomb_num[2]] = 1
return one_hot return one_hot
def _get_obs_landlord(infoset): def _get_obs_landlord(infoset, use_legacy = False):
""" """
Obttain the landlord features. See Table 4 in Obttain the landlord features. See Table 4 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
@ -448,7 +452,7 @@ def _get_obs_landlord(infoset):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -489,7 +493,7 @@ def _get_obs_landlord(infoset):
} }
return obs return obs
def _get_obs_landlord_up(infoset): def _get_obs_landlord_up(infoset, use_legacy = False):
""" """
Obttain the landlord_up features. See Table 5 in Obttain the landlord_up features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
@ -563,7 +567,7 @@ def _get_obs_landlord_up(infoset):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -610,7 +614,7 @@ def _get_obs_landlord_up(infoset):
} }
return obs return obs
def _get_obs_landlord_front(infoset): def _get_obs_landlord_front(infoset, use_legacy = False):
""" """
Obttain the landlord_front features. See Table 5 in Obttain the landlord_front features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
@ -684,7 +688,7 @@ def _get_obs_landlord_front(infoset):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -731,7 +735,7 @@ def _get_obs_landlord_front(infoset):
} }
return obs return obs
def _get_obs_landlord_down(infoset): def _get_obs_landlord_down(infoset, use_legacy = False):
""" """
Obttain the landlord_down features. See Table 5 in Obttain the landlord_down features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf https://arxiv.org/pdf/2106.06135.pdf
@ -805,7 +809,7 @@ def _get_obs_landlord_down(infoset):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)

View File

@ -3,13 +3,16 @@ import numpy as np
from douzero.env.env import get_obs from douzero.env.env import get_obs
def _load_model(position, model_path, model_type): def _load_model(position, model_path, model_type, use_legacy):
from douzero.dmc.models import model_dict_new, model_dict, model_dict_new_legacy, model_dict_legacy from douzero.dmc.models import model_dict_new, model_dict, model_dict_legacy
model = None model = None
if model_type == "general": if model_type == "general":
model = model_dict_new[position]() model = model_dict_new[position]()
else: else:
model = model_dict[position]() if use_legacy:
model = model_dict_legacy[position]()
else:
model = model_dict[position]()
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
if torch.cuda.is_available(): if torch.cuda.is_available():
pretrained = torch.load(model_path, map_location='cuda:0') pretrained = torch.load(model_path, map_location='cuda:0')
@ -27,8 +30,9 @@ def _load_model(position, model_path, model_type):
class DeepAgent: class DeepAgent:
def __init__(self, position, model_path): def __init__(self, position, model_path):
self.use_legacy = True if "legacy" in model_path else False
self.model_type = "general" if "resnet" in model_path else "old" self.model_type = "general" if "resnet" in model_path else "old"
self.model = _load_model(position, model_path, self.model_type) self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'} 13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
@ -36,7 +40,7 @@ class DeepAgent:
if len(infoset.legal_actions) == 1: if len(infoset.legal_actions) == 1:
return infoset.legal_actions[0] return infoset.legal_actions[0]
obs = get_obs(infoset, self.model_type == "general") obs = get_obs(infoset, self.model_type == "general", self.use_legacy)
z_batch = torch.from_numpy(obs['z_batch']).float() z_batch = torch.from_numpy(obs['z_batch']).float()
x_batch = torch.from_numpy(obs['x_batch']).float() x_batch = torch.from_numpy(obs['x_batch']).float()