取回legacy_mode参数代码

This commit is contained in:
zhiyang7 2021-12-20 10:02:55 +08:00
parent 08cf434500
commit dba179db0e
4 changed files with 137 additions and 26 deletions

View File

@ -113,6 +113,104 @@ class FarmerLstmModel(nn.Module):
x = self.dense6(x)
return dict(values=x)
class LandlordLstmModelLegacy(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(860 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 432, requires_grad=True, device=device),
torch.randn(1, 887, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class FarmerLstmModelLegacy(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(1192 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 432, requires_grad=True, device=device),
torch.randn(1, 1219, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
# 用于ResNet18和34的残差块用的是2个3x3的卷积
class BasicBlock(nn.Module):
expansion = 1
@ -214,6 +312,11 @@ model_dict['landlord'] = LandlordLstmModel
model_dict['landlord_up'] = FarmerLstmModel
model_dict['landlord_front'] = FarmerLstmModel
model_dict['landlord_down'] = FarmerLstmModel
model_dict_legacy = {}
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy
model_dict_new = {}
model_dict_new['landlord'] = GeneralModel
model_dict_new['landlord_up'] = GeneralModel

View File

@ -41,7 +41,7 @@ log.setLevel(logging.INFO)
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags):
return Env(flags.objective, flags.old_model)
return Env(flags.objective, flags.old_model, flags.lagacy_model)
def get_batch(b_queues, position, flags, lock):
"""

36
douzero/env/env.py vendored
View File

@ -33,7 +33,7 @@ class Env:
Doudizhu multi-agent wrapper
"""
def __init__(self, objective, old_model):
def __init__(self, objective, old_model, legacy_model=False):
"""
Objective is wp/adp/logadp. It indicates whether considers
bomb in reward calculation. Here, we use dummy agents.
@ -108,7 +108,7 @@ class Env:
self._env.info_sets[pos].player_id = pid
self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general)
return get_obs(self.infoset, self.use_general, self.use_legacy)
def step(self, action):
"""
@ -234,7 +234,7 @@ class DummyAgent(object):
self.action = action
def get_obs(infoset, use_general=True):
def get_obs(infoset, use_general=True, use_legacy = False):
"""
This function obtains observations with imperfect information
from the infoset. It has three branches since we encode
@ -264,13 +264,13 @@ def get_obs(infoset, use_general=True):
return _get_obs_general(infoset, infoset.player_position)
else:
if infoset.player_position == 'landlord':
return _get_obs_landlord(infoset)
return _get_obs_landlord(infoset, use_legacy)
elif infoset.player_position == 'landlord_up':
return _get_obs_landlord_up(infoset)
return _get_obs_landlord_up(infoset, use_legacy)
elif infoset.player_position == 'landlord_front':
return _get_obs_landlord_front(infoset)
return _get_obs_landlord_front(infoset, use_legacy)
elif infoset.player_position == 'landlord_down':
return _get_obs_landlord_down(infoset)
return _get_obs_landlord_down(infoset, use_legacy)
else:
raise ValueError('')
@ -377,11 +377,15 @@ def _process_action_seq(sequence, length=20, new_model=True):
return sequence
def _get_one_hot_bomb(bomb_num):
def _get_one_hot_bomb(bomb_num, use_legacy = False):
"""
A utility function to encode the number of bombs
into one-hot representation.
"""
if use_legacy:
one_hot = np.zeros(29)
one_hot[bomb_num[0] + bomb_num[1]] = 1
else:
one_hot = np.zeros(56) # 14 + 15 + 27
one_hot[bomb_num[0]] = 1
one_hot[14 + bomb_num[1]] = 1
@ -389,7 +393,7 @@ def _get_one_hot_bomb(bomb_num):
return one_hot
def _get_obs_landlord(infoset):
def _get_obs_landlord(infoset, use_legacy = False):
"""
Obttain the landlord features. See Table 4 in
https://arxiv.org/pdf/2106.06135.pdf
@ -448,7 +452,7 @@ def _get_obs_landlord(infoset):
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)
infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
@ -489,7 +493,7 @@ def _get_obs_landlord(infoset):
}
return obs
def _get_obs_landlord_up(infoset):
def _get_obs_landlord_up(infoset, use_legacy = False):
"""
Obttain the landlord_up features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf
@ -563,7 +567,7 @@ def _get_obs_landlord_up(infoset):
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)
infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
@ -610,7 +614,7 @@ def _get_obs_landlord_up(infoset):
}
return obs
def _get_obs_landlord_front(infoset):
def _get_obs_landlord_front(infoset, use_legacy = False):
"""
Obttain the landlord_front features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf
@ -684,7 +688,7 @@ def _get_obs_landlord_front(infoset):
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)
infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
@ -731,7 +735,7 @@ def _get_obs_landlord_front(infoset):
}
return obs
def _get_obs_landlord_down(infoset):
def _get_obs_landlord_down(infoset, use_legacy = False):
"""
Obttain the landlord_down features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf
@ -805,7 +809,7 @@ def _get_obs_landlord_down(infoset):
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)
infoset.bomb_num, use_legacy)
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)

View File

@ -3,11 +3,14 @@ import numpy as np
from douzero.env.env import get_obs
def _load_model(position, model_path, model_type):
from douzero.dmc.models import model_dict_new, model_dict, model_dict_new_legacy, model_dict_legacy
def _load_model(position, model_path, model_type, use_legacy):
from douzero.dmc.models import model_dict_new, model_dict, model_dict_legacy
model = None
if model_type == "general":
model = model_dict_new[position]()
else:
if use_legacy:
model = model_dict_legacy[position]()
else:
model = model_dict[position]()
model_state_dict = model.state_dict()
@ -27,8 +30,9 @@ def _load_model(position, model_path, model_type):
class DeepAgent:
def __init__(self, position, model_path):
self.use_legacy = True if "legacy" in model_path else False
self.model_type = "general" if "resnet" in model_path else "old"
self.model = _load_model(position, model_path, self.model_type)
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
@ -36,7 +40,7 @@ class DeepAgent:
if len(infoset.legal_actions) == 1:
return infoset.legal_actions[0]
obs = get_obs(infoset, self.model_type == "general")
obs = get_obs(infoset, self.model_type == "general", self.use_legacy)
z_batch = torch.from_numpy(obs['z_batch']).float()
x_batch = torch.from_numpy(obs['x_batch']).float()