旧模型压缩模式
This commit is contained in:
parent
177cb04c03
commit
fb22ab2649
|
@ -94,7 +94,7 @@ def train(flags):
|
|||
models = {}
|
||||
for device in device_iterator:
|
||||
if flags.old_model:
|
||||
model = OldModel(device="cpu", flags = flags)
|
||||
model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model)
|
||||
else:
|
||||
model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
|
||||
model.share_memory()
|
||||
|
@ -109,7 +109,7 @@ def train(flags):
|
|||
|
||||
# Learner model for training
|
||||
if flags.old_model:
|
||||
learner_model = OldModel(device=flags.training_device)
|
||||
learner_model = OldModel(device=flags.training_device, lite_model = flags.lite_model)
|
||||
else:
|
||||
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
|
||||
|
||||
|
|
|
@ -113,6 +113,106 @@ class FarmerLstmModel(nn.Module):
|
|||
x = self.dense6(x)
|
||||
return dict(values=x)
|
||||
|
||||
|
||||
class LandlordLstmModelLite(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(276, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(590 + 128, 768)
|
||||
self.dense2 = nn.Linear(768, 768)
|
||||
self.dense3 = nn.Linear(768, 768)
|
||||
self.dense4 = nn.Linear(768, 768)
|
||||
self.dense5 = nn.Linear(768, 768)
|
||||
self.dense6 = nn.Linear(768, 1)
|
||||
|
||||
def get_onnx_params(self, device):
|
||||
return {
|
||||
'args': (
|
||||
torch.randn(1, 5, 276, requires_grad=True, device=device),
|
||||
torch.randn(1, 590, requires_grad=True, device=device),
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "batch_size"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "batch_size"
|
||||
},
|
||||
'values': {
|
||||
0: "batch_size"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def forward(self, z, x):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
x = torch.cat([lstm_out,x], dim=-1)
|
||||
x = self.dense1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense3(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense4(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense5(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense6(x)
|
||||
return dict(values=x)
|
||||
|
||||
class FarmerLstmModelLite(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(276, 128, batch_first=True)
|
||||
self.dense1 = nn.Linear(798 + 128, 768)
|
||||
self.dense2 = nn.Linear(768, 768)
|
||||
self.dense3 = nn.Linear(768, 768)
|
||||
self.dense4 = nn.Linear(768, 768)
|
||||
self.dense5 = nn.Linear(768, 768)
|
||||
self.dense6 = nn.Linear(768, 1)
|
||||
|
||||
def get_onnx_params(self, device):
|
||||
return {
|
||||
'args': (
|
||||
torch.randn(1, 5, 276, requires_grad=True, device=device),
|
||||
torch.randn(1, 798, requires_grad=True, device=device),
|
||||
),
|
||||
'input_names': ['z_batch','x_batch'],
|
||||
'output_names': ['values'],
|
||||
'dynamic_axes': {
|
||||
'z_batch': {
|
||||
0: "legal_actions"
|
||||
},
|
||||
'x_batch': {
|
||||
0: "legal_actions"
|
||||
},
|
||||
'values': {
|
||||
0: "batch_size"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def forward(self, z, x):
|
||||
lstm_out, (h_n, _) = self.lstm(z)
|
||||
lstm_out = lstm_out[:,-1,:]
|
||||
x = torch.cat([lstm_out,x], dim=-1)
|
||||
x = self.dense1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense3(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense4(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense5(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dense6(x)
|
||||
return dict(values=x)
|
||||
|
||||
|
||||
class LandlordLstmModelLegacy(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -396,7 +496,7 @@ class OldModel:
|
|||
The wrapper for the three models. We also wrap several
|
||||
interfaces such as share_memory, eval, etc.
|
||||
"""
|
||||
def __init__(self, device=0, flags=None):
|
||||
def __init__(self, device=0, flags=None, lite_model = False):
|
||||
self.models = {}
|
||||
self.onnx_models = {}
|
||||
self.flags = flags
|
||||
|
@ -409,8 +509,8 @@ class OldModel:
|
|||
self.models[position] = None
|
||||
else:
|
||||
for position in positions[1:]:
|
||||
self.models[position] = FarmerLstmModel().to(self.device)
|
||||
self.models['landlord'] = LandlordLstmModel().to(self.device)
|
||||
self.models[position] = FarmerLstmModelLite().to(self.device) if lite_model else FarmerLstmModel().to(self.device)
|
||||
self.models['landlord'] = LandlordLstmModelLite().to(self.device) if lite_model else LandlordLstmModel().to(self.device)
|
||||
self.onnx_models = {
|
||||
'landlord': None,
|
||||
'landlord_up': None,
|
||||
|
|
|
@ -496,7 +496,7 @@ def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False):
|
|||
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||
|
||||
landlord_up_num_cards_left = _get_one_hot_array(
|
||||
infoset.num_cards_left_dict['landlord_up'], 25, 15 if compressed_form else 0)
|
||||
infoset.num_cards_left_dict['landlord_up'], 25, 8 if compressed_form else 0)
|
||||
landlord_up_num_cards_left_batch = np.repeat(
|
||||
landlord_up_num_cards_left[np.newaxis, :],
|
||||
num_legal_actions, axis=0)
|
||||
|
|
Loading…
Reference in New Issue