旧模型压缩模式

This commit is contained in:
zhiyang7 2021-12-23 09:55:49 +08:00
parent 177cb04c03
commit fb22ab2649
3 changed files with 106 additions and 6 deletions

View File

@ -94,7 +94,7 @@ def train(flags):
models = {}
for device in device_iterator:
if flags.old_model:
model = OldModel(device="cpu", flags = flags)
model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model)
else:
model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
model.share_memory()
@ -109,7 +109,7 @@ def train(flags):
# Learner model for training
if flags.old_model:
learner_model = OldModel(device=flags.training_device)
learner_model = OldModel(device=flags.training_device, lite_model = flags.lite_model)
else:
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)

View File

@ -113,6 +113,106 @@ class FarmerLstmModel(nn.Module):
x = self.dense6(x)
return dict(values=x)
class LandlordLstmModelLite(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(276, 128, batch_first=True)
self.dense1 = nn.Linear(590 + 128, 768)
self.dense2 = nn.Linear(768, 768)
self.dense3 = nn.Linear(768, 768)
self.dense4 = nn.Linear(768, 768)
self.dense5 = nn.Linear(768, 768)
self.dense6 = nn.Linear(768, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 276, requires_grad=True, device=device),
torch.randn(1, 590, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class FarmerLstmModelLite(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(276, 128, batch_first=True)
self.dense1 = nn.Linear(798 + 128, 768)
self.dense2 = nn.Linear(768, 768)
self.dense3 = nn.Linear(768, 768)
self.dense4 = nn.Linear(768, 768)
self.dense5 = nn.Linear(768, 768)
self.dense6 = nn.Linear(768, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 276, requires_grad=True, device=device),
torch.randn(1, 798, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class LandlordLstmModelLegacy(nn.Module):
def __init__(self):
super().__init__()
@ -396,7 +496,7 @@ class OldModel:
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0, flags=None):
def __init__(self, device=0, flags=None, lite_model = False):
self.models = {}
self.onnx_models = {}
self.flags = flags
@ -409,8 +509,8 @@ class OldModel:
self.models[position] = None
else:
for position in positions[1:]:
self.models[position] = FarmerLstmModel().to(self.device)
self.models['landlord'] = LandlordLstmModel().to(self.device)
self.models[position] = FarmerLstmModelLite().to(self.device) if lite_model else FarmerLstmModel().to(self.device)
self.models['landlord'] = LandlordLstmModelLite().to(self.device) if lite_model else LandlordLstmModel().to(self.device)
self.onnx_models = {
'landlord': None,
'landlord_up': None,

2
douzero/env/env.py vendored
View File

@ -496,7 +496,7 @@ def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False):
my_action_batch[j, :] = _cards2array(action, compressed_form)
landlord_up_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_up'], 25, 15 if compressed_form else 0)
infoset.num_cards_left_dict['landlord_up'], 25, 8 if compressed_form else 0)
landlord_up_num_cards_left_batch = np.repeat(
landlord_up_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)