新增evaluation server管理baseline
This commit is contained in:
parent
3b43a6bc6f
commit
15753ac45e
|
@ -41,6 +41,10 @@ parser.add_argument('--enable_onnx', action='store_true',
|
|||
help='Use onnx model for train')
|
||||
parser.add_argument('--onnx_model_path', default='douzero_checkpoints',
|
||||
help='Root dir where onnx temp model will be saved')
|
||||
parser.add_argument('--enable_upload', action='store_true',
|
||||
help='Should the cpkt model will be upload to server')
|
||||
parser.add_argument('--upload_url', default='https://dou.zaneyork.cn:8443/model/upload',
|
||||
help='The cpkt model will be upload to')
|
||||
|
||||
# Hyperparameters
|
||||
parser.add_argument('--total_frames', default=100000000000, type=int,
|
||||
|
|
|
@ -17,6 +17,7 @@ from .models import Model, OldModel
|
|||
from .utils import get_batch, log, create_env, create_optimizers, act
|
||||
import psutil
|
||||
import shutil
|
||||
import requests
|
||||
|
||||
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']}
|
||||
|
||||
|
@ -244,6 +245,21 @@ def train(flags):
|
|||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||
if flags.enable_upload:
|
||||
if flags.lite_model:
|
||||
type = 'lite_'
|
||||
else:
|
||||
type = ''
|
||||
if flags.old_model:
|
||||
type += 'vanilla'
|
||||
else:
|
||||
type += 'resnet'
|
||||
requests.post(flags.upload_url, data={
|
||||
'type': type,
|
||||
'position': position,
|
||||
'frame': frames
|
||||
}, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))})
|
||||
os.remove(model_weights_dir)
|
||||
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||
|
||||
fps_log = []
|
||||
|
@ -297,6 +313,7 @@ def train(flags):
|
|||
proc['actor'] = actor
|
||||
|
||||
except KeyboardInterrupt:
|
||||
flags.enable_upload = False
|
||||
checkpoint(frames)
|
||||
return
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
import os
|
||||
|
||||
from .orm import Baseline, Battle, Model
|
||||
from evaluate import evaluate
|
||||
from datetime import datetime
|
||||
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
|
||||
def battle_logic(baseline : Baseline, battle : Battle):
|
||||
eval_data_first = 'eval_data_200.pkl'
|
||||
eval_data_second = 'eval_data_800.pkl'
|
||||
challenger_baseline = {
|
||||
'landlord_path': str(baseline.landlord_path),
|
||||
'landlord_up_path': str(baseline.landlord_up_path),
|
||||
'landlord_front_path': str(baseline.landlord_front_path),
|
||||
'landlord_down_path': str(baseline.landlord_down_path),
|
||||
}
|
||||
challenger_baseline[battle.challenger_position + "_path"] = str(battle.challenger_path)
|
||||
|
||||
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
|
||||
evaluate(challenger_baseline['landlord_path'],
|
||||
challenger_baseline['landlord_up_path'],
|
||||
challenger_baseline['landlord_front_path'],
|
||||
challenger_baseline['landlord_down_path'],
|
||||
eval_data_first,
|
||||
1,
|
||||
False,
|
||||
'New')
|
||||
def _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp):
|
||||
landlord_wp_2, farmer_wp_2, landlord_adp_2, farmer_adp_2 = \
|
||||
evaluate(challenger_baseline['landlord_path'],
|
||||
challenger_baseline['landlord_up_path'],
|
||||
challenger_baseline['landlord_front_path'],
|
||||
challenger_baseline['landlord_down_path'],
|
||||
eval_data_second,
|
||||
1,
|
||||
False,
|
||||
'New')
|
||||
return (landlord_wp + landlord_wp_2 * 4.0) / 5, \
|
||||
(farmer_wp + farmer_wp_2 * 4.0) / 5, \
|
||||
(landlord_adp + landlord_adp_2 * 4.0) / 5, \
|
||||
(farmer_adp + farmer_adp_2 * 4.0) / 5
|
||||
|
||||
challenge_success = False
|
||||
if battle.challenger_position == 'landlord':
|
||||
if landlord_wp - 0.15 > baseline.landlord_wp:
|
||||
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
|
||||
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
|
||||
if landlord_wp - 0.15 > baseline.landlord_wp:
|
||||
challenge_success = True
|
||||
else:
|
||||
if farmer_wp - 0.15 > baseline.farmer_wp:
|
||||
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
|
||||
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
|
||||
if farmer_wp - 0.15 > baseline.farmer_wp:
|
||||
challenge_success = True
|
||||
if challenge_success:
|
||||
challenger_baseline['rank'] = baseline.rank + 1
|
||||
challenger_baseline['landlord_wp'] = landlord_wp
|
||||
challenger_baseline['landlord_adp'] = landlord_adp
|
||||
challenger_baseline['farmer_wp'] = farmer_wp
|
||||
challenger_baseline['farmer_adp'] = farmer_adp
|
||||
challenger_baseline['create_time'] = datetime.now()
|
||||
Baseline.create(**challenger_baseline)
|
||||
else:
|
||||
onnx_path = str(battle.challenger_path) + '.onnx'
|
||||
if os.path.exists(onnx_path):
|
||||
os.remove(onnx_path)
|
||||
os.remove(str(battle.challenger_path))
|
||||
battle.challenger_wp = landlord_wp if battle.challenger_position == 'landlord' else farmer_wp
|
||||
battle.challenger_adp = landlord_adp if battle.challenger_position == 'landlord' else farmer_adp
|
||||
battle.status = 1 if challenge_success else 2
|
||||
battle.save()
|
||||
|
||||
def tick():
|
||||
battles = Battle.select().where(Battle.status == 0).order_by(Battle.id.desc())
|
||||
for battle in battles:
|
||||
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
|
||||
if len(baselines) == 0:
|
||||
baseline = {}
|
||||
for position in positions:
|
||||
models = Model.select().where(Model.position == position).order_by(Model.create_time.desc()).limit(1)
|
||||
if(len(models) > 0):
|
||||
baseline['%s_path' % position] = models[0].path
|
||||
if len(baseline.keys()) == 4:
|
||||
baseline['rank'] = 0
|
||||
baseline['landlord_wp'] = 0
|
||||
baseline['farmer_wp'] = 0
|
||||
baseline['landlord_adp'] = 0
|
||||
baseline['farmer_adp'] = 0
|
||||
baseline['create_time'] = datetime.now()
|
||||
Baseline.create(**baseline)
|
||||
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
|
||||
battle_logic(baselines[0], battle)
|
||||
else:
|
||||
battle.status = 3
|
||||
battle.save()
|
||||
else:
|
||||
battle_logic(baselines[0], battle)
|
|
@ -0,0 +1,44 @@
|
|||
# -*- coding:utf8 -*-
|
||||
from peewee import *
|
||||
db = SqliteDatabase('model.db')
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = db
|
||||
|
||||
class Model(BaseModel):
|
||||
path = CharField(primary_key=True)
|
||||
position = CharField(null = False)
|
||||
type = CharField(null = False)
|
||||
frame = IntegerField(null = False)
|
||||
create_time = DateTimeField(null=False)
|
||||
class Meta:
|
||||
order_by = ('type',)
|
||||
db_table = 'model'
|
||||
|
||||
class Battle(BaseModel):
|
||||
id = PrimaryKeyField()
|
||||
challenger_path = CharField(null = False)
|
||||
challenger_position = CharField(null = False)
|
||||
status = IntegerField(null = False)
|
||||
challenger_wp = DecimalField(null=True)
|
||||
challenger_adp = DecimalField(null=True)
|
||||
class Meta:
|
||||
order_by = ('id',)
|
||||
db_table = 'battle'
|
||||
|
||||
class Baseline(BaseModel):
|
||||
id = PrimaryKeyField()
|
||||
landlord_path = ForeignKeyField(Model, to_field='path',related_name = "model")
|
||||
landlord_up_path = ForeignKeyField(Model, to_field='path',related_name = "model")
|
||||
landlord_front_path = ForeignKeyField(Model, to_field='path',related_name = "model")
|
||||
landlord_down_path = ForeignKeyField(Model, to_field='path',related_name = "model")
|
||||
rank = IntegerField(null = False)
|
||||
landlord_wp = DecimalField(null=False)
|
||||
farmer_wp = DecimalField(null=False)
|
||||
landlord_adp = DecimalField(null=False)
|
||||
farmer_adp = DecimalField(null=False)
|
||||
create_time = DateTimeField(null=False)
|
||||
class Meta:
|
||||
order_by = ('id',)
|
||||
db_table = 'baseline'
|
|
@ -0,0 +1,46 @@
|
|||
from douzero.evaluation.simulation import evaluate
|
||||
|
||||
from douzero.server.orm import Model, Battle, Baseline
|
||||
from douzero.server.battle import tick
|
||||
from flask import Flask, jsonify, request
|
||||
from flask_cors import CORS
|
||||
from datetime import datetime
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
Model.create_table()
|
||||
Battle.create_table()
|
||||
Baseline.create_table()
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
|
||||
threadpool = ThreadPoolExecutor(1)
|
||||
|
||||
@app.route('/upload', methods=['POST'])
|
||||
def upload():
|
||||
type = request.form.get('type')
|
||||
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']:
|
||||
return jsonify({'status': -1, 'message': 'illegal type'})
|
||||
position = request.form.get("position")
|
||||
if position not in positions:
|
||||
return jsonify({'status': -2, 'message': 'illegal position'})
|
||||
frame = int(request.form.get("frame"))
|
||||
model_file = request.files.get('model_file')
|
||||
if model_file is None:
|
||||
return jsonify({'status': -3, 'message': 'illegal model_file'})
|
||||
path = "baselines_server/%s_%s_%d.ckpt" % (type, position, frame)
|
||||
model = Model.get_or_none(Model.path == path)
|
||||
if model is None:
|
||||
model_file.save(path)
|
||||
Model.create(path=path, position=position,type=type,frame=frame,create_time=datetime.now())
|
||||
Battle.create(challenger_path=path, challenger_position=position, status=0)
|
||||
threadpool.submit(tick)
|
||||
return jsonify({'status': 0, 'message': 'success', 'result': ''})
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='DouZero evaluation backend')
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
args = parser.parse_args()
|
||||
app.run(debug=args.debug, host="0.0.0.0")
|
|
@ -0,0 +1,3 @@
|
|||
flask==1.1
|
||||
flask-cors
|
||||
peewee
|
|
@ -5,3 +5,4 @@ rlcard
|
|||
psutil
|
||||
onnx
|
||||
onnxruntime-gpu==1.7
|
||||
requests
|
||||
|
|
Loading…
Reference in New Issue