""" This file includes the torch models. We wrap the three models into one class for convenience. """ import numpy as np import torch import onnxruntime from onnxruntime.datasets import get_example from torch import nn import torch.nn.functional as F def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() class LandlordLstmModel(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(432, 128, batch_first=True) self.dense1 = nn.Linear(887 + 128, 1024) self.dense2 = nn.Linear(1024, 1024) self.dense3 = nn.Linear(1024, 768) self.dense4 = nn.Linear(768, 512) self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) def get_onnx_params(self): return { 'args': ( torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32), torch.tensor(np.zeros((1, 887)), dtype=torch.float32), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "legal_actions" }, 'x_batch': { 0: "legal_actions" } } } def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) class FarmerLstmModel(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(432, 128, batch_first=True) self.dense1 = nn.Linear(1219 + 128, 1024) self.dense2 = nn.Linear(1024, 1024) self.dense3 = nn.Linear(1024, 768) self.dense4 = nn.Linear(768, 512) self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) def get_onnx_params(self): return { 'args': ( torch.tensor(np.zeros((1, 5, 432)), dtype=torch.float32), torch.tensor(np.zeros((1, 1219)), dtype=torch.float32) ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "legal_actions" }, 'x_batch': { 0: "legal_actions" } } } def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) class LandlordLstmModelLegacy(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(432, 128, batch_first=True) self.dense1 = nn.Linear(860 + 128, 1024) self.dense2 = nn.Linear(1024, 1024) self.dense3 = nn.Linear(1024, 768) self.dense4 = nn.Linear(768, 512) self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) class FarmerLstmModelLegacy(nn.Module): def __init__(self): super().__init__() self.lstm = nn.LSTM(432, 128, batch_first=True) self.dense1 = nn.Linear(1192 + 128, 1024) self.dense2 = nn.Linear(1024, 1024) self.dense3 = nn.Linear(1024, 768) self.dense4 = nn.Linear(768, 512) self.dense5 = nn.Linear(512, 256) self.dense6 = nn.Linear(256, 1) def forward(self, z, x): lstm_out, (h_n, _) = self.lstm(z) lstm_out = lstm_out[:,-1,:] x = torch.cat([lstm_out,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) class GeneralModel1(nn.Module): def __init__(self): super().__init__() # input: B * 32 * 57 # self.lstm = nn.LSTM(162, 512, batch_first=True) self.conv_z_1 = torch.nn.Sequential( nn.Conv2d(1, 64, kernel_size=(1,57)), # B * 1 * 64 * 32 nn.ReLU(inplace=True), nn.BatchNorm2d(64), ) # Squeeze(-1) B * 64 * 16 self.conv_z_2 = torch.nn.Sequential( nn.Conv1d(64, 128, kernel_size=(5,), padding=2), # 128 * 16 nn.ReLU(inplace=True), nn.BatchNorm1d(128), ) self.conv_z_3 = torch.nn.Sequential( nn.Conv1d(128, 256, kernel_size=(3,), padding=1), # 256 * 8 nn.ReLU(inplace=True), nn.BatchNorm1d(256), ) self.conv_z_4 = torch.nn.Sequential( nn.Conv1d(256, 512, kernel_size=(3,), padding=1), # 512 * 4 nn.ReLU(inplace=True), nn.BatchNorm1d(512), ) self.dense1 = nn.Linear(519 + 1024, 1024) self.dense2 = nn.Linear(1024, 512) self.dense3 = nn.Linear(512, 512) self.dense4 = nn.Linear(512, 512) self.dense5 = nn.Linear(512, 512) self.dense6 = nn.Linear(512, 1) def forward(self, z, x): z = z.unsqueeze(1) z = self.conv_z_1(z) z = z.squeeze(-1) z = torch.max_pool1d(z, 2) z = self.conv_z_2(z) z = torch.max_pool1d(z, 2) z = self.conv_z_3(z) z = torch.max_pool1d(z, 2) z = self.conv_z_4(z) z = torch.max_pool1d(z, 2) z = z.flatten(1,2) x = torch.cat([z,x], dim=-1) x = self.dense1(x) x = torch.relu(x) x = self.dense2(x) x = torch.relu(x) x = self.dense3(x) x = torch.relu(x) x = self.dense4(x) x = torch.relu(x) x = self.dense5(x) x = torch.relu(x) x = self.dense6(x) return dict(values=x) # 用于ResNet18和34的残差块,用的是2个3x3的卷积 class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=(3,), stride=(stride,), padding=1, bias=False) self.bn1 = nn.BatchNorm1d(planes) self.conv2 = nn.Conv1d(planes, planes, kernel_size=(3,), stride=(1,), padding=1, bias=False) self.bn2 = nn.BatchNorm1d(planes) self.shortcut = nn.Sequential() # 经过处理后的x要与x的维度相同(尺寸和深度) # 如果不相同,需要添加卷积+BN来变换为同一维度 if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv1d(in_planes, self.expansion * planes, kernel_size=(1,), stride=(stride,), bias=False), nn.BatchNorm1d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class GeneralModelLegacy(nn.Module): def __init__(self): super().__init__() self.in_planes = 80 #input 1*108*41 self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,), stride=(2,), padding=1, bias=False) #1*108*80 self.bn1 = nn.BatchNorm1d(80) self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80 self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160 self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320 self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640 # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048) self.linear2 = nn.Linear(2048, 1024) self.linear3 = nn.Linear(1024, 512) self.linear4 = nn.Linear(512, 256) self.linear5 = nn.Linear(256, 1) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, z, x): out = F.relu(self.bn1(self.conv1(z))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = out.flatten(1,2) out = torch.cat([x,x,x,x,out], dim=-1) out = F.leaky_relu_(self.linear1(out)) out = F.leaky_relu_(self.linear2(out)) out = F.leaky_relu_(self.linear3(out)) out = F.leaky_relu_(self.linear4(out)) out = F.leaky_relu_(self.linear5(out)) return dict(values=out) class GeneralModel(nn.Module): def __init__(self): super().__init__() self.in_planes = 80 #input 1*108*41 self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,), stride=(2,), padding=1, bias=False) #1*108*80 self.bn1 = nn.BatchNorm1d(80) self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80 self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160 self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320 self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640 # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 80, 2048) self.linear2 = nn.Linear(2048, 1024) self.linear3 = nn.Linear(1024, 512) self.linear4 = nn.Linear(512, 256) self.linear5 = nn.Linear(256, 1) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def get_onnx_params(self): return { 'args': ( torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32), torch.tensor(np.zeros((1, 80)), dtype=torch.float32) ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], 'dynamic_axes': { 'z_batch': { 0: "legal_actions" }, 'x_batch': { 0: "legal_actions" } } } def forward(self, z, x): out = F.relu(self.bn1(self.conv1(z))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = out.flatten(1,2) out = torch.cat([x,out], dim=-1) out = F.leaky_relu_(self.linear1(out)) out = F.leaky_relu_(self.linear2(out)) out = F.leaky_relu_(self.linear3(out)) out = F.leaky_relu_(self.linear4(out)) out = F.leaky_relu_(self.linear5(out)) return dict(values=out) class BidModel(nn.Module): def __init__(self): super().__init__() self.dense1 = nn.Linear(208, 512) self.dense2 = nn.Linear(512, 512) self.dense3 = nn.Linear(512, 512) self.dense4 = nn.Linear(512, 512) self.dense5 = nn.Linear(512, 512) self.dense6 = nn.Linear(512, 1) def forward(self, z, x): x = self.dense1(x) x = F.leaky_relu(x) # x = F.relu(x) x = self.dense2(x) x = F.leaky_relu(x) # x = F.relu(x) x = self.dense3(x) x = F.leaky_relu(x) # x = F.relu(x) x = self.dense4(x) x = F.leaky_relu(x) # x = F.relu(x) x = self.dense5(x) # x = F.relu(x) x = F.leaky_relu(x) x = self.dense6(x) return dict(values=x) # Model dict is only used in evaluation but not training model_dict = {} model_dict['landlord'] = LandlordLstmModel model_dict['landlord_up'] = FarmerLstmModel model_dict['landlord_front'] = FarmerLstmModel model_dict['landlord_down'] = FarmerLstmModel model_dict_legacy = {} model_dict_legacy['landlord'] = LandlordLstmModelLegacy model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy model_dict_new_legacy = {} model_dict_new_legacy['landlord'] = GeneralModelLegacy model_dict_new_legacy['landlord_up'] = GeneralModelLegacy model_dict_new_legacy['landlord_front'] = GeneralModelLegacy model_dict_new_legacy['landlord_down'] = GeneralModelLegacy model_dict_new_legacy['bidding'] = BidModel model_dict_new = {} model_dict_new['landlord'] = GeneralModel model_dict_new['landlord_up'] = GeneralModel model_dict_new['landlord_front'] = GeneralModel model_dict_new['landlord_down'] = GeneralModel model_dict_new['bidding'] = BidModel model_dict_lstm = {} model_dict_lstm['landlord'] = GeneralModel model_dict_lstm['landlord_up'] = GeneralModel model_dict_lstm['landlord_front'] = GeneralModel model_dict_lstm['landlord_down'] = GeneralModel class General_Model: """ The wrapper for the three models. We also wrap several interfaces such as share_memory, eval, etc. """ def __init__(self, device=0): self.models = {} if not device == "cpu": device = 'cuda:' + str(device) # model = GeneralModel().to(torch.device(device)) self.models['landlord'] = GeneralModel1().to(torch.device(device)) self.models['landlord_up'] = GeneralModel1().to(torch.device(device)) self.models['landlord_front'] = GeneralModel1().to(torch.device(device)) self.models['landlord_down'] = GeneralModel1().to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device)) def forward(self, position, z, x, return_value=False, flags=None, debug=False): model = self.models[position] values = model.forward(z, x)['values'] if return_value: return dict(values=values) else: if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: action = torch.randint(values.shape[0], (1,))[0] else: action = torch.argmax(values,dim=0)[0] return dict(action=action) def share_memory(self): self.models['landlord'].share_memory() self.models['landlord_up'].share_memory() self.models['landlord_front'].share_memory() self.models['landlord_down'].share_memory() self.models['bidding'].share_memory() def eval(self): self.models['landlord'].eval() self.models['landlord_up'].eval() self.models['landlord_front'].eval() self.models['landlord_down'].eval() self.models['bidding'].eval() def parameters(self, position): return self.models[position].parameters() def get_model(self, position): return self.models[position] def get_models(self): return self.models class OldModel: """ The wrapper for the three models. We also wrap several interfaces such as share_memory, eval, etc. """ def __init__(self, device=0): self.models = {} if not device == "cpu": device = 'cuda:' + str(device) self.models['landlord'] = LandlordLstmModel().to(torch.device(device)) self.models['landlord_up'] = FarmerLstmModel().to(torch.device(device)) self.models['landlord_front'] = FarmerLstmModel().to(torch.device(device)) self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device)) self.onnx_models = { 'landlord': None, 'landlord_up': None, 'landlord_front': None, 'landlord_down': None, 'bidding': None } def set_onnx_model(self, position, model_path): self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path)) def get_onnx_params(self, position): self.models[position].get_onnx_params() def forward(self, position, z, x, return_value=False, flags=None): model = self.onnx_models[position] if model is None: model = self.models[position] values = model.forward(z, x)['values'] else: onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)}) values = torch.tensor(onnx_out[0]) if return_value: return dict(values=values) else: if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: action = torch.randint(values.shape[0], (1,))[0] else: action = torch.argmax(values,dim=0)[0] return dict(action=action) def share_memory(self): self.models['landlord'].share_memory() self.models['landlord_up'].share_memory() self.models['landlord_front'].share_memory() self.models['landlord_down'].share_memory() self.models['bidding'].share_memory() def eval(self): self.models['landlord'].eval() self.models['landlord_up'].eval() self.models['landlord_front'].eval() self.models['landlord_down'].eval() self.models['bidding'].eval() def parameters(self, position): return self.models[position].parameters() def get_model(self, position): return self.models[position] def get_models(self): return self.models class Model: """ The wrapper for the three models. We also wrap several interfaces such as share_memory, eval, etc. """ def __init__(self, device=0): self.models = {} if not device == "cpu": device = 'cuda:' + str(device) # model = GeneralModel().to(torch.device(device)) self.models['landlord'] = GeneralModel().to(torch.device(device)) self.models['landlord_up'] = GeneralModel().to(torch.device(device)) self.models['landlord_front'] = GeneralModel().to(torch.device(device)) self.models['landlord_down'] = GeneralModel().to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device)) self.onnx_models = { 'landlord': None, 'landlord_up': None, 'landlord_front': None, 'landlord_down': None, 'bidding': None } self.models['bidding'] = BidModel().to(torch.device(device)) def set_onnx_model(self, position, model_path): self.onnx_models[position] = onnxruntime.InferenceSession(get_example(model_path)) def get_onnx_params(self, position): self.models[position].get_onnx_params() def forward(self, position, z, x, return_value=False, flags=None, debug=False): model = self.onnx_models[position] if model is None: model = self.models[position] values = model.forward(z, x)['values'] else: onnx_out = model.run(None, {'z_batch': to_numpy(z), 'x_batch': to_numpy(x)}) values = torch.tensor(onnx_out[0]) if return_value: return dict(values=values) else: if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon: action = torch.randint(values.shape[0], (1,))[0] else: action = torch.argmax(values,dim=0)[0] return dict(action=action) def share_memory(self): self.models['landlord'].share_memory() self.models['landlord_up'].share_memory() self.models['landlord_front'].share_memory() self.models['landlord_down'].share_memory() self.models['bidding'].share_memory() def eval(self): self.models['landlord'].eval() self.models['landlord_up'].eval() self.models['landlord_front'].eval() self.models['landlord_down'].eval() self.models['bidding'].eval() def parameters(self, position): return self.models[position].parameters() def get_model(self, position): return self.models[position] def get_models(self): return self.models