参数调整

This commit is contained in:
ZaneYork 2021-12-12 14:01:40 +08:00
parent 64bf792d15
commit 3fec4a6bc1
4 changed files with 119 additions and 52 deletions

3
.gitignore vendored
View File

@ -3,3 +3,6 @@ baselines*/
douzero_checkpoints/ douzero_checkpoints/
.vscode/ .vscode/
*.pkl *.pkl
venv/
.idea/
*.egg-info/

View File

@ -13,7 +13,7 @@ class LandlordLstmModel(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True) self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(860 + 128, 1024) self.dense1 = nn.Linear(887 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024) self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768) self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512) self.dense4 = nn.Linear(768, 512)
@ -48,7 +48,7 @@ class FarmerLstmModel(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True) self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(1192 + 128, 1024) self.dense1 = nn.Linear(1219 + 128, 1024)
self.dense2 = nn.Linear(1024, 1024) self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(1024, 768) self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(768, 512) self.dense4 = nn.Linear(768, 512)
@ -79,16 +79,16 @@ class FarmerLstmModel(nn.Module):
action = torch.argmax(x,dim=0)[0] action = torch.argmax(x,dim=0)[0]
return dict(action=action) return dict(action=action)
class LandlordLstmNewModel(nn.Module): class LandlordLstmModelLegacy(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True) self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(846 + 128, 512) self.dense1 = nn.Linear(860 + 128, 1024)
self.dense2 = nn.Linear(512, 512) self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(512, 512) self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(512, 512) self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 512) self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(512, 1) self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None): def forward(self, z, x, return_value=False, flags=None):
lstm_out, (h_n, _) = self.lstm(z) lstm_out, (h_n, _) = self.lstm(z)
@ -114,16 +114,16 @@ class LandlordLstmNewModel(nn.Module):
action = torch.argmax(x,dim=0)[0] action = torch.argmax(x,dim=0)[0]
return dict(action=action) return dict(action=action)
class FarmerLstmNewModel(nn.Module): class FarmerLstmModelLegacy(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.lstm = nn.LSTM(432, 128, batch_first=True) self.lstm = nn.LSTM(432, 128, batch_first=True)
self.dense1 = nn.Linear(1178 + 128, 512) self.dense1 = nn.Linear(1192 + 128, 1024)
self.dense2 = nn.Linear(512, 512) self.dense2 = nn.Linear(1024, 1024)
self.dense3 = nn.Linear(512, 512) self.dense3 = nn.Linear(1024, 768)
self.dense4 = nn.Linear(512, 512) self.dense4 = nn.Linear(768, 512)
self.dense5 = nn.Linear(512, 512) self.dense5 = nn.Linear(512, 256)
self.dense6 = nn.Linear(512, 1) self.dense6 = nn.Linear(256, 1)
def forward(self, z, x, return_value=False, flags=None): def forward(self, z, x, return_value=False, flags=None):
lstm_out, (h_n, _) = self.lstm(z) lstm_out, (h_n, _) = self.lstm(z)
@ -249,11 +249,10 @@ class BasicBlock(nn.Module):
return out return out
class GeneralModel(nn.Module): class GeneralModelLegacy(nn.Module):
def __init__(self, use_legacy = False): def __init__(self):
super().__init__() super().__init__()
self.in_planes = 80 self.in_planes = 80
self.use_legacy = use_legacy
#input 1*108*41 #input 1*108*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,), self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*108*80 stride=(2,), padding=1, bias=False) #1*108*80
@ -265,10 +264,7 @@ class GeneralModel(nn.Module):
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320 self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640 self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
if self.use_legacy: self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048)
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 24 * 4, 2048)
else:
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 80, 2048)
self.linear2 = nn.Linear(2048, 1024) self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512) self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256) self.linear4 = nn.Linear(512, 256)
@ -289,10 +285,58 @@ class GeneralModel(nn.Module):
out = self.layer3(out) out = self.layer3(out)
out = self.layer4(out) out = self.layer4(out)
out = out.flatten(1,2) out = out.flatten(1,2)
if self.use_legacy: out = torch.cat([x,x,x,x,out], dim=-1)
out = torch.cat([x,x,x,x,out], dim=-1) out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
if return_value:
return dict(values=out)
else: else:
out = torch.cat([x,out], dim=-1) if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(out.shape[0], (1,))[0]
else:
action = torch.argmax(out,dim=0)[0]
return dict(action=action, max_value=torch.max(out))
class GeneralModel(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 80
#input 1*108*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*108*80
self.bn1 = nn.BatchNorm1d(80)
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*27*80
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*14*160
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*7*320
self.layer4 = self._make_layer(BasicBlock, 640, 2, stride=2)#1*4*640
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear1 = nn.Linear(640 * BasicBlock.expansion * 4 + 80, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256)
self.linear5 = nn.Linear(256, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, z, x, return_value=False, flags=None, debug=False):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.flatten(1,2)
out = torch.cat([x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out)) out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out)) out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out)) out = F.leaky_relu_(self.linear3(out))
@ -355,6 +399,17 @@ model_dict['landlord'] = LandlordLstmModel
model_dict['landlord_up'] = FarmerLstmModel model_dict['landlord_up'] = FarmerLstmModel
model_dict['landlord_front'] = FarmerLstmModel model_dict['landlord_front'] = FarmerLstmModel
model_dict['landlord_down'] = FarmerLstmModel model_dict['landlord_down'] = FarmerLstmModel
model_dict_legacy = {}
model_dict_legacy['landlord'] = LandlordLstmModelLegacy
model_dict_legacy['landlord_up'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_front'] = FarmerLstmModelLegacy
model_dict_legacy['landlord_down'] = FarmerLstmModelLegacy
model_dict_new_legacy = {}
model_dict_new_legacy['landlord'] = GeneralModelLegacy
model_dict_new_legacy['landlord_up'] = GeneralModelLegacy
model_dict_new_legacy['landlord_front'] = GeneralModelLegacy
model_dict_new_legacy['landlord_down'] = GeneralModelLegacy
model_dict_new_legacy['bidding'] = BidModel
model_dict_new = {} model_dict_new = {}
model_dict_new['landlord'] = GeneralModel model_dict_new['landlord'] = GeneralModel
model_dict_new['landlord_up'] = GeneralModel model_dict_new['landlord_up'] = GeneralModel
@ -458,15 +513,15 @@ class Model:
The wrapper for the three models. We also wrap several The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc. interfaces such as share_memory, eval, etc.
""" """
def __init__(self, device=0, use_legacy = False): def __init__(self, device=0):
self.models = {} self.models = {}
if not device == "cpu": if not device == "cpu":
device = 'cuda:' + str(device) device = 'cuda:' + str(device)
# model = GeneralModel().to(torch.device(device)) # model = GeneralModel().to(torch.device(device))
self.models['landlord'] = GeneralModel(use_legacy).to(torch.device(device)) self.models['landlord'] = GeneralModel().to(torch.device(device))
self.models['landlord_up'] = GeneralModel(use_legacy).to(torch.device(device)) self.models['landlord_up'] = GeneralModel().to(torch.device(device))
self.models['landlord_front'] = GeneralModel(use_legacy).to(torch.device(device)) self.models['landlord_front'] = GeneralModel().to(torch.device(device))
self.models['landlord_down'] = GeneralModel(use_legacy).to(torch.device(device)) self.models['landlord_down'] = GeneralModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device)) self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None, debug=False): def forward(self, position, z, x, training=False, flags=None, debug=False):

