新增evaluation server管理baseline

This commit is contained in:
ZaneYork 2021-12-25 19:06:34 +08:00
parent 3b43a6bc6f
commit 15753ac45e
8 changed files with 214 additions and 0 deletions

View File

@ -41,6 +41,10 @@ parser.add_argument('--enable_onnx', action='store_true',
help='Use onnx model for train')
parser.add_argument('--onnx_model_path', default='douzero_checkpoints',
help='Root dir where onnx temp model will be saved')
parser.add_argument('--enable_upload', action='store_true',
help='Should the cpkt model will be upload to server')
parser.add_argument('--upload_url', default='https://dou.zaneyork.cn:8443/model/upload',
help='The cpkt model will be upload to')
# Hyperparameters
parser.add_argument('--total_frames', default=100000000000, type=int,

View File

@ -17,6 +17,7 @@ from .models import Model, OldModel
from .utils import get_batch, log, create_env, create_optimizers, act
import psutil
import shutil
import requests
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']}
@ -244,6 +245,21 @@ def train(flags):
model_weights_dir = os.path.expandvars(os.path.expanduser(
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
if flags.enable_upload:
if flags.lite_model:
type = 'lite_'
else:
type = ''
if flags.old_model:
type += 'vanilla'
else:
type += 'resnet'
requests.post(flags.upload_url, data={
'type': type,
'position': position,
'frame': frames
}, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))})
os.remove(model_weights_dir)
shutil.move(checkpointpath + '.new', checkpointpath)
fps_log = []
@ -297,6 +313,7 @@ def train(flags):
proc['actor'] = actor
except KeyboardInterrupt:
flags.enable_upload = False
checkpoint(frames)
return
else:

View File

99
douzero/server/battle.py Normal file
View File

@ -0,0 +1,99 @@
import os
from .orm import Baseline, Battle, Model
from evaluate import evaluate
from datetime import datetime
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
def battle_logic(baseline : Baseline, battle : Battle):
eval_data_first = 'eval_data_200.pkl'
eval_data_second = 'eval_data_800.pkl'
challenger_baseline = {
'landlord_path': str(baseline.landlord_path),
'landlord_up_path': str(baseline.landlord_up_path),
'landlord_front_path': str(baseline.landlord_front_path),
'landlord_down_path': str(baseline.landlord_down_path),
}
challenger_baseline[battle.challenger_position + "_path"] = str(battle.challenger_path)
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
evaluate(challenger_baseline['landlord_path'],
challenger_baseline['landlord_up_path'],
challenger_baseline['landlord_front_path'],
challenger_baseline['landlord_down_path'],
eval_data_first,
1,
False,
'New')
def _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp):
landlord_wp_2, farmer_wp_2, landlord_adp_2, farmer_adp_2 = \
evaluate(challenger_baseline['landlord_path'],
challenger_baseline['landlord_up_path'],
challenger_baseline['landlord_front_path'],
challenger_baseline['landlord_down_path'],
eval_data_second,
1,
False,
'New')
return (landlord_wp + landlord_wp_2 * 4.0) / 5, \
(farmer_wp + farmer_wp_2 * 4.0) / 5, \
(landlord_adp + landlord_adp_2 * 4.0) / 5, \
(farmer_adp + farmer_adp_2 * 4.0) / 5
challenge_success = False
if battle.challenger_position == 'landlord':
if landlord_wp - 0.15 > baseline.landlord_wp:
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
if landlord_wp - 0.15 > baseline.landlord_wp:
challenge_success = True
else:
if farmer_wp - 0.15 > baseline.farmer_wp:
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
if farmer_wp - 0.15 > baseline.farmer_wp:
challenge_success = True
if challenge_success:
challenger_baseline['rank'] = baseline.rank + 1
challenger_baseline['landlord_wp'] = landlord_wp
challenger_baseline['landlord_adp'] = landlord_adp
challenger_baseline['farmer_wp'] = farmer_wp
challenger_baseline['farmer_adp'] = farmer_adp
challenger_baseline['create_time'] = datetime.now()
Baseline.create(**challenger_baseline)
else:
onnx_path = str(battle.challenger_path) + '.onnx'
if os.path.exists(onnx_path):
os.remove(onnx_path)
os.remove(str(battle.challenger_path))
battle.challenger_wp = landlord_wp if battle.challenger_position == 'landlord' else farmer_wp
battle.challenger_adp = landlord_adp if battle.challenger_position == 'landlord' else farmer_adp
battle.status = 1 if challenge_success else 2
battle.save()
def tick():
battles = Battle.select().where(Battle.status == 0).order_by(Battle.id.desc())
for battle in battles:
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
if len(baselines) == 0:
baseline = {}
for position in positions:
models = Model.select().where(Model.position == position).order_by(Model.create_time.desc()).limit(1)
if(len(models) > 0):
baseline['%s_path' % position] = models[0].path
if len(baseline.keys()) == 4:
baseline['rank'] = 0
baseline['landlord_wp'] = 0
baseline['farmer_wp'] = 0
baseline['landlord_adp'] = 0
baseline['farmer_adp'] = 0
baseline['create_time'] = datetime.now()
Baseline.create(**baseline)
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
battle_logic(baselines[0], battle)
else:
battle.status = 3
battle.save()
else:
battle_logic(baselines[0], battle)

