baseline更新逻辑调整

This commit is contained in:
zhiyang7 2021-12-27 15:43:01 +08:00
parent 627e89ef31
commit d68d39718c
2 changed files with 21 additions and 18 deletions

View File

@ -2,10 +2,25 @@ import os
import traceback import traceback
from .orm import Baseline, Battle, Model from .orm import Baseline, Battle, Model
from evaluate import evaluate from ..evaluation.simulation import evaluate
from ..evaluation.deep_agent import DeepAgent
from datetime import datetime from datetime import datetime
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'}
position_idx = {'landlord': 0, 'landlord_down': 1, 'landlord_front': 2, 'landlord_up': 3}
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
baseline_players = [None, None, None, None]
if len(baselines) >= 1:
baseline = baselines[0]
try:
baseline_players[0] = DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True)
baseline_players[1] = DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True)
baseline_players[2] = DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True)
baseline_players[3] = DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True)
except:
pass
def battle_logic(baseline : Baseline, battle : Battle): def battle_logic(baseline : Baseline, battle : Battle):
eval_data_first = 'eval_data_200.pkl' eval_data_first = 'eval_data_200.pkl'
@ -77,6 +92,8 @@ def battle_logic(baseline : Baseline, battle : Battle):
if os.path.exists(onnx_path): if os.path.exists(onnx_path):
os.remove(onnx_path) os.remove(onnx_path)
os.remove(str(battle.challenger_path)) os.remove(str(battle.challenger_path))
else:
baseline_players[position_idx[battle.challenger_position]] = DeepAgent(battle.challenger_position, str(battle.challenger_path), use_onnx=True)
def tick(): def tick():
try: try:
@ -91,6 +108,7 @@ def tick():
models = Model.select().where(Model.position == position).order_by(Model.create_time.desc()).limit(1) models = Model.select().where(Model.position == position).order_by(Model.create_time.desc()).limit(1)
if(len(models) > 0): if(len(models) > 0):
baseline['%s_path' % position] = models[0].path baseline['%s_path' % position] = models[0].path
baseline_players[position_idx[position]] = DeepAgent(position, str(models[0].path), use_onnx=True)
if len(baseline.keys()) == 4: if len(baseline.keys()) == 4:
baseline['rank'] = 0 baseline['rank'] = 0
baseline['landlord_wp'] = 0 baseline['landlord_wp'] = 0

View File

@ -4,14 +4,13 @@ import os
from peewee import JOIN from peewee import JOIN
from douzero.server.orm import Model, Battle, Baseline from douzero.server.orm import Model, Battle, Baseline
from douzero.server.battle import tick from douzero.server.battle import tick, baseline_players, positions, idx_position
from flask import Flask, jsonify, request from flask import Flask, jsonify, request
from flask_cors import CORS from flask_cors import CORS
from datetime import datetime from datetime import datetime
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from douzero.env.game import get_legal_card_play_actions from douzero.env.game import get_legal_card_play_actions
from douzero.evaluation.deep_agent import DeepAgent
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
@ -19,8 +18,6 @@ CORS(app)
Model.create_table() Model.create_table()
Battle.create_table() Battle.create_table()
Baseline.create_table() Baseline.create_table()
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'}
threadpool = ThreadPoolExecutor(1) threadpool = ThreadPoolExecutor(1)
@ -33,18 +30,6 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12, '8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30} 'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
if len(baselines) >= 1:
baseline = baselines[0]
try:
players = [
DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True),
DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True),
DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True),
DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True)
]
except:
pass
@app.route('/upload', methods=['POST']) @app.route('/upload', methods=['POST'])
def upload(): def upload():
@ -211,7 +196,7 @@ def predict():
info_set.legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move) info_set.legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
# Prediction # Prediction
actions, actions_confidence = players[player_position].act(info_set, True) actions, actions_confidence = baseline_players[player_position].act(info_set, True)
actions = [''.join([EnvCard2RealCard[a] for a in action]) for action in actions] actions = [''.join([EnvCard2RealCard[a] for a in action]) for action in actions]
result = {} result = {}
win_rates = {} win_rates = {}