Douzero_Resnet/douzero/dmc/models.py

631 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
This file includes the torch models. We wrap the three
models into one class for convenience.
"""
import os
import numpy as np
import torch
import onnxruntime
from onnxruntime.datasets import get_example
from torch import nn
import torch.nn.functional as F
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
class LandlordLstmModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(887 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self):
return {
'args': (
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
torch.tensor(np.zeros((1, 887)), dtype=torch.float32),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class FarmerLstmModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(1219 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self):
return {
'args': (
torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32),
torch.tensor(np.zeros((1, 1219)), dtype=torch.float32)
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class LandlordLstmModelLegacy(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(860 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class FarmerLstmModelLegacy(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(1192 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class GeneralModel1(nn.Module):
def __init__(self):
super().__init__()
# input: B * 32 * 57
# self.lstm = nn.LSTM(162, 512, batch_first=True)
self.conv_z_1 = torch.nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(1,57)), # B * 1 * 64 * 32
nn.ReLU(inplace=True),
nn.BatchNorm2d(64),
)
# Squeeze(-1) B * 64 * 16
self.conv_z_2 = torch.nn.Sequential(
nn.Conv1d(64, 128, kernel_size=(5,), padding=2), # 128 * 16
nn.ReLU(inplace=True),
nn.BatchNorm1d(128),
)
self.conv_z_3 = torch.nn.Sequential(
nn.Conv1d(128, 256, kernel_size=(3,), padding=1), # 256 * 8
nn.ReLU(inplace=True),
nn.BatchNorm1d(256),
)
self.conv_z_4 = torch.nn.Sequential(
nn.Conv1d(256, 512, kernel_size=(3,), padding=1), # 512 * 4
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
)
self.dense1 = nn.Linear(519 + 1024, 1024)
self.dense2 = nn.Linear(1024, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x):
z = z.unsqueeze(1)
z = self.conv_z_1(z)
z = z.squeeze(-1)
z = torch.max_pool1d(z, 2)
z = self.conv_z_2(z)
z = torch.max_pool1d(z, 2)
z = self.conv_z_3(z)
z = torch.max_pool1d(z, 2)
z = self.conv_z_4(z)
z = torch.max_pool1d(z, 2)
z = z.flatten(1,2)
x = torch.cat([z,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
# 用于ResNet18和34的残差块用的是2个3x3的卷积
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=(3,),
stride=(stride,), padding=1, bias=False)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = nn.Conv1d(planes, planes, kernel_size=(3,),
stride=(1,), padding=1, bias=False)
self.bn2 = nn.BatchNorm1d(planes)
self.shortcut = nn.Sequential()
# 经过处理后的x要与x的维度相同(尺寸和深度)
# 如果不相同,需要添加卷积+BN来变换为同一维度
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv1d(in_planes, self.expansion * planes,
kernel_size=(1,), stride=(stride,), bias=False),
nn.BatchNorm1d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class GeneralModelLegacy(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 80
#input 1*108*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*108*80
self.bn1 = nn.BatchNorm1d(80)
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256)
self.linear5 = nn.Linear(256, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.flatten(1,2)
out = torch.cat([x,x,x,x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
return dict(values=out)
class GeneralModel(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 80
#input 1*108*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*108*80
self.bn1 = nn.BatchNorm1d(80)
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 80, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256)
self.linear5 = nn.Linear(256, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def get_onnx_params(self):
return {
'args': (
torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32, device='cuda:0'),
torch.tensor(np.zeros((1, 80)), dtype=torch.float32, device='cuda:0')
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
}
}
}
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.flatten(1,2)
out = torch.cat([x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
return dict(values=out)
class BidModel(nn.Module):
def __init__(self):
super().__init__()
self.dense1 = nn.Linear(208, 512)
self.dense2 = nn.Linear(512, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x):
x = self.dense1(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense2(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense3(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense4(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense5(x)
# x = F.relu(x)
x = F.leaky_relu(x)
x = self.dense6(x)
return dict(values=x)
# Model dict is only used in evaluation but not training
model_dict = {}
model_dict['landlord'] = LandlordLstmModel
model_dict['landlord_up'] = FarmerLstmModel
model_dict['landlord_front'] = FarmerLstmModel
model_dict['landlord_down'] = FarmerLstmModel
model_dict_legacy = {}
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy
model_dict_new_legacy = {}
model_dict_new_legacy['landlord'] = GeneralModelLegacy
model_dict_new_legacy['landlord_up'] = GeneralModelLegacy
model_dict_new_legacy['landlord_front'] = GeneralModelLegacy
model_dict_new_legacy['landlord_down'] = GeneralModelLegacy
model_dict_new_legacy['bidding'] = BidModel
model_dict_new = {}
model_dict_new['landlord'] = GeneralModel
model_dict_new['landlord_up'] = GeneralModel
model_dict_new['landlord_front'] = GeneralModel
model_dict_new['landlord_down'] = GeneralModel
model_dict_new['bidding'] = BidModel
model_dict_lstm = {}
model_dict_lstm['landlord'] = GeneralModel
model_dict_lstm['landlord_up'] = GeneralModel
model_dict_lstm['landlord_front'] = GeneralModel
model_dict_lstm['landlord_down'] = GeneralModel
class General_Model:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0):
self.models = {}
if not device == "cpu":
device = 'cuda:' + str(device)
# model = GeneralModel().to(torch.device(device))
self.models['landlord'] = GeneralModel1().to(torch.device(device))
self.models['landlord_up'] = GeneralModel1().to(torch.device(device))
self.models['landlord_front'] = GeneralModel1().to(torch.device(device))
self.models['landlord_down'] = GeneralModel1().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
model = self.models[position]
values = model.forward(z, x)['values']
if return_value:
return dict(values=values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(values.shape[0], (1,))[0]
else:
action = torch.argmax(values,dim=0)[0]
return dict(action=action)
def share_memory(self):
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_front'].share_memory()
self.models['landlord_down'].share_memory()
self.models['bidding'].share_memory()
def eval(self):
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_front'].eval()
self.models['landlord_down'].eval()
self.models['bidding'].eval()
def parameters(self, position):
return self.models[position].parameters()
def get_model(self, position):
return self.models[position]
def get_models(self):
return self.models
class OldModel:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0, flags=None):
self.models = {}
if not device == "cpu":
device = 'cuda:' + str(device)
self.models['landlord'] = LandlordLstmModel().to(torch.device(device))
self.models['landlord_up'] = FarmerLstmModel().to(torch.device(device))
self.models['landlord_front'] = FarmerLstmModel().to(torch.device(device))
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
self.onnx_models = {
'landlord': None,
'landlord_up': None,
'landlord_front': None,
'landlord_down': None,
'bidding': None
}
def set_onnx_model(self, position, model_path):
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path))
def get_onnx_params(self, position):
self.models[position].get_onnx_params()
def forward(self, position, z, x, return_value=False, flags=None):
model = self.onnx_models[position]
if model is None:
model = self.models[position]
values = model.forward(z, x)['values']
else:
onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
values = torch.tensor(onnx_out[0])
if return_value:
return dict(values=values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(values.shape[0], (1,))[0]
else:
action = torch.argmax(values,dim=0)[0]
return dict(action=action)
def share_memory(self):
if self.models['landlord'] is not None:
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_front'].share_memory()
self.models['landlord_down'].share_memory()
self.models['bidding'].share_memory()
def eval(self):
if self.models['landlord'] is not None:
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_front'].eval()
self.models['landlord_down'].eval()
self.models['bidding'].eval()
def parameters(self, position):
return self.models[position].parameters()
def get_model(self, position):
return self.models[position]
def get_models(self):
return self.models
class Model:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0, flags=None):
self.models = {}
self.onnx_models = {}
self.flags = flags
if not device == "cpu":
device = 'cuda:' + str(device)
# model = GeneralModel().to(torch.device(device))
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if flags is not None and flags.enable_onnx:
self.models['bidding'] = BidModel().to(torch.device(device))
for position in positions:
self.models[position] = None
else:
for position in positions:
self.models[position] = GeneralModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
self.onnx_models = {
'landlord': None,
'landlord_up': None,
'landlord_front': None,
'landlord_down': None,
'bidding': None
}
def set_onnx_model(self):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
for position in positions:
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.savedir, self.flags.xpid, position))
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.onnx_models['bidding'] = None
def get_onnx_params(self, position):
self.models[position].get_onnx_params()
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
if self.flags.enable_onnx and len(self.onnx_models) == 0:
self.set_onnx_model()
model = self.onnx_models[position]
if model is None:
model = self.models[position]
values = model.forward(z, x)['values']
else:
onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)})
values = torch.tensor(onnx_out[0])
if return_value:
return dict(values=values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(values.shape[0], (1,))[0]
else:
action = torch.argmax(values,dim=0)[0]
return dict(action=action)
def share_memory(self):
if self.models['landlord'] is not None:
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_front'].share_memory()
self.models['landlord_down'].share_memory()
self.models['bidding'].share_memory()
def eval(self):
if self.models['landlord'] is not None:
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_front'].eval()
self.models['landlord_down'].eval()
self.models['bidding'].eval()
def parameters(self, position):
return self.models[position].parameters()
def get_model(self, position):
return self.models[position]
def get_models(self):
return self.models