2021-09-07 16:38:34 +08:00
|
|
|
|
"""
|
|
|
|
|
This file includes the torch models. We wrap the three
|
|
|
|
|
models into one class for convenience.
|
|
|
|
|
"""
|
2021-12-15 22:09:18 +08:00
|
|
|
|
import os
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import torch
|
2021-12-14 22:55:03 +08:00
|
|
|
|
import onnxruntime
|
|
|
|
|
from onnxruntime.datasets import get_example
|
2021-09-07 16:38:34 +08:00
|
|
|
|
from torch import nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
class LandlordLstmModel(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
2021-12-05 12:03:30 +08:00
|
|
|
|
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
2021-12-12 14:01:40 +08:00
|
|
|
|
self.dense1 = nn.Linear(887 + 128, 1024)
|
2021-12-05 12:03:30 +08:00
|
|
|
|
self.dense2 = nn.Linear(1024, 1024)
|
|
|
|
|
self.dense3 = nn.Linear(1024, 768)
|
|
|
|
|
self.dense4 = nn.Linear(768, 512)
|
|
|
|
|
self.dense5 = nn.Linear(512, 256)
|
|
|
|
|
self.dense6 = nn.Linear(256, 1)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
2021-12-19 17:48:24 +08:00
|
|
|
|
def get_onnx_params(self, device):
|
2021-12-15 12:26:47 +08:00
|
|
|
|
return {
|
|
|
|
|
'args': (
|
2021-12-19 17:48:24 +08:00
|
|
|
|
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
|
|
|
|
torch.randn(1, 887, requires_grad=True, device=device),
|
2021-12-15 12:26:47 +08:00
|
|
|
|
),
|
|
|
|
|
'input_names': ['z_batch','x_batch'],
|
|
|
|
|
'output_names': ['values'],
|
|
|
|
|
'dynamic_axes': {
|
|
|
|
|
'z_batch': {
|
2021-12-19 17:48:24 +08:00
|
|
|
|
0: "batch_size"
|
2021-12-15 12:26:47 +08:00
|
|
|
|
},
|
|
|
|
|
'x_batch': {
|
2021-12-19 17:48:24 +08:00
|
|
|
|
0: "batch_size"
|
|
|
|
|
},
|
|
|
|
|
'values': {
|
|
|
|
|
0: "batch_size"
|
2021-12-15 12:26:47 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-12-14 22:55:03 +08:00
|
|
|
|
def forward(self, z, x):
|
2021-09-07 16:38:34 +08:00
|
|
|
|
lstm_out, (h_n, _) = self.lstm(z)
|
|
|
|
|
lstm_out = lstm_out[:,-1,:]
|
|
|
|
|
x = torch.cat([lstm_out,x], dim=-1)
|
|
|
|
|
x = self.dense1(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense2(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense3(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense4(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense5(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense6(x)
|
2021-12-14 22:55:03 +08:00
|
|
|
|
return dict(values=x)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
|
|
class FarmerLstmModel(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
2021-12-05 12:03:30 +08:00
|
|
|
|
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
2021-12-12 14:01:40 +08:00
|
|
|
|
self.dense1 = nn.Linear(1219 + 128, 1024)
|
2021-12-05 12:03:30 +08:00
|
|
|
|
self.dense2 = nn.Linear(1024, 1024)
|
|
|
|
|
self.dense3 = nn.Linear(1024, 768)
|
|
|
|
|
self.dense4 = nn.Linear(768, 512)
|
|
|
|
|
self.dense5 = nn.Linear(512, 256)
|
|
|
|
|
self.dense6 = nn.Linear(256, 1)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
2021-12-19 17:48:24 +08:00
|
|
|
|
def get_onnx_params(self, device):
|
2021-12-15 12:26:47 +08:00
|
|
|
|
return {
|
|
|
|
|
'args': (
|
2021-12-19 17:48:24 +08:00
|
|
|
|
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
|
|
|
|
torch.randn(1, 1219, requires_grad=True, device=device),
|
2021-12-15 12:26:47 +08:00
|
|
|
|
),
|
|
|
|
|
'input_names': ['z_batch','x_batch'],
|
|
|
|
|
'output_names': ['values'],
|
|
|
|
|
'dynamic_axes': {
|
|
|
|
|
'z_batch': {
|
|
|
|
|
0: "legal_actions"
|
|
|
|
|
},
|
|
|
|
|
'x_batch': {
|
|
|
|
|
0: "legal_actions"
|
2021-12-19 17:48:24 +08:00
|
|
|
|
},
|
|
|
|
|
'values': {
|
|
|
|
|
0: "batch_size"
|
2021-12-15 12:26:47 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-12-14 22:55:03 +08:00
|
|
|
|
def forward(self, z, x):
|
2021-09-07 16:38:34 +08:00
|
|
|
|
lstm_out, (h_n, _) = self.lstm(z)
|
|
|
|
|
lstm_out = lstm_out[:,-1,:]
|
|
|
|
|
x = torch.cat([lstm_out,x], dim=-1)
|
|
|
|
|
x = self.dense1(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense2(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense3(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense4(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense5(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense6(x)
|
2021-12-14 22:55:03 +08:00
|
|
|
|
return dict(values=x)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
2021-12-23 09:55:49 +08:00
|
|
|
|
|
|
|
|
|
class LandlordLstmModelLite(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.lstm = nn.LSTM(276, 128, batch_first=True)
|
|
|
|
|
self.dense1 = nn.Linear(590 + 128, 768)
|
|
|
|
|
self.dense2 = nn.Linear(768, 768)
|
|
|
|
|
self.dense3 = nn.Linear(768, 768)
|
|
|
|
|
self.dense4 = nn.Linear(768, 768)
|
|
|
|
|
self.dense5 = nn.Linear(768, 768)
|
|
|
|
|
self.dense6 = nn.Linear(768, 1)
|
|
|
|
|
|
|
|
|
|
def get_onnx_params(self, device):
|
|
|
|
|
return {
|
|
|
|
|
'args': (
|
|
|
|
|
torch.randn(1, 5, 276, requires_grad=True, device=device),
|
|
|
|
|
torch.randn(1, 590, requires_grad=True, device=device),
|
|
|
|
|
),
|
|
|
|
|
'input_names': ['z_batch','x_batch'],
|
|
|
|
|
'output_names': ['values'],
|
|
|
|
|
'dynamic_axes': {
|
|
|
|
|
'z_batch': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
},
|
|
|
|
|
'x_batch': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
},
|
|
|
|
|
'values': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def forward(self, z, x):
|
|
|
|
|
lstm_out, (h_n, _) = self.lstm(z)
|
|
|
|
|
lstm_out = lstm_out[:,-1,:]
|
|
|
|
|
x = torch.cat([lstm_out,x], dim=-1)
|
|
|
|
|
x = self.dense1(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense2(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense3(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense4(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense5(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense6(x)
|
|
|
|
|
return dict(values=x)
|
|
|
|
|
|
|
|
|
|
class FarmerLstmModelLite(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.lstm = nn.LSTM(276, 128, batch_first=True)
|
|
|
|
|
self.dense1 = nn.Linear(798 + 128, 768)
|
|
|
|
|
self.dense2 = nn.Linear(768, 768)
|
|
|
|
|
self.dense3 = nn.Linear(768, 768)
|
|
|
|
|
self.dense4 = nn.Linear(768, 768)
|
|
|
|
|
self.dense5 = nn.Linear(768, 768)
|
|
|
|
|
self.dense6 = nn.Linear(768, 1)
|
|
|
|
|
|
|
|
|
|
def get_onnx_params(self, device):
|
|
|
|
|
return {
|
|
|
|
|
'args': (
|
|
|
|
|
torch.randn(1, 5, 276, requires_grad=True, device=device),
|
|
|
|
|
torch.randn(1, 798, requires_grad=True, device=device),
|
|
|
|
|
),
|
|
|
|
|
'input_names': ['z_batch','x_batch'],
|
|
|
|
|
'output_names': ['values'],
|
|
|
|
|
'dynamic_axes': {
|
|
|
|
|
'z_batch': {
|
|
|
|
|
0: "legal_actions"
|
|
|
|
|
},
|
|
|
|
|
'x_batch': {
|
|
|
|
|
0: "legal_actions"
|
|
|
|
|
},
|
|
|
|
|
'values': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def forward(self, z, x):
|
|
|
|
|
lstm_out, (h_n, _) = self.lstm(z)
|
|
|
|
|
lstm_out = lstm_out[:,-1,:]
|
|
|
|
|
x = torch.cat([lstm_out,x], dim=-1)
|
|
|
|
|
x = self.dense1(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense2(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense3(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense4(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense5(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense6(x)
|
|
|
|
|
return dict(values=x)
|
|
|
|
|
|
|
|
|
|
|
2021-12-20 10:02:55 +08:00
|
|
|
|
class LandlordLstmModelLegacy(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
|
|
|
|
self.dense1 = nn.Linear(860 + 128, 1024)
|
|
|
|
|
self.dense2 = nn.Linear(1024, 1024)
|
|
|
|
|
self.dense3 = nn.Linear(1024, 768)
|
|
|
|
|
self.dense4 = nn.Linear(768, 512)
|
|
|
|
|
self.dense5 = nn.Linear(512, 256)
|
|
|
|
|
self.dense6 = nn.Linear(256, 1)
|
|
|
|
|
|
|
|
|
|
def get_onnx_params(self, device):
|
|
|
|
|
return {
|
|
|
|
|
'args': (
|
|
|
|
|
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
2021-12-21 18:05:52 +08:00
|
|
|
|
torch.randn(1, 860, requires_grad=True, device=device),
|
2021-12-20 10:02:55 +08:00
|
|
|
|
),
|
|
|
|
|
'input_names': ['z_batch','x_batch'],
|
|
|
|
|
'output_names': ['values'],
|
|
|
|
|
'dynamic_axes': {
|
|
|
|
|
'z_batch': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
},
|
|
|
|
|
'x_batch': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
},
|
|
|
|
|
'values': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def forward(self, z, x):
|
|
|
|
|
lstm_out, (h_n, _) = self.lstm(z)
|
|
|
|
|
lstm_out = lstm_out[:,-1,:]
|
|
|
|
|
x = torch.cat([lstm_out,x], dim=-1)
|
|
|
|
|
x = self.dense1(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense2(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense3(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense4(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense5(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense6(x)
|
|
|
|
|
return dict(values=x)
|
|
|
|
|
|
|
|
|
|
class FarmerLstmModelLegacy(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.lstm = nn.LSTM(432, 128, batch_first=True)
|
|
|
|
|
self.dense1 = nn.Linear(1192 + 128, 1024)
|
|
|
|
|
self.dense2 = nn.Linear(1024, 1024)
|
|
|
|
|
self.dense3 = nn.Linear(1024, 768)
|
|
|
|
|
self.dense4 = nn.Linear(768, 512)
|
|
|
|
|
self.dense5 = nn.Linear(512, 256)
|
|
|
|
|
self.dense6 = nn.Linear(256, 1)
|
|
|
|
|
|
|
|
|
|
def get_onnx_params(self, device):
|
|
|
|
|
return {
|
|
|
|
|
'args': (
|
|
|
|
|
torch.randn(1, 5, 432, requires_grad=True, device=device),
|
2021-12-21 18:05:52 +08:00
|
|
|
|
torch.randn(1, 1192, requires_grad=True, device=device),
|
2021-12-20 10:02:55 +08:00
|
|
|
|
),
|
|
|
|
|
'input_names': ['z_batch','x_batch'],
|
|
|
|
|
'output_names': ['values'],
|
|
|
|
|
'dynamic_axes': {
|
|
|
|
|
'z_batch': {
|
|
|
|
|
0: "legal_actions"
|
|
|
|
|
},
|
|
|
|
|
'x_batch': {
|
|
|
|
|
0: "legal_actions"
|
|
|
|
|
},
|
|
|
|
|
'values': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def forward(self, z, x):
|
|
|
|
|
lstm_out, (h_n, _) = self.lstm(z)
|
|
|
|
|
lstm_out = lstm_out[:,-1,:]
|
|
|
|
|
x = torch.cat([lstm_out,x], dim=-1)
|
|
|
|
|
x = self.dense1(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense2(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense3(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense4(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense5(x)
|
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
x = self.dense6(x)
|
|
|
|
|
return dict(values=x)
|
|
|
|
|
|
2021-09-07 16:38:34 +08:00
|
|
|
|
# 用于ResNet18和34的残差块,用的是2个3x3的卷积
|
|
|
|
|
class BasicBlock(nn.Module):
|
|
|
|
|
expansion = 1
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_planes, planes, stride=1):
|
|
|
|
|
super(BasicBlock, self).__init__()
|
|
|
|
|
self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=(3,),
|
|
|
|
|
stride=(stride,), padding=1, bias=False)
|
|
|
|
|
self.bn1 = nn.BatchNorm1d(planes)
|
|
|
|
|
self.conv2 = nn.Conv1d(planes, planes, kernel_size=(3,),
|
|
|
|
|
stride=(1,), padding=1, bias=False)
|
|
|
|
|
self.bn2 = nn.BatchNorm1d(planes)
|
|
|
|
|
self.shortcut = nn.Sequential()
|
|
|
|
|
# 经过处理后的x要与x的维度相同(尺寸和深度)
|
|
|
|
|
# 如果不相同,需要添加卷积+BN来变换为同一维度
|
|
|
|
|
if stride != 1 or in_planes != self.expansion * planes:
|
|
|
|
|
self.shortcut = nn.Sequential(
|
|
|
|
|
nn.Conv1d(in_planes, self.expansion * planes,
|
|
|
|
|
kernel_size=(1,), stride=(stride,), bias=False),
|
|
|
|
|
nn.BatchNorm1d(self.expansion * planes)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
out = F.relu(self.bn1(self.conv1(x)))
|
|
|
|
|
out = self.bn2(self.conv2(out))
|
|
|
|
|
out += self.shortcut(x)
|
|
|
|
|
out = F.relu(out)
|
|
|
|
|
return out
|
|
|
|
|
|
2021-12-22 21:19:10 +08:00
|
|
|
|
|
|
|
|
|
class GeneralModelLite(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.in_planes = 80
|
|
|
|
|
#input 1*69*41
|
|
|
|
|
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
|
|
|
|
|
stride=(2,), padding=1, bias=False) #1*35*80
|
|
|
|
|
|
|
|
|
|
self.bn1 = nn.BatchNorm1d(80)
|
|
|
|
|
|
|
|
|
|
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*18*80
|
|
|
|
|
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*9*160
|
|
|
|
|
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*5*320
|
|
|
|
|
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
|
|
|
|
self.linear1 = nn.Linear(320 * BasicBlock.expansion * 5 + 56, 1536)
|
|
|
|
|
self.linear2 = nn.Linear(1536, 768)
|
|
|
|
|
self.linear3 = nn.Linear(768, 384)
|
|
|
|
|
self.linear4 = nn.Linear(384, 1)
|
|
|
|
|
|
|
|
|
|
def _make_layer(self, block, planes, num_blocks, stride):
|
|
|
|
|
strides = [stride] + [1] * (num_blocks - 1)
|
|
|
|
|
layers = []
|
|
|
|
|
for stride in strides:
|
|
|
|
|
layers.append(block(self.in_planes, planes, stride))
|
|
|
|
|
self.in_planes = planes * block.expansion
|
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
|
def get_onnx_params(self, device=None):
|
|
|
|
|
return {
|
|
|
|
|
'args': (
|
|
|
|
|
torch.randn(1, 40, 69, requires_grad=True, device=device),
|
|
|
|
|
torch.randn(1, 56, requires_grad=True, device=device),
|
|
|
|
|
),
|
|
|
|
|
'input_names': ['z_batch','x_batch'],
|
|
|
|
|
'output_names': ['values'],
|
|
|
|
|
'dynamic_axes': {
|
|
|
|
|
'z_batch': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
},
|
|
|
|
|
'x_batch': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
},
|
|
|
|
|
'values': {
|
|
|
|
|
0: "batch_size"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def forward(self, z, x):
|
|
|
|
|
out = F.relu(self.bn1(self.conv1(z)))
|
|
|
|
|
out = self.layer1(out)
|
|
|
|
|
out = self.layer2(out)
|
|
|
|
|
out = self.layer3(out)
|
|
|
|
|
out = out.flatten(1,2)
|
|
|
|
|
out = torch.cat([x,out], dim=-1)
|
|
|
|
|
out = F.leaky_relu_(self.linear1(out))
|
|
|
|
|
out = F.leaky_relu_(self.linear2(out))
|
|
|
|
|
out = F.leaky_relu_(self.linear3(out))
|
|
|
|
|
out = F.leaky_relu_(self.linear4(out))
|
|
|
|
|
return dict(values=out)
|
|
|
|
|
|
|
|
|
|
|
2021-12-12 14:01:40 +08:00
|
|
|
|
class GeneralModel(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.in_planes = 80
|
|
|
|
|
#input 1*108*41
|
|
|
|
|
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
|
|
|
|
|
stride=(2,), padding=1, bias=False) #1*108*80
|
|
|
|
|
|
|
|
|
|
self.bn1 = nn.BatchNorm1d(80)
|
|
|
|
|
|
|
|
|
|
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80
|
|
|
|
|
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160
|
|
|
|
|
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
|
|
|
|
|
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
|
|
|
|
|
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
2021-12-19 17:19:32 +08:00
|
|
|
|
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 56, 2048)
|
2021-12-12 14:01:40 +08:00
|
|
|
|
self.linear2 = nn.Linear(2048, 1024)
|
|
|
|
|
self.linear3 = nn.Linear(1024, 512)
|
|
|
|
|
self.linear4 = nn.Linear(512, 256)
|
|
|
|
|
self.linear5 = nn.Linear(256, 1)
|
|
|
|
|
|
|
|
|
|
def _make_layer(self, block, planes, num_blocks, stride):
|
|
|
|
|
strides = [stride] + [1] * (num_blocks - 1)
|
|
|
|
|
layers = []
|
|
|
|
|
for stride in strides:
|
|
|
|
|
layers.append(block(self.in_planes, planes, stride))
|
|
|
|
|
self.in_planes = planes * block.expansion
|
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
2021-12-19 17:48:24 +08:00
|
|
|
|
def get_onnx_params(self, device=None):
|
2021-12-15 12:26:47 +08:00
|
|
|
|
return {
|
|
|
|
|
'args': (
|
2021-12-19 17:48:24 +08:00
|
|
|
|
torch.randn(1, 40, 108, requires_grad=True, device=device),
|
|
|
|
|
torch.randn(1, 56, requires_grad=True, device=device),
|
2021-12-15 12:26:47 +08:00
|
|
|
|
),
|
|
|
|
|
'input_names': ['z_batch','x_batch'],
|
|
|
|
|
'output_names': ['values'],
|
|
|
|
|
'dynamic_axes': {
|
|
|
|
|
'z_batch': {
|
2021-12-19 17:48:24 +08:00
|
|
|
|
0: "batch_size"
|
2021-12-15 12:26:47 +08:00
|
|
|
|
},
|
|
|
|
|
'x_batch': {
|
2021-12-19 17:48:24 +08:00
|
|
|
|
0: "batch_size"
|
|
|
|
|
},
|
|
|
|
|
'values': {
|
|
|
|
|
0: "batch_size"
|
2021-12-15 12:26:47 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2021-12-14 22:55:03 +08:00
|
|
|
|
def forward(self, z, x):
|
2021-12-12 14:01:40 +08:00
|
|
|
|
out = F.relu(self.bn1(self.conv1(z)))
|
|
|
|
|
out = self.layer1(out)
|
|
|
|
|
out = self.layer2(out)
|
|
|
|
|
out = self.layer3(out)
|
|
|
|
|
out = self.layer4(out)
|
|
|
|
|
out = out.flatten(1,2)
|
|
|
|
|
out = torch.cat([x,out], dim=-1)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
out = F.leaky_relu_(self.linear1(out))
|
|
|
|
|
out = F.leaky_relu_(self.linear2(out))
|
|
|
|
|
out = F.leaky_relu_(self.linear3(out))
|
|
|
|
|
out = F.leaky_relu_(self.linear4(out))
|
2021-12-05 12:03:30 +08:00
|
|
|
|
out = F.leaky_relu_(self.linear5(out))
|
2021-12-14 22:55:03 +08:00
|
|
|
|
return dict(values=out)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Model dict is only used in evaluation but not training
|
|
|
|
|
model_dict = {}
|
|
|
|
|
model_dict['landlord'] = LandlordLstmModel
|
|
|
|
|
model_dict['landlord_up'] = FarmerLstmModel
|
2021-12-05 12:03:30 +08:00
|
|
|
|
model_dict['landlord_front'] = FarmerLstmModel
|
2021-09-07 16:38:34 +08:00
|
|
|
|
model_dict['landlord_down'] = FarmerLstmModel
|
2021-12-24 09:49:36 +08:00
|
|
|
|
model_dict_lite = {}
|
|
|
|
|
model_dict_lite['landlord'] = LandlordLstmModelLite
|
|
|
|
|
model_dict_lite['landlord_up'] = FarmerLstmModelLite
|
|
|
|
|
model_dict_lite['landlord_front'] = FarmerLstmModelLite
|
|
|
|
|
model_dict_lite['landlord_down'] = FarmerLstmModelLite
|
2021-12-20 10:02:55 +08:00
|
|
|
|
model_dict_legacy = {}
|
|
|
|
|
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
|
|
|
|
|
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
|
|
|
|
|
model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy
|
|
|
|
|
model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy
|
2021-09-07 16:38:34 +08:00
|
|
|
|
model_dict_new = {}
|
|
|
|
|
model_dict_new['landlord'] = GeneralModel
|
|
|
|
|
model_dict_new['landlord_up'] = GeneralModel
|
2021-12-05 12:03:30 +08:00
|
|
|
|
model_dict_new['landlord_front'] = GeneralModel
|
2021-09-07 16:38:34 +08:00
|
|
|
|
model_dict_new['landlord_down'] = GeneralModel
|
2021-12-23 09:22:38 +08:00
|
|
|
|
model_dict_new_lite = {}
|
|
|
|
|
model_dict_new_lite['landlord'] = GeneralModelLite
|
|
|
|
|
model_dict_new_lite['landlord_up'] = GeneralModelLite
|
|
|
|
|
model_dict_new_lite['landlord_front'] = GeneralModelLite
|
|
|
|
|
model_dict_new_lite['landlord_down'] = GeneralModelLite
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
2022-01-04 11:12:36 +08:00
|
|
|
|
def forward_logic(self_model, position, z, x, device='cpu', return_value=False, flags=None):
|
2021-12-23 11:37:32 +08:00
|
|
|
|
legal_count = len(z)
|
2022-01-04 11:12:36 +08:00
|
|
|
|
if not flags.enable_onnx:
|
|
|
|
|
z = torch.tensor(z, device=device)
|
|
|
|
|
x = torch.tensor(x, device=device)
|
2021-12-23 11:37:32 +08:00
|
|
|
|
if legal_count >= 80:
|
|
|
|
|
partition_count = int(legal_count / 40)
|
|
|
|
|
sub_z = np.array_split(z, partition_count)
|
|
|
|
|
sub_x = np.array_split(x, partition_count)
|
|
|
|
|
if flags.enable_onnx:
|
|
|
|
|
model = self_model.onnx_models[position]
|
|
|
|
|
if legal_count >= 80:
|
|
|
|
|
values = np.ndarray((legal_count, 1))
|
|
|
|
|
j = 0
|
|
|
|
|
for i in range(partition_count):
|
|
|
|
|
onnx_out = model.run(None, {'z_batch': sub_z[i], 'x_batch': sub_x[i]})
|
|
|
|
|
values[j:j+len(sub_z[i])] = onnx_out[0]
|
|
|
|
|
j += len(sub_z[i])
|
|
|
|
|
else:
|
|
|
|
|
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
|
|
|
|
values = onnx_out[0]
|
|
|
|
|
else:
|
|
|
|
|
if legal_count >= 80:
|
|
|
|
|
values = np.ndarray((legal_count, 1))
|
|
|
|
|
j = 0
|
|
|
|
|
for i in range(partition_count):
|
|
|
|
|
model = self_model.models[position]
|
|
|
|
|
model_out = model.forward(sub_z[i], sub_x[i])['values']
|
2021-12-23 11:40:46 +08:00
|
|
|
|
values[j:j+len(sub_z[i])] = model_out.cpu().detach().numpy()
|
2021-12-23 11:37:32 +08:00
|
|
|
|
j += len(sub_z[i])
|
|
|
|
|
else:
|
|
|
|
|
model = self_model.models[position]
|
|
|
|
|
values = model.forward(z, x)['values']
|
|
|
|
|
if return_value:
|
|
|
|
|
return dict(values = values.cpu().detach().numpy() if torch.is_tensor(values) else values)
|
|
|
|
|
else:
|
|
|
|
|
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
|
|
|
|
if torch.is_tensor(values):
|
|
|
|
|
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
|
|
|
|
else:
|
|
|
|
|
action = np.random.randint(0, values.shape[0], (1,))[0]
|
|
|
|
|
else:
|
|
|
|
|
if torch.is_tensor(values):
|
|
|
|
|
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
|
|
|
|
|
else:
|
|
|
|
|
action = np.argmax(values, axis=0)[0]
|
|
|
|
|
return dict(action = action)
|
|
|
|
|
|
2021-09-07 16:38:34 +08:00
|
|
|
|
class OldModel:
|
|
|
|
|
"""
|
|
|
|
|
The wrapper for the three models. We also wrap several
|
|
|
|
|
interfaces such as share_memory, eval, etc.
|
|
|
|
|
"""
|
2021-12-23 09:55:49 +08:00
|
|
|
|
def __init__(self, device=0, flags=None, lite_model = False):
|
2021-09-07 16:38:34 +08:00
|
|
|
|
self.models = {}
|
2021-12-22 09:04:56 +08:00
|
|
|
|
self.onnx_models = {}
|
|
|
|
|
self.flags = flags
|
2021-09-07 16:38:34 +08:00
|
|
|
|
if not device == "cpu":
|
|
|
|
|
device = 'cuda:' + str(device)
|
2021-12-19 17:48:24 +08:00
|
|
|
|
self.device = torch.device(device)
|
2021-12-22 09:04:56 +08:00
|
|
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
|
|
|
|
if flags is not None and flags.enable_onnx:
|
|
|
|
|
for position in positions:
|
|
|
|
|
self.models[position] = None
|
|
|
|
|
else:
|
|
|
|
|
for position in positions[1:]:
|
2021-12-23 09:55:49 +08:00
|
|
|
|
self.models[position] = FarmerLstmModelLite().to(self.device) if lite_model else FarmerLstmModel().to(self.device)
|
|
|
|
|
self.models['landlord'] = LandlordLstmModelLite().to(self.device) if lite_model else LandlordLstmModel().to(self.device)
|
2021-12-22 09:04:56 +08:00
|
|
|
|
self.onnx_models = {
|
|
|
|
|
'landlord': None,
|
|
|
|
|
'landlord_up': None,
|
|
|
|
|
'landlord_front': None,
|
|
|
|
|
'landlord_down': None
|
|
|
|
|
}
|
2021-12-15 12:26:47 +08:00
|
|
|
|
|
2021-12-22 09:04:56 +08:00
|
|
|
|
def set_onnx_model(self, device='cpu'):
|
|
|
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
|
|
|
|
for position in positions:
|
|
|
|
|
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid, position))
|
|
|
|
|
if device == 'cpu':
|
|
|
|
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
|
|
|
|
else:
|
|
|
|
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
|
2021-12-15 12:26:47 +08:00
|
|
|
|
|
|
|
|
|
def get_onnx_params(self, position):
|
2021-12-19 17:48:24 +08:00
|
|
|
|
self.models[position].get_onnx_params(self.device)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
2022-01-04 11:12:36 +08:00
|
|
|
|
def forward(self, position, z, x, device='cpu', return_value=False, flags=None):
|
|
|
|
|
return forward_logic(self, position, z, x, device, return_value, flags)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
|
|
def share_memory(self):
|
2021-12-15 22:09:18 +08:00
|
|
|
|
if self.models['landlord'] is not None:
|
|
|
|
|
self.models['landlord'].share_memory()
|
|
|
|
|
self.models['landlord_up'].share_memory()
|
|
|
|
|
self.models['landlord_front'].share_memory()
|
|
|
|
|
self.models['landlord_down'].share_memory()
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
|
|
def eval(self):
|
2021-12-15 22:09:18 +08:00
|
|
|
|
if self.models['landlord'] is not None:
|
|
|
|
|
self.models['landlord'].eval()
|
|
|
|
|
self.models['landlord_up'].eval()
|
|
|
|
|
self.models['landlord_front'].eval()
|
|
|
|
|
self.models['landlord_down'].eval()
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
|
|
def parameters(self, position):
|
|
|
|
|
return self.models[position].parameters()
|
|
|
|
|
|
|
|
|
|
def get_model(self, position):
|
|
|
|
|
return self.models[position]
|
|
|
|
|
|
|
|
|
|
def get_models(self):
|
|
|
|
|
return self.models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model:
|
|
|
|
|
"""
|
|
|
|
|
The wrapper for the three models. We also wrap several
|
|
|
|
|
interfaces such as share_memory, eval, etc.
|
|
|
|
|
"""
|
2021-12-22 21:19:10 +08:00
|
|
|
|
def __init__(self, device=0, flags=None, lite_model = False):
|
2021-09-07 16:38:34 +08:00
|
|
|
|
self.models = {}
|
2021-12-15 22:09:18 +08:00
|
|
|
|
self.onnx_models = {}
|
|
|
|
|
self.flags = flags
|
2021-09-07 16:38:34 +08:00
|
|
|
|
if not device == "cpu":
|
|
|
|
|
device = 'cuda:' + str(device)
|
2021-12-19 17:48:24 +08:00
|
|
|
|
self.device = torch.device(device)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
# model = GeneralModel().to(torch.device(device))
|
2021-12-15 22:09:18 +08:00
|
|
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
|
|
|
|
if flags is not None and flags.enable_onnx:
|
|
|
|
|
for position in positions:
|
|
|
|
|
self.models[position] = None
|
|
|
|
|
else:
|
|
|
|
|
for position in positions:
|
2021-12-22 21:19:10 +08:00
|
|
|
|
if lite_model:
|
|
|
|
|
self.models[position] = GeneralModelLite().to(self.device)
|
|
|
|
|
else:
|
|
|
|
|
self.models[position] = GeneralModel().to(self.device)
|
2021-12-15 22:09:18 +08:00
|
|
|
|
self.onnx_models = {
|
|
|
|
|
'landlord': None,
|
|
|
|
|
'landlord_up': None,
|
|
|
|
|
'landlord_front': None,
|
2021-12-19 17:19:32 +08:00
|
|
|
|
'landlord_down': None
|
2021-12-15 22:09:18 +08:00
|
|
|
|
}
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
2021-12-19 20:00:42 +08:00
|
|
|
|
def set_onnx_model(self, device='cpu'):
|
2021-12-15 22:09:18 +08:00
|
|
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
|
|
|
|
for position in positions:
|
2021-12-21 10:46:24 +08:00
|
|
|
|
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid, position))
|
2021-12-19 20:00:42 +08:00
|
|
|
|
if device == 'cpu':
|
|
|
|
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
|
|
|
|
else:
|
|
|
|
|
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
|
2021-12-14 22:55:03 +08:00
|
|
|
|
|
2021-12-15 12:26:47 +08:00
|
|
|
|
def get_onnx_params(self, position):
|
2021-12-19 17:48:24 +08:00
|
|
|
|
self.models[position].get_onnx_params(self.device)
|
2021-12-15 12:26:47 +08:00
|
|
|
|
|
2022-01-04 11:12:36 +08:00
|
|
|
|
def forward(self, position, z, x, device='cpu', return_value=False, flags=None):
|
|
|
|
|
return forward_logic(self, position, z, x, device, return_value, flags)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
|
|
def share_memory(self):
|
2021-12-15 22:09:18 +08:00
|
|
|
|
if self.models['landlord'] is not None:
|
|
|
|
|
self.models['landlord'].share_memory()
|
|
|
|
|
self.models['landlord_up'].share_memory()
|
|
|
|
|
self.models['landlord_front'].share_memory()
|
|
|
|
|
self.models['landlord_down'].share_memory()
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
|
|
def eval(self):
|
2021-12-15 22:09:18 +08:00
|
|
|
|
if self.models['landlord'] is not None:
|
|
|
|
|
self.models['landlord'].eval()
|
|
|
|
|
self.models['landlord_up'].eval()
|
|
|
|
|
self.models['landlord_front'].eval()
|
|
|
|
|
self.models['landlord_down'].eval()
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
|
|
def parameters(self, position):
|
|
|
|
|
return self.models[position].parameters()
|
|
|
|
|
|
|
|
|
|
def get_model(self, position):
|
|
|
|
|
return self.models[position]
|
|
|
|
|
|
|
|
|
|
def get_models(self):
|
|
|
|
|
return self.models
|
|
|
|
|
|