调整模型
This commit is contained in:
parent
b55f470db9
commit
390393edbd
|
@ -275,7 +275,7 @@ def train(flags):
|
||||||
if flags.old_model:
|
if flags.old_model:
|
||||||
type += 'vanilla'
|
type += 'vanilla'
|
||||||
elif flags.unified_model:
|
elif flags.unified_model:
|
||||||
type += 'unified'
|
type += 'unified_v2'
|
||||||
else:
|
else:
|
||||||
type += 'resnet'
|
type += 'resnet'
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -415,13 +415,13 @@ class UnifiedModelLite(nn.Module):
|
||||||
self.layer2 = self._make_layer(BasicBlock, 60, 2, stride=2)#1*9*60
|
self.layer2 = self._make_layer(BasicBlock, 60, 2, stride=2)#1*9*60
|
||||||
self.layer3 = self._make_layer(BasicBlock, 120, 2, stride=2)#1*5*120
|
self.layer3 = self._make_layer(BasicBlock, 120, 2, stride=2)#1*5*120
|
||||||
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||||
self.lstm = nn.LSTM(276, 128, batch_first=True)
|
self.lstm = nn.LSTM(276, 224, batch_first=True)
|
||||||
|
|
||||||
self.linear1 = nn.Linear(120 * BasicBlock.expansion * 5 + 128, 2048)
|
self.linear1 = nn.Linear((120 * BasicBlock.expansion * 5 + 224) * 2, 2048)
|
||||||
self.linear2 = nn.Linear(2048, 1024)
|
self.linear2 = nn.Linear(2048, 1024)
|
||||||
self.linear3 = nn.Linear(1024, 512)
|
self.linear3 = nn.Linear(1024, 1024)
|
||||||
self.linear4 = nn.Linear(512, 256)
|
self.linear4 = nn.Linear(1024, 512)
|
||||||
self.linear5 = nn.Linear(256, 1)
|
self.linear5 = nn.Linear(512, 1)
|
||||||
|
|
||||||
def _make_layer(self, block, planes, num_blocks, stride):
|
def _make_layer(self, block, planes, num_blocks, stride):
|
||||||
strides = [stride] + [1] * (num_blocks - 1)
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
@ -458,9 +458,11 @@ class UnifiedModelLite(nn.Module):
|
||||||
out = self.layer2(out)
|
out = self.layer2(out)
|
||||||
out = self.layer3(out)
|
out = self.layer3(out)
|
||||||
out = out.flatten(1,2)
|
out = out.flatten(1,2)
|
||||||
|
is_landlord = z[0][0][0]
|
||||||
lstm_out, (h_n, _) = self.lstm(x)
|
lstm_out, (h_n, _) = self.lstm(x)
|
||||||
lstm_out = lstm_out[:,-1,:]
|
lstm_out = lstm_out[:,-1,:]
|
||||||
out = torch.cat([lstm_out,out], dim=1)
|
out = torch.cat([lstm_out,out], dim=1)
|
||||||
|
out = torch.cat([out * is_landlord, out * (1 - is_landlord)], dim=1)
|
||||||
out = F.leaky_relu_(self.linear1(out))
|
out = F.leaky_relu_(self.linear1(out))
|
||||||
out = F.leaky_relu_(self.linear2(out))
|
out = F.leaky_relu_(self.linear2(out))
|
||||||
out = F.leaky_relu_(self.linear3(out))
|
out = F.leaky_relu_(self.linear3(out))
|
||||||
|
|
|
@ -47,7 +47,7 @@ def battle_logic(flags, baseline : Baseline, battle : Battle):
|
||||||
challenger_baseline['landlord_front_path'],
|
challenger_baseline['landlord_front_path'],
|
||||||
challenger_baseline['landlord_down_path'],
|
challenger_baseline['landlord_down_path'],
|
||||||
eval_data_first,
|
eval_data_first,
|
||||||
2,
|
4,
|
||||||
False,
|
False,
|
||||||
'New')
|
'New')
|
||||||
def _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp):
|
def _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp):
|
||||||
|
@ -57,7 +57,7 @@ def battle_logic(flags, baseline : Baseline, battle : Battle):
|
||||||
challenger_baseline['landlord_front_path'],
|
challenger_baseline['landlord_front_path'],
|
||||||
challenger_baseline['landlord_down_path'],
|
challenger_baseline['landlord_down_path'],
|
||||||
eval_data_second,
|
eval_data_second,
|
||||||
2,
|
4,
|
||||||
False,
|
False,
|
||||||
'New')
|
'New')
|
||||||
return (landlord_wp + landlord_wp_2 * 4.0) / 5, \
|
return (landlord_wp + landlord_wp_2 * 4.0) / 5, \
|
||||||
|
|
|
@ -28,7 +28,7 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
|
||||||
@app.route('/upload', methods=['POST'])
|
@app.route('/upload', methods=['POST'])
|
||||||
def upload():
|
def upload():
|
||||||
type = request.form.get('type')
|
type = request.form.get('type')
|
||||||
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified']:
|
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified', 'lite_unified_v2']:
|
||||||
return jsonify({'status': -1, 'message': 'illegal type'})
|
return jsonify({'status': -1, 'message': 'illegal type'})
|
||||||
position = request.form.get("position")
|
position = request.form.get("position")
|
||||||
if position != 'uni' and position not in positions:
|
if position != 'uni' and position not in positions:
|
||||||
|
|
Loading…
Reference in New Issue