68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
class LandlordLstmModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lstm = nn.LSTM(162, 128, batch_first = True)
|
|
self.dense1 = nn.Linear(373 + 128, 512)
|
|
self.dense2 = nn.Linear(512, 512)
|
|
self.dense3 = nn.Linear(512, 512)
|
|
self.dense4 = nn.Linear(512, 512)
|
|
self.dense5 = nn.Linear(512, 512)
|
|
self.dense5 = nn.Linear(512, 512)
|
|
self.dense5 = nn.Linear(512, 512)
|
|
self.dense6 = nn.Linear(512, 1)
|
|
|
|
def forward(self, z, x):
|
|
lstm_out, (h_n, _) = self.lstm(z)
|
|
lstm_out = lstm_out[:,-1,:]
|
|
x = torch.cat([lstm_out,x], dim=-1)
|
|
x = self.dense1(x)
|
|
x = torch.relu(x)
|
|
x = self.dense2(x)
|
|
x = torch.relu(x)
|
|
x = self.dense3(x)
|
|
x = torch.relu(x)
|
|
x = self.dense4(x)
|
|
x = torch.relu(x)
|
|
x = self.dense5(x)
|
|
x = torch.relu(x)
|
|
x = self.dense6(x)
|
|
return x
|
|
|
|
class FarmerLstmModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lstm = nn.LSTM(162, 128, batch_first = True)
|
|
self.dense1 = nn.Linear(484 + 128 , 512)
|
|
self.dense2 = nn.Linear(512, 512)
|
|
self.dense3 = nn.Linear(512, 512)
|
|
self.dense4 = nn.Linear(512, 512)
|
|
self.dense4 = nn.Linear(512, 512)
|
|
self.dense4 = nn.Linear(512, 512)
|
|
self.dense5 = nn.Linear(512, 512)
|
|
self.dense6 = nn.Linear(512, 1)
|
|
|
|
def forward(self, z, x):
|
|
lstm_out, (h_n, _) = self.lstm(z)
|
|
lstm_out = lstm_out[:,-1,:]
|
|
x = torch.cat([lstm_out,x], dim=-1)
|
|
x = self.dense1(x)
|
|
x = torch.relu(x)
|
|
x = self.dense2(x)
|
|
x = torch.relu(x)
|
|
x = self.dense3(x)
|
|
x = torch.relu(x)
|
|
x = self.dense4(x)
|
|
x = torch.relu(x)
|
|
x = self.dense5(x)
|
|
x = torch.relu(x)
|
|
x = self.dense6(x)
|
|
return x
|
|
|
|
model_dict = {}
|
|
model_dict['landlord'] = LandlordLstmModel
|
|
model_dict['landlord_up'] = FarmerLstmModel
|
|
model_dict['landlord_down'] = FarmerLstmModel
|