修复游戏规则BUG, 移除BID模型
This commit is contained in:
parent
5055a5c84d
commit
94d64889a7
|
@ -219,11 +219,15 @@ def train(flags):
|
|||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||
if flags.enable_onnx and position:
|
||||
model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position)
|
||||
onnx_params = learner_model.get_model(position).get_onnx_params()
|
||||
onnx_params = learner_model.get_model(position)\
|
||||
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
||||
torch.onnx.export(
|
||||
learner_model.get_model(position),
|
||||
onnx_params['args'],
|
||||
model_path,
|
||||
export_params=True,
|
||||
opset_version=10,
|
||||
do_constant_folding=True,
|
||||
input_names=onnx_params['input_names'],
|
||||
output_names=onnx_params['output_names'],
|
||||
dynamic_axes=onnx_params['dynamic_axes']
|
||||
|
|
|
@ -26,20 +26,23 @@ class LandlordLstmModel(nn.Module):
|
|||
self.dense5 = nn.Linear(512, 256)
|
||||
self.dense6 = nn.Linear(256, 1)
|
||||
|
||||
def get_onnx_params(self):
|
||||
def get_onnx_params(self, device):
|
||||
return {
|
||||
'args': (
|
||||
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
|
||||
torch.tensor(np.zeros((1, 887)), dtype=torch.float32),
|
||||
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
||||
torch.randn(1, 887, requires_grad=True, device=device),
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "legal_actions"
|
||||
0: "batch_size"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "legal_actions"
|
||||
0: "batch_size"
|
||||
},
|
||||
'values': {
|
||||
0: "batch_size"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -72,11 +75,11 @@ class FarmerLstmModel(nn.Module):
|
|||
self.dense5 = nn.Linear(512, 256)
|
||||
self.dense6 = nn.Linear(256, 1)
|
||||
|
||||
def get_onnx_params(self):
|
||||
def get_onnx_params(self, device):
|
||||
return {
|
||||
'args': (
|
||||
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
|
||||
torch.tensor(np.zeros((1, 1219)), dtype=torch.float32)
|
||||
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
||||
torch.randn(1, 1219, requires_grad=True, device=device),
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
|
@ -86,6 +89,9 @@ class FarmerLstmModel(nn.Module):
|
|||
},
|
||||
'x_batch': {
|
||||
0: "legal_actions"
|
||||
},
|
||||
'values': {
|
||||
0: "batch_size"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -107,62 +113,6 @@ class FarmerLstmModel(nn.Module):
|
|||
x = self.dense6(x)
|
||||
return dict(values=x)
|
||||
|
||||
class LandlordLstmModelLegacy(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(860 + 128, 1024)
|
||||
self.dense2 = nn.Linear(1024, 1024)
|
||||
self.dense3 = nn.Linear(1024, 768)
|
||||
self.dense4 = nn.Linear(768, 512)
|
||||
self.dense5 = nn.Linear(512, 256)
|
||||
self.dense6 = nn.Linear(256, 1)
|
||||
|
||||
def forward(self, z, x):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
x = torch.cat([lstm_out,x], dim=-1)
|
||||
x = self.dense1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense3(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense4(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense5(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense6(x)
|
||||
return dict(values=x)
|
||||
|
||||
class FarmerLstmModelLegacy(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(1192 + 128, 1024)
|
||||
self.dense2 = nn.Linear(1024, 1024)
|
||||
self.dense3 = nn.Linear(1024, 768)
|
||||
self.dense4 = nn.Linear(768, 512)
|
||||
self.dense5 = nn.Linear(512, 256)
|
||||
self.dense6 = nn.Linear(256, 1)
|
||||
|
||||
def forward(self, z, x):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
x = torch.cat([lstm_out,x], dim=-1)
|
||||
x = self.dense1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense3(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense4(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense5(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense6(x)
|
||||
return dict(values=x)
|
||||
|
||||
# 用于ResNet18和34的残差块,用的是2个3x3的卷积
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
@ -221,20 +171,23 @@ class GeneralModel(nn.Module):
|
|||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_onnx_params(self):
|
||||
def get_onnx_params(self, device=None):
|
||||
return {
|
||||
'args': (
|
||||
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32, device='cuda:0'),
|
||||
torch.tensor(np.zeros((1, 80)), dtype=torch.float32, device='cuda:0')
|
||||
torch.randn(1, 40, 108, requires_grad=True, device=device),
|
||||
torch.randn(1, 56, requires_grad=True, device=device),
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "legal_actions"
|
||||
0: "batch_size"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "legal_actions"
|
||||
0: "batch_size"
|
||||
},
|
||||
'values': {
|
||||
0: "batch_size"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -261,69 +214,11 @@ model_dict['landlord'] = LandlordLstmModel
|
|||
model_dict['landlord_up'] = FarmerLstmModel
|
||||
model_dict['landlord_front'] = FarmerLstmModel
|
||||
model_dict['landlord_down'] = FarmerLstmModel
|
||||
model_dict_legacy = {}
|
||||
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
|
||||
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
|
||||
model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy
|
||||
model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy
|
||||
model_dict_new = {}
|
||||
model_dict_new['landlord'] = GeneralModel
|
||||
model_dict_new['landlord_up'] = GeneralModel
|
||||
model_dict_new['landlord_front'] = GeneralModel
|
||||
model_dict_new['landlord_down'] = GeneralModel
|
||||
model_dict_lstm = {}
|
||||
model_dict_lstm['landlord'] = GeneralModel
|
||||
model_dict_lstm['landlord_up'] = GeneralModel
|
||||
model_dict_lstm['landlord_front'] = GeneralModel
|
||||
model_dict_lstm['landlord_down'] = GeneralModel
|
||||
|
||||
class General_Model:
|
||||
"""
|
||||
The wrapper for the three models. We also wrap several
|
||||
interfaces such as share_memory, eval, etc.
|
||||
"""
|
||||
def __init__(self, device=0):
|
||||
self.models = {}
|
||||
if not device == "cpu":
|
||||
device = 'cuda:' + str(device)
|
||||
# model = GeneralModel().to(torch.device(device))
|
||||
self.models['landlord'] = GeneralModel1().to(torch.device(device))
|
||||
self.models['landlord_up'] = GeneralModel1().to(torch.device(device))
|
||||
self.models['landlord_front'] = GeneralModel1().to(torch.device(device))
|
||||
self.models['landlord_down'] = GeneralModel1().to(torch.device(device))
|
||||
|
||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
||||
model = self.models[position]
|
||||
values = model.forward(z, x)['values']
|
||||
if return_value:
|
||||
return dict(values=values)
|
||||
else:
|
||||
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
||||
action = torch.randint(values.shape[0], (1,))[0]
|
||||
else:
|
||||
action = torch.argmax(values,dim=0)[0]
|
||||
return dict(action=action)
|
||||
|
||||
def share_memory(self):
|
||||
self.models['landlord'].share_memory()
|
||||
self.models['landlord_up'].share_memory()
|
||||
self.models['landlord_front'].share_memory()
|
||||
self.models['landlord_down'].share_memory()
|
||||
|
||||
def eval(self):
|
||||
self.models['landlord'].eval()
|
||||
self.models['landlord_up'].eval()
|
||||
self.models['landlord_front'].eval()
|
||||
self.models['landlord_down'].eval()
|
||||
|
||||
def parameters(self, position):
|
||||
return self.models[position].parameters()
|
||||
|
||||
def get_model(self, position):
|
||||
return self.models[position]
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
class OldModel:
|
||||
"""
|
||||
|
@ -334,10 +229,11 @@ class OldModel:
|
|||
self.models = {}
|
||||
if not device == "cpu":
|
||||
device = 'cuda:' + str(device)
|
||||
self.models['landlord'] = LandlordLstmModel().to(torch.device(device))
|
||||
self.models['landlord_up'] = FarmerLstmModel().to(torch.device(device))
|
||||
self.models['landlord_front'] = FarmerLstmModel().to(torch.device(device))
|
||||
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
|
||||
self.device = torch.device(device)
|
||||
self.models['landlord'] = LandlordLstmModel().to(self.device)
|
||||
self.models['landlord_up'] = FarmerLstmModel().to(self.device)
|
||||
self.models['landlord_front'] = FarmerLstmModel().to(self.device)
|
||||
self.models['landlord_down'] = FarmerLstmModel().to(self.device)
|
||||
self.onnx_models = {
|
||||
'landlord': None,
|
||||
'landlord_up': None,
|
||||
|
@ -349,7 +245,7 @@ class OldModel:
|
|||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
|
||||
|
||||
def get_onnx_params(self, position):
|
||||
self.models[position].get_onnx_params()
|
||||
self.models[position].get_onnx_params(self.device)
|
||||
|
||||
def forward(self, position, z, x, return_value=False, flags=None):
|
||||
model = self.onnx_models[position]
|
||||
|
@ -403,6 +299,7 @@ class Model:
|
|||
self.flags = flags
|
||||
if not device == "cpu":
|
||||
device = 'cuda:' + str(device)
|
||||
self.device = torch.device(device)
|
||||
# model = GeneralModel().to(torch.device(device))
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
if flags is not None and flags.enable_onnx:
|
||||
|
@ -410,7 +307,7 @@ class Model:
|
|||
self.models[position] = None
|
||||
else:
|
||||
for position in positions:
|
||||
self.models[position] = GeneralModel().to(torch.device(device))
|
||||
self.models[position] = GeneralModel().to(self.device)
|
||||
self.onnx_models = {
|
||||
'landlord': None,
|
||||
'landlord_up': None,
|
||||
|
@ -425,7 +322,7 @@ class Model:
|
|||
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
|
||||
def get_onnx_params(self, position):
|
||||
self.models[position].get_onnx_params()
|
||||
self.models[position].get_onnx_params(self.device)
|
||||
|
||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
||||
if self.flags.enable_onnx and len(self.onnx_models) == 0:
|
||||
|
|
Loading…
Reference in New Issue