import torch from torch import nn class LandlordLstmModel(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(162, 128, batch_first = True) self.dense1 = nn.Linear(373 + 128, 512) self.dense2 = nn.Linear(512, 512) self.dense3 = nn.Linear(512, 512) self.dense4 = nn.Linear(512, 512) self.dense5 = nn.Linear(512, 512) self.dense5 = nn.Linear(512, 512) self.dense5 = nn.Linear(512, 512) self.dense6 = nn.Linear(512, 1) def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return x class FarmerLstmModel(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(162, 128, batch_first = True) self.dense1 = nn.Linear(484 + 128 , 512) self.dense2 = nn.Linear(512, 512) self.dense3 = nn.Linear(512, 512) self.dense4 = nn.Linear(512, 512) self.dense4 = nn.Linear(512, 512) self.dense4 = nn.Linear(512, 512) self.dense5 = nn.Linear(512, 512) self.dense6 = nn.Linear(512, 1) def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return x model_dict = {} model_dict['landlord'] = LandlordLstmModel model_dict['landlord_up'] = FarmerLstmModel model_dict['landlord_down'] = FarmerLstmModel