diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 872b1d7..6df37a5 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -275,7 +275,7 @@ def train(flags): if flags.old_model: type += 'vanilla' elif flags.unified_model: - type += 'unified' + type += 'unified_v2' else: type += 'resnet' try: diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index a6e5f65..919431a 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -415,13 +415,13 @@ class UnifiedModelLite(nn.Module): self.layer2 = self._make_layer(BasicBlock, 60, 2, stride=2)#1*9*60 self.layer3 = self._make_layer(BasicBlock, 120, 2, stride=2)#1*5*120 # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.lstm = nn.LSTM(276, 128, batch_first=True) + self.lstm = nn.LSTM(276, 224, batch_first=True) - self.linear1 = nn.Linear(120 * BasicBlock.expansion * 5 + 128, 2048) + self.linear1 = nn.Linear((120 * BasicBlock.expansion * 5 + 224) * 2, 2048) self.linear2 = nn.Linear(2048, 1024) - self.linear3 = nn.Linear(1024, 512) - self.linear4 = nn.Linear(512, 256) - self.linear5 = nn.Linear(256, 1) + self.linear3 = nn.Linear(1024, 1024) + self.linear4 = nn.Linear(1024, 512) + self.linear5 = nn.Linear(512, 1) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) @@ -458,9 +458,11 @@ class UnifiedModelLite(nn.Module): out = self.layer2(out) out = self.layer3(out) out = out.flatten(1,2) + is_landlord = z[0][0][0] lstm_out, (h_n, _) = self.lstm(x) lstm_out = lstm_out[:,-1,:] out = torch.cat([lstm_out,out], dim=1) + out = torch.cat([out * is_landlord, out * (1 - is_landlord)], dim=1) out = F.leaky_relu_(self.linear1(out)) out = F.leaky_relu_(self.linear2(out)) out = F.leaky_relu_(self.linear3(out)) diff --git a/douzero/server/battle.py b/douzero/server/battle.py index 27cb8a7..180b3bd 100644 --- a/douzero/server/battle.py +++ b/douzero/server/battle.py @@ -47,7 +47,7 @@ def battle_logic(flags, baseline : Baseline, battle : Battle): challenger_baseline['landlord_front_path'], challenger_baseline['landlord_down_path'], eval_data_first, - 2, + 4, False, 'New') def _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp): @@ -57,7 +57,7 @@ def battle_logic(flags, baseline : Baseline, battle : Battle): challenger_baseline['landlord_front_path'], challenger_baseline['landlord_down_path'], eval_data_second, - 2, + 4, False, 'New') return (landlord_wp + landlord_wp_2 * 4.0) / 5, \ diff --git a/evaluate_server.py b/evaluate_server.py index 5d771c4..294a594 100644 --- a/evaluate_server.py +++ b/evaluate_server.py @@ -28,7 +28,7 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, @app.route('/upload', methods=['POST']) def upload(): type = request.form.get('type') - if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified']: + if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified', 'lite_unified_v2']: return jsonify({'status': -1, 'message': 'illegal type'}) position = request.form.get("position") if position != 'uni' and position not in positions: