rlcard-showdown/pve_server/models.py

199 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
import torch.nn.functional as F
class LandlordLstmModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(162, 128, batch_first = True)
self.dense1 = nn.Linear(373 + 128, 512)
self.dense2 = nn.Linear(512, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return x
class FarmerLstmModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(162, 128, batch_first = True)
self.dense1 = nn.Linear(484 + 128 , 512)
self.dense2 = nn.Linear(512, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return x
# 用于ResNet18和34的残差块用的是2个3x3的卷积
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=(3,),
stride=(stride,), padding=1, bias=False)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = nn.Conv1d(planes, planes, kernel_size=(3,),
stride=(1,), padding=1, bias=False)
self.bn2 = nn.BatchNorm1d(planes)
self.shortcut = nn.Sequential()
# 经过处理后的x要与x的维度相同(尺寸和深度)
# 如果不相同,需要添加卷积+BN来变换为同一维度
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv1d(in_planes, self.expansion * planes,
kernel_size=(1,), stride=(stride,), bias=False),
nn.BatchNorm1d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class GeneralModel(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 80
#input 1*108*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*108*80
self.bn1 = nn.BatchNorm1d(80)
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 80, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256)
self.linear5 = nn.Linear(256, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def get_onnx_params(self):
return {
'args': (
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32, device='cuda:0'),
torch.tensor(np.zeros((1, 80)), dtype=torch.float32, device='cuda:0')
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
}
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.flatten(1,2)
out = torch.cat([x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
return dict(values=out)
class BidModel(nn.Module):
def __init__(self):
super().__init__()
self.dense1 = nn.Linear(208, 512)
self.dense2 = nn.Linear(512, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x):
x = self.dense1(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense2(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense3(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense4(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense5(x)
# x = F.relu(x)
x = F.leaky_relu(x)
x = self.dense6(x)
return dict(values=x)
model_dict = {}
model_dict['landlord'] = LandlordLstmModel
model_dict['landlord_up'] = FarmerLstmModel
model_dict['landlord_down'] = FarmerLstmModel
model_dict_new = {}
model_dict_new['landlord'] = GeneralModel
model_dict_new['landlord_up'] = GeneralModel
model_dict_new['landlord_front'] = GeneralModel
model_dict_new['landlord_down'] = GeneralModel
model_dict_new['bidding'] = BidModel