76 lines
3.0 KiB
Python
76 lines
3.0 KiB
Python
from douzero.server.orm import Model, Battle, Baseline
|
|
from douzero.server.battle import tick
|
|
from flask import Flask, jsonify, request
|
|
from flask_cors import CORS
|
|
from datetime import datetime
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
|
|
Model.create_table()
|
|
Battle.create_table()
|
|
Baseline.create_table()
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
|
|
|
threadpool = ThreadPoolExecutor(1)
|
|
|
|
@app.route('/upload', methods=['POST'])
|
|
def upload():
|
|
type = request.form.get('type')
|
|
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']:
|
|
return jsonify({'status': -1, 'message': 'illegal type'})
|
|
position = request.form.get("position")
|
|
if position not in positions:
|
|
return jsonify({'status': -2, 'message': 'illegal position'})
|
|
frame = int(request.form.get("frame"))
|
|
model_file = request.files.get('model_file')
|
|
if model_file is None:
|
|
return jsonify({'status': -3, 'message': 'illegal model_file'})
|
|
path = "baselines_server/%s_%s_%d.ckpt" % (type, position, frame)
|
|
model = Model.get_or_none(Model.path == path)
|
|
if model is None:
|
|
model_file.save(path)
|
|
Model.create(path=path, position=position,type=type,frame=frame,create_time=datetime.now())
|
|
Battle.create(challenger_path=path, challenger_position=position, status=0)
|
|
threadpool.submit(tick)
|
|
return jsonify({'status': 0, 'message': 'success', 'result': ''})
|
|
|
|
@app.route('/battle_tick', methods=['POST'])
|
|
def battle_tick():
|
|
tick()
|
|
return jsonify({'status': 0, 'message': 'success', 'result': ''})
|
|
|
|
@app.route('/metrics', methods=['GET'])
|
|
def metrics():
|
|
type = request.args.get('type')
|
|
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(3)
|
|
end_time = datetime.now()
|
|
metrics = {}
|
|
for i in range(len(baselines)):
|
|
baseline = baselines[i]
|
|
baseline_metric = {}
|
|
models = Model.select(Model.frame, Model.path).where(
|
|
Model.type == type,
|
|
Model.create_time >= baseline.create_time,
|
|
Model.create_time >= end_time
|
|
).order_by(Model.create_time.asc())
|
|
end_time = baseline.create_time
|
|
for model in models:
|
|
battle = Battle.get_or_none(Battle.challenger_path == model.path, Battle.status > 0, Battle.status != 3)
|
|
if battle is not None:
|
|
baseline_metric[model.frame] = {
|
|
'position': battle.challenger_position,
|
|
'wp': '%.4f' % float(battle.challenger_wp),
|
|
'adp': '%.4f' % float(battle.challenger_adp)
|
|
}
|
|
metrics[baseline.rank] = baseline_metric
|
|
return jsonify({'status': 0, 'message': 'success', 'result': metrics})
|
|
|
|
if __name__ == '__main__':
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description='DouZero evaluation backend')
|
|
parser.add_argument('--debug', action='store_true')
|
|
args = parser.parse_args()
|
|
app.run(debug=args.debug, host="0.0.0.0")
|