View File

@ -3,13 +3,19 @@ import numpy as np
from douzero.env.env import get_obs from douzero.env.env import get_obs
def _load_model(position, model_path, model_type): def _load_model(position, model_path, model_type, use_legacy):
from douzero.dmc.models import model_dict_new, model_dict from douzero.dmc.models import model_dict_new, model_dict, model_dict_new_legacy, model_dict_legacy
model = None model = None
if model_type == "general": if model_type == "general":
model = model_dict_new[position]() if use_legacy:
model = model_dict_new_legacy[position]()
else:
model = model_dict_new[position]()
else: else:
model = model_dict[position]() if use_legacy:
model = model_dict_legacy[position]()
else:
model = model_dict[position]()
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
if torch.cuda.is_available(): if torch.cuda.is_available():
pretrained = torch.load(model_path, map_location='cuda:0') pretrained = torch.load(model_path, map_location='cuda:0')
@ -27,8 +33,9 @@ def _load_model(position, model_path, model_type):
class DeepAgent: class DeepAgent:
def __init__(self, position, model_path): def __init__(self, position, model_path):
self.use_legacy = True if "legacy" in model_path else False
self.model_type = "general" if "resnet" in model_path else "old" self.model_type = "general" if "resnet" in model_path else "old"
self.model = _load_model(position, model_path, self.model_type) self.model = _load_model(position, model_path, self.model_type, self.use_legacy)
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'} 13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
@ -36,7 +43,7 @@ class DeepAgent:
if len(infoset.legal_actions) == 1: if len(infoset.legal_actions) == 1:
return infoset.legal_actions[0] return infoset.legal_actions[0]
obs = get_obs(infoset, self.model_type == "general") obs = get_obs(infoset, self.model_type == "general", self.use_legacy)
z_batch = torch.from_numpy(obs['z_batch']).float() z_batch = torch.from_numpy(obs['z_batch']).float()
x_batch = torch.from_numpy(obs['x_batch']).float() x_batch = torch.from_numpy(obs['x_batch']).float()

View File

