调整模型

This commit is contained in:
ZaneYork 2022-01-08 16:09:34 +08:00
parent b55f470db9
commit 390393edbd
4 changed files with 11 additions and 9 deletions

View File

@ -275,7 +275,7 @@ def train(flags):
if flags.old_model:
type += 'vanilla'
elif flags.unified_model:
type += 'unified'
type += 'unified_v2'
else:
type += 'resnet'
try:

View File

@ -415,13 +415,13 @@ class UnifiedModelLite(nn.Module):
self.layer2 = self._make_layer(BasicBlock, 60, 2, stride=2)#1*9*60
self.layer3 = self._make_layer(BasicBlock, 120, 2, stride=2)#1*5*120
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.lstm = nn.LSTM(276, 128, batch_first=True)
self.lstm = nn.LSTM(276, 224, batch_first=True)
self.linear1 = nn.Linear(120 * BasicBlock.expansion * 5 + 128, 2048)
self.linear1 = nn.Linear((120 * BasicBlock.expansion * 5 + 224) * 2, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256)
self.linear5 = nn.Linear(256, 1)
self.linear3 = nn.Linear(1024, 1024)
self.linear4 = nn.Linear(1024, 512)
self.linear5 = nn.Linear(512, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
@ -458,9 +458,11 @@ class UnifiedModelLite(nn.Module):
out = self.layer2(out)
out = self.layer3(out)
out = out.flatten(1,2)
is_landlord = z[0][0][0]
lstm_out, (h_n, _) = self.lstm(x)
lstm_out = lstm_out[:,-1,:]
out = torch.cat([lstm_out,out], dim=1)
out = torch.cat([out * is_landlord, out * (1 - is_landlord)], dim=1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))

View File

@ -47,7 +47,7 @@ def battle_logic(flags, baseline : Baseline, battle : Battle):
challenger_baseline['landlord_front_path'],
challenger_baseline['landlord_down_path'],
eval_data_first,
2,
4,
False,
'New')
def _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp):
@ -57,7 +57,7 @@ def battle_logic(flags, baseline : Baseline, battle : Battle):
challenger_baseline['landlord_front_path'],
challenger_baseline['landlord_down_path'],
eval_data_second,
2,
4,
False,
'New')
return (landlord_wp + landlord_wp_2 * 4.0) / 5, \

View File

@ -28,7 +28,7 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
@app.route('/upload', methods=['POST'])
def upload():
type = request.form.get('type')
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified']:
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified', 'lite_unified_v2']:
return jsonify({'status': -1, 'message': 'illegal type'})
position = request.form.get("position")
if position != 'uni' and position not in positions: