取回legacy_mode参数代码
This commit is contained in:
parent
08cf434500
commit
dba179db0e
|
@ -113,6 +113,104 @@ class FarmerLstmModel(nn.Module):
|
||||||
x = self.dense6(x)
|
x = self.dense6(x)
|
||||||
return dict(values=x)
|
return dict(values=x)
|
||||||
|
|
||||||
|
class LandlordLstmModelLegacy(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
||||||
|
self.dense1 = nn.Linear(860 + 128, 1024)
|
||||||
|
self.dense2 = nn.Linear(1024, 1024)
|
||||||
|
self.dense3 = nn.Linear(1024, 768)
|
||||||
|
self.dense4 = nn.Linear(768, 512)
|
||||||
|
self.dense5 = nn.Linear(512, 256)
|
||||||
|
self.dense6 = nn.Linear(256, 1)
|
||||||
|
|
||||||
|
def get_onnx_params(self, device):
|
||||||
|
return {
|
||||||
|
'args': (
|
||||||
|
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
||||||
|
torch.randn(1, 887, requires_grad=True, device=device),
|
||||||
|
),
|
||||||
|
'input_names': ['z_batch','x_batch'],
|
||||||
|
'output_names': ['values'],
|
||||||
|
'dynamic_axes': {
|
||||||
|
'z_batch': {
|
||||||
|
0: "batch_size"
|
||||||
|
},
|
||||||
|
'x_batch': {
|
||||||
|
0: "batch_size"
|
||||||
|
},
|
||||||
|
'values': {
|
||||||
|
0: "batch_size"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, z, x):
|
||||||
|
lstm_out, (h_n, _) = self.lstm(z)
|
||||||
|
lstm_out = lstm_out[:,-1,:]
|
||||||
|
x = torch.cat([lstm_out,x], dim=-1)
|
||||||
|
x = self.dense1(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense2(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense3(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense4(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense5(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense6(x)
|
||||||
|
return dict(values=x)
|
||||||
|
|
||||||
|
class FarmerLstmModelLegacy(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
||||||
|
self.dense1 = nn.Linear(1192 + 128, 1024)
|
||||||
|
self.dense2 = nn.Linear(1024, 1024)
|
||||||
|
self.dense3 = nn.Linear(1024, 768)
|
||||||
|
self.dense4 = nn.Linear(768, 512)
|
||||||
|
self.dense5 = nn.Linear(512, 256)
|
||||||
|
self.dense6 = nn.Linear(256, 1)
|
||||||
|
|
||||||
|
def get_onnx_params(self, device):
|
||||||
|
return {
|
||||||
|
'args': (
|
||||||
|
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
||||||
|
torch.randn(1, 1219, requires_grad=True, device=device),
|
||||||
|
),
|
||||||
|
'input_names': ['z_batch','x_batch'],
|
||||||
|
'output_names': ['values'],
|
||||||
|
'dynamic_axes': {
|
||||||
|
'z_batch': {
|
||||||
|
0: "legal_actions"
|
||||||
|
},
|
||||||
|
'x_batch': {
|
||||||
|
0: "legal_actions"
|
||||||
|
},
|
||||||
|
'values': {
|
||||||
|
0: "batch_size"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, z, x):
|
||||||
|
lstm_out, (h_n, _) = self.lstm(z)
|
||||||
|
lstm_out = lstm_out[:,-1,:]
|
||||||
|
x = torch.cat([lstm_out,x], dim=-1)
|
||||||
|
x = self.dense1(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense2(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense3(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense4(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense5(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.dense6(x)
|
||||||
|
return dict(values=x)
|
||||||
|
|
||||||
# 用于ResNet18和34的残差块,用的是2个3x3的卷积
|
# 用于ResNet18和34的残差块,用的是2个3x3的卷积
|
||||||
class BasicBlock(nn.Module):
|
class BasicBlock(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
@ -214,6 +312,11 @@ model_dict['landlord'] = LandlordLstmModel
|
||||||
model_dict['landlord_up'] = FarmerLstmModel
|
model_dict['landlord_up'] = FarmerLstmModel
|
||||||
model_dict['landlord_front'] = FarmerLstmModel
|
model_dict['landlord_front'] = FarmerLstmModel
|
||||||
model_dict['landlord_down'] = FarmerLstmModel
|
model_dict['landlord_down'] = FarmerLstmModel
|
||||||
|
model_dict_legacy = {}
|
||||||
|
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
|
||||||
|
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
|
||||||
|
model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy
|
||||||
|
model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy
|
||||||
model_dict_new = {}
|
model_dict_new = {}
|
||||||
model_dict_new['landlord'] = GeneralModel
|
model_dict_new['landlord'] = GeneralModel
|
||||||
model_dict_new['landlord_up'] = GeneralModel
|
model_dict_new['landlord_up'] = GeneralModel
|
||||||
|
|
|
@ -41,7 +41,7 @@ log.setLevel(logging.INFO)
|
||||||
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
||||||
|
|
||||||
def create_env(flags):
|
def create_env(flags):
|
||||||
return Env(flags.objective, flags.old_model)
|
return Env(flags.objective, flags.old_model, flags.lagacy_model)
|
||||||
|
|
||||||
def get_batch(b_queues, position, flags, lock):
|
def get_batch(b_queues, position, flags, lock):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -33,7 +33,7 @@ class Env:
|
||||||
Doudizhu multi-agent wrapper
|
Doudizhu multi-agent wrapper
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, objective, old_model):
|
def __init__(self, objective, old_model, legacy_model=False):
|
||||||
"""
|
"""
|
||||||
Objective is wp/adp/logadp. It indicates whether considers
|
Objective is wp/adp/logadp. It indicates whether considers
|
||||||
bomb in reward calculation. Here, we use dummy agents.
|
bomb in reward calculation. Here, we use dummy agents.
|
||||||
|
@ -108,7 +108,7 @@ class Env:
|
||||||
self._env.info_sets[pos].player_id = pid
|
self._env.info_sets[pos].player_id = pid
|
||||||
self.infoset = self._game_infoset
|
self.infoset = self._game_infoset
|
||||||
|
|
||||||
return get_obs(self.infoset, self.use_general)
|
return get_obs(self.infoset, self.use_general, self.use_legacy)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""
|
"""
|
||||||
|
@ -234,7 +234,7 @@ class DummyAgent(object):
|
||||||
self.action = action
|
self.action = action
|
||||||
|
|
||||||
|
|
||||||
def get_obs(infoset, use_general=True):
|
def get_obs(infoset, use_general=True, use_legacy = False):
|
||||||
"""
|
"""
|
||||||
This function obtains observations with imperfect information
|
This function obtains observations with imperfect information
|
||||||
from the infoset. It has three branches since we encode
|
from the infoset. It has three branches since we encode
|
||||||
|
@ -264,13 +264,13 @@ def get_obs(infoset, use_general=True):
|
||||||
return _get_obs_general(infoset, infoset.player_position)
|
return _get_obs_general(infoset, infoset.player_position)
|
||||||
else:
|
else:
|
||||||
if infoset.player_position == 'landlord':
|
if infoset.player_position == 'landlord':
|
||||||
return _get_obs_landlord(infoset)
|
return _get_obs_landlord(infoset, use_legacy)
|
||||||
elif infoset.player_position == 'landlord_up':
|
elif infoset.player_position == 'landlord_up':
|
||||||
return _get_obs_landlord_up(infoset)
|
return _get_obs_landlord_up(infoset, use_legacy)
|
||||||
elif infoset.player_position == 'landlord_front':
|
elif infoset.player_position == 'landlord_front':
|
||||||
return _get_obs_landlord_front(infoset)
|
return _get_obs_landlord_front(infoset, use_legacy)
|
||||||
elif infoset.player_position == 'landlord_down':
|
elif infoset.player_position == 'landlord_down':
|
||||||
return _get_obs_landlord_down(infoset)
|
return _get_obs_landlord_down(infoset, use_legacy)
|
||||||
else:
|
else:
|
||||||
raise ValueError('')
|
raise ValueError('')
|
||||||
|
|
||||||
|
@ -377,19 +377,23 @@ def _process_action_seq(sequence, length=20, new_model=True):
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
def _get_one_hot_bomb(bomb_num):
|
def _get_one_hot_bomb(bomb_num, use_legacy = False):
|
||||||
"""
|
"""
|
||||||
A utility function to encode the number of bombs
|
A utility function to encode the number of bombs
|
||||||
into one-hot representation.
|
into one-hot representation.
|
||||||
"""
|
"""
|
||||||
one_hot = np.zeros(56) # 14 + 15 + 27
|
if use_legacy:
|
||||||
one_hot[bomb_num[0]] = 1
|
one_hot = np.zeros(29)
|
||||||
one_hot[14 + bomb_num[1]] = 1
|
one_hot[bomb_num[0] + bomb_num[1]] = 1
|
||||||
one_hot[29 + bomb_num[2]] = 1
|
else:
|
||||||
|
one_hot = np.zeros(56) # 14 + 15 + 27
|
||||||
|
one_hot[bomb_num[0]] = 1
|
||||||
|
one_hot[14 + bomb_num[1]] = 1
|
||||||
|
one_hot[29 + bomb_num[2]] = 1
|
||||||
return one_hot
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
def _get_obs_landlord(infoset):
|
def _get_obs_landlord(infoset, use_legacy = False):
|
||||||
"""
|
"""
|
||||||
Obttain the landlord features. See Table 4 in
|
Obttain the landlord features. See Table 4 in
|
||||||
https://arxiv.org/pdf/2106.06135.pdf
|
https://arxiv.org/pdf/2106.06135.pdf
|
||||||
|
@ -448,7 +452,7 @@ def _get_obs_landlord(infoset):
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
bomb_num = _get_one_hot_bomb(
|
bomb_num = _get_one_hot_bomb(
|
||||||
infoset.bomb_num)
|
infoset.bomb_num, use_legacy)
|
||||||
bomb_num_batch = np.repeat(
|
bomb_num_batch = np.repeat(
|
||||||
bomb_num[np.newaxis, :],
|
bomb_num[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -489,7 +493,7 @@ def _get_obs_landlord(infoset):
|
||||||
}
|
}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _get_obs_landlord_up(infoset):
|
def _get_obs_landlord_up(infoset, use_legacy = False):
|
||||||
"""
|
"""
|
||||||
Obttain the landlord_up features. See Table 5 in
|
Obttain the landlord_up features. See Table 5 in
|
||||||
https://arxiv.org/pdf/2106.06135.pdf
|
https://arxiv.org/pdf/2106.06135.pdf
|
||||||
|
@ -563,7 +567,7 @@ def _get_obs_landlord_up(infoset):
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
bomb_num = _get_one_hot_bomb(
|
bomb_num = _get_one_hot_bomb(
|
||||||
infoset.bomb_num)
|
infoset.bomb_num, use_legacy)
|
||||||
bomb_num_batch = np.repeat(
|
bomb_num_batch = np.repeat(
|
||||||
bomb_num[np.newaxis, :],
|
bomb_num[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -610,7 +614,7 @@ def _get_obs_landlord_up(infoset):
|
||||||
}
|
}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _get_obs_landlord_front(infoset):
|
def _get_obs_landlord_front(infoset, use_legacy = False):
|
||||||
"""
|
"""
|
||||||
Obttain the landlord_front features. See Table 5 in
|
Obttain the landlord_front features. See Table 5 in
|
||||||
https://arxiv.org/pdf/2106.06135.pdf
|
https://arxiv.org/pdf/2106.06135.pdf
|
||||||
|
@ -684,7 +688,7 @@ def _get_obs_landlord_front(infoset):
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
bomb_num = _get_one_hot_bomb(
|
bomb_num = _get_one_hot_bomb(
|
||||||
infoset.bomb_num)
|
infoset.bomb_num, use_legacy)
|
||||||
bomb_num_batch = np.repeat(
|
bomb_num_batch = np.repeat(
|
||||||
bomb_num[np.newaxis, :],
|
bomb_num[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
@ -731,7 +735,7 @@ def _get_obs_landlord_front(infoset):
|
||||||
}
|
}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _get_obs_landlord_down(infoset):
|
def _get_obs_landlord_down(infoset, use_legacy = False):
|
||||||
"""
|
"""
|
||||||
Obttain the landlord_down features. See Table 5 in
|
Obttain the landlord_down features. See Table 5 in
|
||||||
https://arxiv.org/pdf/2106.06135.pdf
|
https://arxiv.org/pdf/2106.06135.pdf
|
||||||
|
@ -805,7 +809,7 @@ def _get_obs_landlord_down(infoset):
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
bomb_num = _get_one_hot_bomb(
|
bomb_num = _get_one_hot_bomb(
|
||||||
infoset.bomb_num)
|
infoset.bomb_num, use_legacy)
|
||||||
bomb_num_batch = np.repeat(
|
bomb_num_batch = np.repeat(
|
||||||
bomb_num[np.newaxis, :],
|
bomb_num[np.newaxis, :],
|
||||||
num_legal_actions, axis=0)
|
num_legal_actions, axis=0)
|
||||||
|
|
|
@ -3,13 +3,16 @@ import numpy as np
|
||||||
|
|
||||||
from douzero.env.env import get_obs
|
from douzero.env.env import get_obs
|
||||||
|
|
||||||
def _load_model(position, model_path, model_type):
|
def _load_model(position, model_path, model_type, use_legacy):
|
||||||
from douzero.dmc.models import model_dict_new, model_dict, model_dict_new_legacy, model_dict_legacy
|
from douzero.dmc.models import model_dict_new, model_dict, model_dict_legacy
|
||||||
model = None
|
model = None
|
||||||
if model_type == "general":
|
if model_type == "general":
|
||||||
model = model_dict_new[position]()
|
model = model_dict_new[position]()
|
||||||
else:
|
else:
|
||||||
model = model_dict[position]()
|
if use_legacy:
|
||||||
|
model = model_dict_legacy[position]()
|
||||||
|
else:
|
||||||
|
model = model_dict[position]()
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
pretrained = torch.load(model_path, map_location='cuda:0')
|
pretrained = torch.load(model_path, map_location='cuda:0')
|
||||||
|
@ -27,8 +30,9 @@ def _load_model(position, model_path, model_type):
|
||||||
class DeepAgent:
|
class DeepAgent:
|
||||||
|
|
||||||
def __init__(self, position, model_path):
|
def __init__(self, position, model_path):
|
||||||
|
self.use_legacy = True if "legacy" in model_path else False
|
||||||
self.model_type = "general" if "resnet" in model_path else "old"
|
self.model_type = "general" if "resnet" in model_path else "old"
|
||||||
self.model = _load_model(position, model_path, self.model_type)
|
self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
|
||||||
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||||
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
||||||
|
@ -36,7 +40,7 @@ class DeepAgent:
|
||||||
if len(infoset.legal_actions) == 1:
|
if len(infoset.legal_actions) == 1:
|
||||||
return infoset.legal_actions[0]
|
return infoset.legal_actions[0]
|
||||||
|
|
||||||
obs = get_obs(infoset, self.model_type == "general")
|
obs = get_obs(infoset, self.model_type == "general", self.use_legacy)
|
||||||
|
|
||||||
z_batch = torch.from_numpy(obs['z_batch']).float()
|
z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||||
x_batch = torch.from_numpy(obs['x_batch']).float()
|
x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||||
|
|
Loading…
Reference in New Issue