Douzero_Resnet/evaluate_server.py

80 lines
3.1 KiB
Python

from douzero.server.orm import Model, Battle, Baseline
from douzero.server.battle import tick
from flask import Flask, jsonify, request
from flask_cors import CORS
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
app = Flask(__name__)
CORS(app)
Model.create_table()
Battle.create_table()
Baseline.create_table()
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
threadpool = ThreadPoolExecutor(1)
@app.route('/upload', methods=['POST'])
def upload():
type = request.form.get('type')
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']:
return jsonify({'status': -1, 'message': 'illegal type'})
position = request.form.get("position")
if position not in positions:
return jsonify({'status': -2, 'message': 'illegal position'})
frame = int(request.form.get("frame"))
model_file = request.files.get('model_file')
if model_file is None:
return jsonify({'status': -3, 'message': 'illegal model_file'})
path = "baselines_server/%s_%s_%d.ckpt" % (type, position, frame)
model = Model.get_or_none(Model.path == path)
if model is None:
model_file.save(path)
Model.create(path=path, position=position,type=type,frame=frame,create_time=datetime.now())
Battle.create(challenger_path=path, challenger_position=position, status=0)
threadpool.submit(tick)
return jsonify({'status': 0, 'message': 'success', 'result': ''})
@app.route('/battle_tick', methods=['POST'])
def battle_tick():
tick()
return jsonify({'status': 0, 'message': 'success', 'result': ''})
@app.route('/metrics', methods=['GET'])
def metrics():
type = request.args.get('type')
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(3)
end_time = datetime.now()
metrics = {}
for i in range(len(baselines)):
baseline = baselines[i]
baseline_metric = {
'landlord': {},
'landlord_up': {},
'landlord_front': {},
'landlord_down': {}
}
models = Model.select(Model.frame, Model.path).where(
Model.type == type,
Model.create_time >= baseline.create_time,
Model.create_time <= end_time
).order_by(Model.create_time.asc())
end_time = baseline.create_time
for model in models:
battle = Battle.get_or_none(Battle.challenger_path == model.path, Battle.status > 0, Battle.status != 3)
if battle is not None:
baseline_metric[str(battle.challenger_position)][model.frame] = {
'wp': '%.4f' % float(battle.challenger_wp),
'adp': '%.4f' % float(battle.challenger_adp)
}
metrics[baseline.rank] = baseline_metric
return jsonify({'status': 0, 'message': 'success', 'result': metrics})
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='DouZero evaluation backend')
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
app.run(debug=args.debug, host="0.0.0.0")