From d68d39718c03f639db3c67d891a39ca601185cf4 Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Mon, 27 Dec 2021 15:43:01 +0800 Subject: [PATCH] =?UTF-8?q?baseline=E6=9B=B4=E6=96=B0=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/server/battle.py | 20 +++++++++++++++++++- evaluate_server.py | 19 ++----------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/douzero/server/battle.py b/douzero/server/battle.py index c34340c..70c720d 100644 --- a/douzero/server/battle.py +++ b/douzero/server/battle.py @@ -2,10 +2,25 @@ import os import traceback from .orm import Baseline, Battle, Model -from evaluate import evaluate +from ..evaluation.simulation import evaluate +from ..evaluation.deep_agent import DeepAgent from datetime import datetime positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] +idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'} +position_idx = {'landlord': 0, 'landlord_down': 1, 'landlord_front': 2, 'landlord_up': 3} + +baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1) +baseline_players = [None, None, None, None] +if len(baselines) >= 1: + baseline = baselines[0] + try: + baseline_players[0] = DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True) + baseline_players[1] = DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True) + baseline_players[2] = DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True) + baseline_players[3] = DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True) + except: + pass def battle_logic(baseline : Baseline, battle : Battle): eval_data_first = 'eval_data_200.pkl' @@ -77,6 +92,8 @@ def battle_logic(baseline : Baseline, battle : Battle): if os.path.exists(onnx_path): os.remove(onnx_path) os.remove(str(battle.challenger_path)) + else: + baseline_players[position_idx[battle.challenger_position]] = DeepAgent(battle.challenger_position, str(battle.challenger_path), use_onnx=True) def tick(): try: @@ -91,6 +108,7 @@ def tick(): models = Model.select().where(Model.position == position).order_by(Model.create_time.desc()).limit(1) if(len(models) > 0): baseline['%s_path' % position] = models[0].path + baseline_players[position_idx[position]] = DeepAgent(position, str(models[0].path), use_onnx=True) if len(baseline.keys()) == 4: baseline['rank'] = 0 baseline['landlord_wp'] = 0 diff --git a/evaluate_server.py b/evaluate_server.py index 4460880..68f3159 100644 --- a/evaluate_server.py +++ b/evaluate_server.py @@ -4,14 +4,13 @@ import os from peewee import JOIN from douzero.server.orm import Model, Battle, Baseline -from douzero.server.battle import tick +from douzero.server.battle import tick, baseline_players, positions, idx_position from flask import Flask, jsonify, request from flask_cors import CORS from datetime import datetime from concurrent.futures import ThreadPoolExecutor from douzero.env.game import get_legal_card_play_actions -from douzero.evaluation.deep_agent import DeepAgent app = Flask(__name__) CORS(app) @@ -19,8 +18,6 @@ CORS(app) Model.create_table() Battle.create_table() Baseline.create_table() -positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] -idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'} threadpool = ThreadPoolExecutor(1) @@ -33,18 +30,6 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12, 'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30} -baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1) -if len(baselines) >= 1: - baseline = baselines[0] - try: - players = [ - DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True), - DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True), - DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True), - DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True) - ] - except: - pass @app.route('/upload', methods=['POST']) def upload(): @@ -211,7 +196,7 @@ def predict(): info_set.legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move) # Prediction - actions, actions_confidence = players[player_position].act(info_set, True) + actions, actions_confidence = baseline_players[player_position].act(info_set, True) actions = [''.join([EnvCard2RealCard[a] for a in action]) for action in actions] result = {} win_rates = {}