44
douzero/server/orm.py Normal file
View File

@ -0,0 +1,44 @@
# -*- coding:utf8 -*-
from peewee import *
db = SqliteDatabase('model.db')
class BaseModel(Model):
class Meta:
database = db
class Model(BaseModel):
path = CharField(primary_key=True)
position = CharField(null = False)
type = CharField(null = False)
frame = IntegerField(null = False)
create_time = DateTimeField(null=False)
class Meta:
order_by = ('type',)
db_table = 'model'
class Battle(BaseModel):
id = PrimaryKeyField()
challenger_path = CharField(null = False)
challenger_position = CharField(null = False)
status = IntegerField(null = False)
challenger_wp = DecimalField(null=True)
challenger_adp = DecimalField(null=True)
class Meta:
order_by = ('id',)
db_table = 'battle'
class Baseline(BaseModel):
id = PrimaryKeyField()
landlord_path = ForeignKeyField(Model, to_field='path',related_name = "model")
landlord_up_path = ForeignKeyField(Model, to_field='path',related_name = "model")
landlord_front_path = ForeignKeyField(Model, to_field='path',related_name = "model")
landlord_down_path = ForeignKeyField(Model, to_field='path',related_name = "model")
rank = IntegerField(null = False)
landlord_wp = DecimalField(null=False)
farmer_wp = DecimalField(null=False)
landlord_adp = DecimalField(null=False)
farmer_adp = DecimalField(null=False)
create_time = DateTimeField(null=False)
class Meta:
order_by = ('id',)
db_table = 'baseline'

46
evaluate_server.py Normal file
View File

@ -0,0 +1,46 @@
from douzero.evaluation.simulation import evaluate
from douzero.server.orm import Model, Battle, Baseline
from douzero.server.battle import tick
from flask import Flask, jsonify, request
from flask_cors import CORS
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
app = Flask(__name__)
CORS(app)
Model.create_table()
Battle.create_table()
Baseline.create_table()
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
threadpool = ThreadPoolExecutor(1)
@app.route('/upload', methods=['POST'])
def upload():
type = request.form.get('type')
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']:
return jsonify({'status': -1, 'message': 'illegal type'})
position = request.form.get("position")
if position not in positions:
return jsonify({'status': -2, 'message': 'illegal position'})
frame = int(request.form.get("frame"))
model_file = request.files.get('model_file')
if model_file is None:
return jsonify({'status': -3, 'message': 'illegal model_file'})
path = "baselines_server/%s_%s_%d.ckpt" % (type, position, frame)
model = Model.get_or_none(Model.path == path)
if model is None:
model_file.save(path)
Model.create(path=path, position=position,type=type,frame=frame,create_time=datetime.now())
Battle.create(challenger_path=path, challenger_position=position, status=0)
threadpool.submit(tick)
return jsonify({'status': 0, 'message': 'success', 'result': ''})
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='DouZero evaluation backend')
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
app.run(debug=args.debug, host="0.0.0.0")

3
requirements-server.txt Normal file
View File

@ -0,0 +1,3 @@
flask==1.1
flask-cors
peewee

View File

@ -5,3 +5,4 @@ rlcard
psutil
onnx
onnxruntime-gpu==1.7
requests