baseline更新逻辑调整

This commit is contained in:
zhiyang7 2021-12-27 15:43:01 +08:00
parent 627e89ef31
commit d68d39718c
2 changed files with 21 additions and 18 deletions

View File

@ -2,10 +2,25 @@ import os
import traceback
from .orm import Baseline, Battle, Model
from evaluate import evaluate
from ..evaluation.simulation import evaluate
from ..evaluation.deep_agent import DeepAgent
from datetime import datetime
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'}
position_idx = {'landlord': 0, 'landlord_down': 1, 'landlord_front': 2, 'landlord_up': 3}
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
baseline_players = [None, None, None, None]
if len(baselines) >= 1:
baseline = baselines[0]
try:
baseline_players[0] = DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True)
baseline_players[1] = DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True)
baseline_players[2] = DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True)
baseline_players[3] = DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True)
except:
pass
def battle_logic(baseline : Baseline, battle : Battle):
eval_data_first = 'eval_data_200.pkl'
@ -77,6 +92,8 @@ def battle_logic(baseline : Baseline, battle : Battle):
if os.path.exists(onnx_path):
os.remove(onnx_path)
os.remove(str(battle.challenger_path))
else:
baseline_players[position_idx[battle.challenger_position]] = DeepAgent(battle.challenger_position, str(battle.challenger_path), use_onnx=True)
def tick():
try:
@ -91,6 +108,7 @@ def tick():
models = Model.select().where(Model.position == position).order_by(Model.create_time.desc()).limit(1)
if(len(models) > 0):
baseline['%s_path' % position] = models[0].path
baseline_players[position_idx[position]] = DeepAgent(position, str(models[0].path), use_onnx=True)
if len(baseline.keys()) == 4:
baseline['rank'] = 0
baseline['landlord_wp'] = 0

View File

@ -4,14 +4,13 @@ import os
from peewee import JOIN
from douzero.server.orm import Model, Battle, Baseline
from douzero.server.battle import tick
from douzero.server.battle import tick, baseline_players, positions, idx_position
from flask import Flask, jsonify, request
from flask_cors import CORS
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from douzero.env.game import get_legal_card_play_actions
from douzero.evaluation.deep_agent import DeepAgent
app = Flask(__name__)
CORS(app)
@ -19,8 +18,6 @@ CORS(app)
Model.create_table()
Battle.create_table()
Baseline.create_table()
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'}
threadpool = ThreadPoolExecutor(1)
@ -33,18 +30,6 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
if len(baselines) >= 1:
baseline = baselines[0]
try:
players = [
DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True),
DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True),
DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True),
DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True)
]
except:
pass
@app.route('/upload', methods=['POST'])
def upload():
@ -211,7 +196,7 @@ def predict():
info_set.legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
# Prediction
actions, actions_confidence = players[player_position].act(info_set, True)
actions, actions_confidence = baseline_players[player_position].act(info_set, True)
actions = [''.join([EnvCard2RealCard[a] for a in action]) for action in actions]
result = {}
win_rates = {}