@ -4,7 +4,7 @@ import argparse
from douzero.evaluation.simulation import evaluate from douzero.evaluation.simulation import evaluate
def make_evaluate(args, t, frame, adp_frame, folder_a = 'baselines', folder_b = 'baselines'): def make_evaluate(args, t, frame, adp_frame, folder_a = 'baselines', folder_b = 'baselines', prefix_a = '', prefix_b = ''):
if t == 0: if t == 0:
args.landlord = 'random' args.landlord = 'random'
args.landlord_up = 'random' args.landlord_up = 'random'
@ -24,16 +24,16 @@ def make_evaluate(args, t, frame, adp_frame, folder_a = 'baselines', folder_b =
args.landlord_down = '%s/resnet_landlord_down_%i.ckpt' % (folder_a, frame) args.landlord_down = '%s/resnet_landlord_down_%i.ckpt' % (folder_a, frame)
print('random vs %i' % frame) print('random vs %i' % frame)
elif t == 3: elif t == 3:
args.landlord = '%s/resnet_landlord_%i.ckpt' % (folder_a, frame) args.landlord = '%s/%sresnet_landlord_%i.ckpt' % (folder_a, prefix_a, frame)
args.landlord_up = '%s/resnet_landlord_up_%i.ckpt' % (folder_b, adp_frame) args.landlord_up = '%s/%sresnet_landlord_up_%i.ckpt' % (folder_b, prefix_b, adp_frame)
args.landlord_front = '%s/resnet_landlord_front_%i.ckpt' % (folder_b, adp_frame) args.landlord_front = '%s/%sresnet_landlord_front_%i.ckpt' % (folder_b, prefix_b, adp_frame)
args.landlord_down = '%s/resnet_landlord_down_%i.ckpt' % (folder_b, adp_frame) args.landlord_down = '%s/%sresnet_landlord_down_%i.ckpt' % (folder_b, prefix_b, adp_frame)
print('%i vs %i' % (frame, adp_frame)) print('%i vs %i' % (frame, adp_frame))
elif t == 4: elif t == 4:
args.landlord = '%s/resnet_landlord_%i.ckpt' % (folder_b, adp_frame) args.landlord = '%s/%sresnet_landlord_%i.ckpt' % (folder_b, prefix_b, adp_frame)
args.landlord_up = '%s/resnet_landlord_up_%i.ckpt' % (folder_a, frame) args.landlord_up = '%s/%sresnet_landlord_up_%i.ckpt' % (folder_a, prefix_a, frame)
args.landlord_front = '%s/resnet_landlord_front_%i.ckpt' % (folder_a, frame) args.landlord_front = '%s/%sresnet_landlord_front_%i.ckpt' % (folder_a, prefix_a, frame)
args.landlord_down = '%s/resnet_landlord_down_%i.ckpt' % (folder_a, frame) args.landlord_down = '%s/%sresnet_landlord_down_%i.ckpt' % (folder_a, prefix_a, frame)
print('%i vs %i' % (adp_frame, frame)) print('%i vs %i' % (adp_frame, frame))
evaluate(args.landlord, evaluate(args.landlord,
@ -108,8 +108,10 @@ if __name__ == '__main__':
# [14102400, 4968800, 'baselines', 'baselines'], # [14102400, 4968800, 'baselines', 'baselines'],
# [14102400, 13252000, 'baselines', 'baselines2'], # [14102400, 13252000, 'baselines', 'baselines2'],
# [14102400, 15096800, 'baselines', 'baselines2'], # [14102400, 15096800, 'baselines', 'baselines2'],
[34828000, 40132800, 'baselines2', 'baselines2'], # [34828000, 40132800, 'baselines2', 'baselines2'],
# [14102400, None, 'baselines', 'baselines'], # [14102400, None, 'baselines', 'baselines'],
[19918400, 19918400, 'baselines', 'baselines', 'legacy_', 'legacy_'],
[9161600, 19918400, 'baselines', 'baselines', '', 'legacy_'],
] ]
for vs in reversed(eval_list): for vs in reversed(eval_list):
@ -119,11 +121,11 @@ if __name__ == '__main__':
folder_b = vs[3] folder_b = vs[3]
if adp_frame is None: if adp_frame is None:
if frame is None: if frame is None:
make_evaluate(args, 0, None, None) make_evaluate(args, 0, None, None, folder_a , folder_b, vs[4], vs[5])
else: else:
make_evaluate(args, 1, frame, None) make_evaluate(args, 1, frame, None, folder_a , folder_b, vs[4], vs[5])
make_evaluate(args, 2, frame, None) make_evaluate(args, 2, frame, None, folder_a , folder_b, vs[4], vs[5])
else: else:
make_evaluate(args, 3, frame, adp_frame, folder_a , folder_b) make_evaluate(args, 3, frame, adp_frame, folder_a , folder_b, vs[4], vs[5])
if frame != adp_frame: if frame != adp_frame:
make_evaluate(args, 4, frame, adp_frame, folder_a, folder_b) make_evaluate(args, 4, frame, adp_frame, folder_a, folder_b, vs[4], vs[5])