Douzero_Resnet/douzero/dmc/models.py

814 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
This file includes the torch models. We wrap the three
models into one class for convenience.
"""
import os
import numpy as np
import torch
import onnxruntime
from onnxruntime.datasets import get_example
from torch import nn
import torch.nn.functional as F
class LandlordLstmModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(887 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 432, requires_grad=True, device=device),
torch.randn(1, 887, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class FarmerLstmModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(1219 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 432, requires_grad=True, device=device),
torch.randn(1, 1219, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class LandlordLstmModelLite(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(276, 128, batch_first=True)
self.dense1 = nn.Linear(590 + 128, 768)
self.dense2 = nn.Linear(768, 768)
self.dense3 = nn.Linear(768, 768)
self.dense4 = nn.Linear(768, 768)
self.dense5 = nn.Linear(768, 768)
self.dense6 = nn.Linear(768, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 276, requires_grad=True, device=device),
torch.randn(1, 590, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class FarmerLstmModelLite(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(276, 128, batch_first=True)
self.dense1 = nn.Linear(798 + 128, 768)
self.dense2 = nn.Linear(768, 768)
self.dense3 = nn.Linear(768, 768)
self.dense4 = nn.Linear(768, 768)
self.dense5 = nn.Linear(768, 768)
self.dense6 = nn.Linear(768, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 276, requires_grad=True, device=device),
torch.randn(1, 798, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class LandlordLstmModelLegacy(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(860 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 432, requires_grad=True, device=device),
torch.randn(1, 860, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
class FarmerLstmModelLegacy(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(1192 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(256, 1)
def get_onnx_params(self, device):
return {
'args': (
torch.randn(1, 5, 432, requires_grad=True, device=device),
torch.randn(1, 1192, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "legal_actions"
},
'x_batch': {
0: "legal_actions"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
return dict(values=x)
# 用于ResNet18和34的残差块用的是2个3x3的卷积
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=(3,),
stride=(stride,), padding=1, bias=False)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = nn.Conv1d(planes, planes, kernel_size=(3,),
stride=(1,), padding=1, bias=False)
self.bn2 = nn.BatchNorm1d(planes)
self.shortcut = nn.Sequential()
# 经过处理后的x要与x的维度相同(尺寸和深度)
# 如果不相同,需要添加卷积+BN来变换为同一维度
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv1d(in_planes, self.expansion * planes,
kernel_size=(1,), stride=(stride,), bias=False),
nn.BatchNorm1d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class GeneralModelLite(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 80
#input 1*69*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*35*80
self.bn1 = nn.BatchNorm1d(80)
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*18*80
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*9*160
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*5*320
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear1 = nn.Linear(320 * BasicBlock.expansion * 5 + 56, 1536)
self.linear2 = nn.Linear(1536, 768)
self.linear3 = nn.Linear(768, 384)
self.linear4 = nn.Linear(384, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def get_onnx_params(self, device=None):
return {
'args': (
torch.randn(1, 40, 69, requires_grad=True, device=device),
torch.randn(1, 56, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = out.flatten(1,2)
out = torch.cat([x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
return dict(values=out)
class UnifiedModelLite(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 30
#input 1*69*15
self.conv1 = nn.Conv1d(15, 30, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*35*30
self.bn1 = nn.BatchNorm1d(30)
self.layer1 = self._make_layer(BasicBlock, 30, 2, stride=2)#1*18*30
self.layer2 = self._make_layer(BasicBlock, 60, 2, stride=2)#1*9*60
self.layer3 = self._make_layer(BasicBlock, 120, 2, stride=2)#1*5*120
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.lstm = nn.LSTM(276, 224, batch_first=True)
self.linear1 = nn.Linear((120 * BasicBlock.expansion * 5 + 224) * 2, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 1024)
self.linear4 = nn.Linear(1024, 512)
self.linear5 = nn.Linear(512, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def get_onnx_params(self, device=None):
return {
'args': (
torch.randn(1, 15, 69, requires_grad=True, device=device),
torch.randn(1, 24, 276, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = out.flatten(1,2)
is_landlord = z[0][0][0]
lstm_out, (h_n, _) = self.lstm(x)
lstm_out = lstm_out[:,-1,:]
out = torch.cat([lstm_out,out], dim=1)
out = torch.cat([out * is_landlord, out * (1 - is_landlord)], dim=1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
return dict(values=out)
class GeneralModel(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 80
#input 1*108*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*108*80
self.bn1 = nn.BatchNorm1d(80)
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 56, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256)
self.linear5 = nn.Linear(256, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def get_onnx_params(self, device=None):
return {
'args': (
torch.randn(1, 40, 108, requires_grad=True, device=device),
torch.randn(1, 56, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.flatten(1,2)
out = torch.cat([x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
return dict(values=out)
# Model dict is only used in evaluation but not training
model_dict = {}
model_dict['landlord'] = LandlordLstmModel
model_dict['landlord_up'] = FarmerLstmModel
model_dict['landlord_front'] = FarmerLstmModel
model_dict['landlord_down'] = FarmerLstmModel
model_dict_lite = {}
model_dict_lite['landlord'] = LandlordLstmModelLite
model_dict_lite['landlord_up'] = FarmerLstmModelLite
model_dict_lite['landlord_front'] = FarmerLstmModelLite
model_dict_lite['landlord_down'] = FarmerLstmModelLite
model_dict_legacy = {}
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy
model_dict_new = {}
model_dict_new['landlord'] = GeneralModel
model_dict_new['landlord_up'] = GeneralModel
model_dict_new['landlord_front'] = GeneralModel
model_dict_new['landlord_down'] = GeneralModel
model_dict_new_lite = {}
model_dict_new_lite['landlord'] = GeneralModelLite
model_dict_new_lite['landlord_up'] = GeneralModelLite
model_dict_new_lite['landlord_front'] = GeneralModelLite
model_dict_new_lite['landlord_down'] = GeneralModelLite
model_dict_uni_lite = {}
model_dict_uni_lite['landlord'] = UnifiedModelLite
model_dict_uni_lite['landlord_up'] = UnifiedModelLite
model_dict_uni_lite['landlord_front'] = UnifiedModelLite
model_dict_uni_lite['landlord_down'] = UnifiedModelLite
def forward_logic(self_model, position, z, x, device='cpu', return_value=False, flags=None):
legal_count = len(z)
if not flags.enable_onnx:
z = torch.tensor(z, device=device)
x = torch.tensor(x, device=device)
if legal_count >= 80:
partition_count = int(legal_count / 40)
sub_z = np.array_split(z, partition_count)
sub_x = np.array_split(x, partition_count)
if flags.enable_onnx:
model = self_model.onnx_models[position]
if legal_count >= 80:
values = np.ndarray((legal_count, 1))
j = 0
for i in range(partition_count):
onnx_out = model.run(None, {'z_batch': sub_z[i], 'x_batch': sub_x[i]})
values[j:j+len(sub_z[i])] = onnx_out[0]
j += len(sub_z[i])
else:
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = onnx_out[0]
else:
if legal_count >= 80:
values = np.ndarray((legal_count, 1))
j = 0
for i in range(partition_count):
model = self_model.models[position]
model_out = model.forward(sub_z[i], sub_x[i])['values']
values[j:j+len(sub_z[i])] = model_out.cpu().detach().numpy()
j += len(sub_z[i])
else:
model = self_model.models[position]
values = model.forward(z, x)['values']
if return_value:
return dict(values = values.cpu().detach().numpy() if torch.is_tensor(values) else values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
if torch.is_tensor(values):
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
else:
action = np.random.randint(0, values.shape[0], (1,))[0]
else:
if torch.is_tensor(values):
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
else:
action = np.argmax(values, axis=0)[0]
return dict(action = action)
class OldModel:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0, flags=None, lite_model = False):
self.models = {}
self.onnx_models = {}
self.flags = flags
if not device == "cpu":
device = 'cuda:' + str(device)
self.device = torch.device(device)
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if flags is not None and flags.enable_onnx:
for position in positions:
self.models[position] = None
else:
for position in positions[1:]:
self.models[position] = FarmerLstmModelLite().to(self.device) if lite_model else FarmerLstmModel().to(self.device)
self.models['landlord'] = LandlordLstmModelLite().to(self.device) if lite_model else LandlordLstmModel().to(self.device)
self.onnx_models = {
'landlord': None,
'landlord_up': None,
'landlord_front': None,
'landlord_down': None
}
def set_onnx_model(self, device='cpu'):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
for position in positions:
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid, position))
if device == 'cpu':
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
else:
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
def get_onnx_params(self, position):
self.models[position].get_onnx_params(self.device)
def forward(self, position, z, x, device='cpu', return_value=False, flags=None):
return forward_logic(self, position, z, x, device, return_value, flags)
def share_memory(self):
if self.models['landlord'] is not None:
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_front'].share_memory()
self.models['landlord_down'].share_memory()
def eval(self):
if self.models['landlord'] is not None:
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_front'].eval()
self.models['landlord_down'].eval()
def parameters(self, position):
return self.models[position].parameters()
def get_model(self, position):
return self.models[position]
def get_models(self):
return self.models
class Model:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0, flags=None, lite_model = False):
self.models = {}
self.onnx_models = {}
self.flags = flags
if not device == "cpu":
device = 'cuda:' + str(device)
self.device = torch.device(device)
# model = GeneralModel().to(torch.device(device))
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if flags is not None and flags.enable_onnx:
for position in positions:
self.models[position] = None
else:
for position in positions:
if lite_model:
self.models[position] = GeneralModelLite().to(self.device)
else:
self.models[position] = GeneralModel().to(self.device)
self.onnx_models = {
'landlord': None,
'landlord_up': None,
'landlord_front': None,
'landlord_down': None
}
def set_onnx_model(self, device='cpu'):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
for position in positions:
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid, position))
if device == 'cpu':
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
else:
self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
def get_onnx_params(self, position):
self.models[position].get_onnx_params(self.device)
def forward(self, position, z, x, device='cpu', return_value=False, flags=None):
return forward_logic(self, position, z, x, device, return_value, flags)
def share_memory(self):
if self.models['landlord'] is not None:
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_front'].share_memory()
self.models['landlord_down'].share_memory()
def eval(self):
if self.models['landlord'] is not None:
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_front'].eval()
self.models['landlord_down'].eval()
def parameters(self, position):
return self.models[position].parameters()
def get_model(self, position):
return self.models[position]
def get_models(self):
return self.models
class UnifiedModel:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0, flags=None, lite_model = False):
self.onnx_models = {}
self.model = None
self.models = {}
self.flags = flags
if not device == "cpu":
device = 'cuda:' + str(device)
self.device = torch.device(device)
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if flags is not None and flags.enable_onnx:
self.model = None
else:
if lite_model:
self.model = UnifiedModelLite().to(self.device)
for position in positions:
self.models[position] = self.model
else:
self.model = GeneralModel().to(self.device)
self.onnx_model = None
def set_onnx_model(self, device='cpu'):
model_path = os.path.abspath('%s/%s/model_uni.onnx' % (self.flags.onnx_model_path, self.flags.xpid))
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if device == 'cpu':
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
else:
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
for position in positions:
self.onnx_models[position] = self.onnx_model
def get_onnx_params(self, position):
self.model.get_onnx_params(self.device)
def forward(self, position, z, x, device='cpu', return_value=False, flags=None):
return forward_logic(self, position, z, x, device, return_value, flags)
def share_memory(self):
if self.model is not None:
self.model.share_memory()
def eval(self):
if self.model is not None:
self.model.eval()
def parameters(self, position):
return self.model.parameters()
def get_model(self, position):
return self.model
def get_models(self):
return {
'uni' : self.model
}