""" This file includes the torch models. We wrap the three models into one class for convenience. """ import os import numpy as np import torch import onnxruntime from onnxruntime.datasets import get_example from torch import nn import torch.nn.functional as F class LandlordLstmModel(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(432, 128, batch_first=True) self.dense1 = nn.Linear(887 + 128, 1024) self.dense2 = nn.Linear(1024, 1024) self.dense3 = nn.Linear(1024, 768) self.dense4 = nn.Linear(768, 512) self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) def get_onnx_params(self, device): return { 'args': ( torch.randn(1, 5, 432, requires_grad=True, device=device), torch.randn(1, 887, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "batch_size" }, 'x_batch': { 0: "batch_size" }, 'values': { 0: "batch_size" } } } def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) class FarmerLstmModel(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(432, 128, batch_first=True) self.dense1 = nn.Linear(1219 + 128, 1024) self.dense2 = nn.Linear(1024, 1024) self.dense3 = nn.Linear(1024, 768) self.dense4 = nn.Linear(768, 512) self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) def get_onnx_params(self, device): return { 'args': ( torch.randn(1, 5, 432, requires_grad=True, device=device), torch.randn(1, 1219, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "legal_actions" }, 'x_batch': { 0: "legal_actions" }, 'values': { 0: "batch_size" } } } def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) class LandlordLstmModelLite(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(276, 128, batch_first=True) self.dense1 = nn.Linear(590 + 128, 768) self.dense2 = nn.Linear(768, 768) self.dense3 = nn.Linear(768, 768) self.dense4 = nn.Linear(768, 768) self.dense5 = nn.Linear(768, 768) self.dense6 = nn.Linear(768, 1) def get_onnx_params(self, device): return { 'args': ( torch.randn(1, 5, 276, requires_grad=True, device=device), torch.randn(1, 590, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "batch_size" }, 'x_batch': { 0: "batch_size" }, 'values': { 0: "batch_size" } } } def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) class FarmerLstmModelLite(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(276, 128, batch_first=True) self.dense1 = nn.Linear(798 + 128, 768) self.dense2 = nn.Linear(768, 768) self.dense3 = nn.Linear(768, 768) self.dense4 = nn.Linear(768, 768) self.dense5 = nn.Linear(768, 768) self.dense6 = nn.Linear(768, 1) def get_onnx_params(self, device): return { 'args': ( torch.randn(1, 5, 276, requires_grad=True, device=device), torch.randn(1, 798, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "legal_actions" }, 'x_batch': { 0: "legal_actions" }, 'values': { 0: "batch_size" } } } def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) class LandlordLstmModelLegacy(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(432, 128, batch_first=True) self.dense1 = nn.Linear(860 + 128, 1024) self.dense2 = nn.Linear(1024, 1024) self.dense3 = nn.Linear(1024, 768) self.dense4 = nn.Linear(768, 512) self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) def get_onnx_params(self, device): return { 'args': ( torch.randn(1, 5, 432, requires_grad=True, device=device), torch.randn(1, 860, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "batch_size" }, 'x_batch': { 0: "batch_size" }, 'values': { 0: "batch_size" } } } def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) class FarmerLstmModelLegacy(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(432, 128, batch_first=True) self.dense1 = nn.Linear(1192 + 128, 1024) self.dense2 = nn.Linear(1024, 1024) self.dense3 = nn.Linear(1024, 768) self.dense4 = nn.Linear(768, 512) self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) def get_onnx_params(self, device): return { 'args': ( torch.randn(1, 5, 432, requires_grad=True, device=device), torch.randn(1, 1192, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "legal_actions" }, 'x_batch': { 0: "legal_actions" }, 'values': { 0: "batch_size" } } } def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) # 用于ResNet18和34的残差块,用的是2个3x3的卷积 class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=(3,), stride=(stride,), padding=1, bias=False) self.bn1 = nn.BatchNorm1d(planes) self.conv2 = nn.Conv1d(planes, planes, kernel_size=(3,), stride=(1,), padding=1, bias=False) self.bn2 = nn.BatchNorm1d(planes) self.shortcut = nn.Sequential() # 经过处理后的x要与x的维度相同(尺寸和深度) # 如果不相同,需要添加卷积+BN来变换为同一维度 if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv1d(in_planes, self.expansion * planes, kernel_size=(1,), stride=(stride,), bias=False), nn.BatchNorm1d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class GeneralModelLite(nn.Module): def __init__(self): super().__init__() self.in_planes = 80 #input 1*69*41 self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,), stride=(2,), padding=1, bias=False) #1*35*80 self.bn1 = nn.BatchNorm1d(80) self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*18*80 self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*9*160 self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*5*320 # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear1 = nn.Linear(320 * BasicBlock.expansion * 5 + 56, 1536) self.linear2 = nn.Linear(1536, 768) self.linear3 = nn.Linear(768, 384) self.linear4 = nn.Linear(384, 1) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def get_onnx_params(self, device=None): return { 'args': ( torch.randn(1, 40, 69, requires_grad=True, device=device), torch.randn(1, 56, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "batch_size" }, 'x_batch': { 0: "batch_size" }, 'values': { 0: "batch_size" } } } def forward(self, z, x): out = F.relu(self.bn1(self.conv1(z))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = out.flatten(1,2) out = torch.cat([x,out], dim=-1) out = F.leaky_relu_(self.linear1(out)) out = F.leaky_relu_(self.linear2(out)) out = F.leaky_relu_(self.linear3(out)) out = F.leaky_relu_(self.linear4(out)) return dict(values=out) class GeneralModel(nn.Module): def __init__(self): super().__init__() self.in_planes = 80 #input 1*108*41 self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,), stride=(2,), padding=1, bias=False) #1*108*80 self.bn1 = nn.BatchNorm1d(80) self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80 self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160 self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320 self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640 # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 56, 2048) self.linear2 = nn.Linear(2048, 1024) self.linear3 = nn.Linear(1024, 512) self.linear4 = nn.Linear(512, 256) self.linear5 = nn.Linear(256, 1) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def get_onnx_params(self, device=None): return { 'args': ( torch.randn(1, 40, 108, requires_grad=True, device=device), torch.randn(1, 56, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "batch_size" }, 'x_batch': { 0: "batch_size" }, 'values': { 0: "batch_size" } } } def forward(self, z, x): out = F.relu(self.bn1(self.conv1(z))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = out.flatten(1,2) out = torch.cat([x,out], dim=-1) out = F.leaky_relu_(self.linear1(out)) out = F.leaky_relu_(self.linear2(out)) out = F.leaky_relu_(self.linear3(out)) out = F.leaky_relu_(self.linear4(out)) out = F.leaky_relu_(self.linear5(out)) return dict(values=out) # Model dict is only used in evaluation but not training model_dict = {} model_dict['landlord'] = LandlordLstmModel model_dict['landlord_up'] = FarmerLstmModel model_dict['landlord_front'] = FarmerLstmModel model_dict['landlord_down'] = FarmerLstmModel model_dict_lite = {} model_dict_lite['landlord'] = LandlordLstmModelLite model_dict_lite['landlord_up'] = FarmerLstmModelLite model_dict_lite['landlord_front'] = FarmerLstmModelLite model_dict_lite['landlord_down'] = FarmerLstmModelLite model_dict_legacy = {} model_dict_legacy['landlord'] = LandlordLstmModelLegacy model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy model_dict_new = {} model_dict_new['landlord'] = GeneralModel model_dict_new['landlord_up'] = GeneralModel model_dict_new['landlord_front'] = GeneralModel model_dict_new['landlord_down'] = GeneralModel model_dict_new_lite = {} model_dict_new_lite['landlord'] = GeneralModelLite model_dict_new_lite['landlord_up'] = GeneralModelLite model_dict_new_lite['landlord_front'] = GeneralModelLite model_dict_new_lite['landlord_down'] = GeneralModelLite def forward_logic(self_model, position, z, x, device='cpu', return_value=False, flags=None): legal_count = len(z) if not flags.enable_onnx: z = torch.tensor(z, device=device) x = torch.tensor(x, device=device) if legal_count >= 80: partition_count = int(legal_count / 40) sub_z = np.array_split(z, partition_count) sub_x = np.array_split(x, partition_count) if flags.enable_onnx: model = self_model.onnx_models[position] if legal_count >= 80: values = np.ndarray((legal_count, 1)) j = 0 for i in range(partition_count): onnx_out = model.run(None, {'z_batch': sub_z[i], 'x_batch': sub_x[i]}) values[j:j+len(sub_z[i])] = onnx_out[0] j += len(sub_z[i]) else: onnx_out = model.run(None, {'z_batch': z, 'x_batch': x}) values = onnx_out[0] else: if legal_count >= 80: values = np.ndarray((legal_count, 1)) j = 0 for i in range(partition_count): model = self_model.models[position] model_out = model.forward(sub_z[i], sub_x[i])['values'] values[j:j+len(sub_z[i])] = model_out.cpu().detach().numpy() j += len(sub_z[i]) else: model = self_model.models[position] values = model.forward(z, x)['values'] if return_value: return dict(values = values.cpu().detach().numpy() if torch.is_tensor(values) else values) else: if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: if torch.is_tensor(values): action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy() else: action = np.random.randint(0, values.shape[0], (1,))[0] else: if torch.is_tensor(values): action = torch.argmax(values,dim=0)[0].cpu().detach().numpy() else: action = np.argmax(values, axis=0)[0] return dict(action = action) class OldModel: """ The wrapper for the three models. We also wrap several interfaces such as share_memory, eval, etc. """ def __init__(self, device=0, flags=None, lite_model = False): self.models = {} self.onnx_models = {} self.flags = flags if not device == "cpu": device = 'cuda:' + str(device) self.device = torch.device(device) positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] if flags is not None and flags.enable_onnx: for position in positions: self.models[position] = None else: for position in positions[1:]: self.models[position] = FarmerLstmModelLite().to(self.device) if lite_model else FarmerLstmModel().to(self.device) self.models['landlord'] = LandlordLstmModelLite().to(self.device) if lite_model else LandlordLstmModel().to(self.device) self.onnx_models = { 'landlord': None, 'landlord_up': None, 'landlord_front': None, 'landlord_down': None } def set_onnx_model(self, device='cpu'): positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] for position in positions: model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid, position)) if device == 'cpu': self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider']) else: self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider']) def get_onnx_params(self, position): self.models[position].get_onnx_params(self.device) def forward(self, position, z, x, device='cpu', return_value=False, flags=None): return forward_logic(self, position, z, x, device, return_value, flags) def share_memory(self): if self.models['landlord'] is not None: self.models['landlord'].share_memory() self.models['landlord_up'].share_memory() self.models['landlord_front'].share_memory() self.models['landlord_down'].share_memory() def eval(self): if self.models['landlord'] is not None: self.models['landlord'].eval() self.models['landlord_up'].eval() self.models['landlord_front'].eval() self.models['landlord_down'].eval() def parameters(self, position): return self.models[position].parameters() def get_model(self, position): return self.models[position] def get_models(self): return self.models class Model: """ The wrapper for the three models. We also wrap several interfaces such as share_memory, eval, etc. """ def __init__(self, device=0, flags=None, lite_model = False): self.models = {} self.onnx_models = {} self.flags = flags if not device == "cpu": device = 'cuda:' + str(device) self.device = torch.device(device) # model = GeneralModel().to(torch.device(device)) positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] if flags is not None and flags.enable_onnx: for position in positions: self.models[position] = None else: for position in positions: if lite_model: self.models[position] = GeneralModelLite().to(self.device) else: self.models[position] = GeneralModel().to(self.device) self.onnx_models = { 'landlord': None, 'landlord_up': None, 'landlord_front': None, 'landlord_down': None } def set_onnx_model(self, device='cpu'): positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] for position in positions: model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid, position)) if device == 'cpu': self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider']) else: self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider']) def get_onnx_params(self, position): self.models[position].get_onnx_params(self.device) def forward(self, position, z, x, device='cpu', return_value=False, flags=None): return forward_logic(self, position, z, x, device, return_value, flags) def share_memory(self): if self.models['landlord'] is not None: self.models['landlord'].share_memory() self.models['landlord_up'].share_memory() self.models['landlord_front'].share_memory() self.models['landlord_down'].share_memory() def eval(self): if self.models['landlord'] is not None: self.models['landlord'].eval() self.models['landlord_up'].eval() self.models['landlord_front'].eval() self.models['landlord_down'].eval() def parameters(self, position): return self.models[position].parameters() def get_model(self, position): return self.models[position] def get_models(self): return self.models