新增evaluation server管理baseline
This commit is contained in:
parent
3b43a6bc6f
commit
15753ac45e
|
@ -41,6 +41,10 @@ parser.add_argument('--enable_onnx', action='store_true',
|
||||||
help='Use onnx model for train')
|
help='Use onnx model for train')
|
||||||
parser.add_argument('--onnx_model_path', default='douzero_checkpoints',
|
parser.add_argument('--onnx_model_path', default='douzero_checkpoints',
|
||||||
help='Root dir where onnx temp model will be saved')
|
help='Root dir where onnx temp model will be saved')
|
||||||
|
parser.add_argument('--enable_upload', action='store_true',
|
||||||
|
help='Should the cpkt model will be upload to server')
|
||||||
|
parser.add_argument('--upload_url', default='https://dou.zaneyork.cn:8443/model/upload',
|
||||||
|
help='The cpkt model will be upload to')
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
parser.add_argument('--total_frames', default=100000000000, type=int,
|
parser.add_argument('--total_frames', default=100000000000, type=int,
|
||||||
|
|
|
@ -17,6 +17,7 @@ from .models import Model, OldModel
|
||||||
from .utils import get_batch, log, create_env, create_optimizers, act
|
from .utils import get_batch, log, create_env, create_optimizers, act
|
||||||
import psutil
|
import psutil
|
||||||
import shutil
|
import shutil
|
||||||
|
import requests
|
||||||
|
|
||||||
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']}
|
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']}
|
||||||
|
|
||||||
|
@ -244,6 +245,21 @@ def train(flags):
|
||||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
||||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||||
|
if flags.enable_upload:
|
||||||
|
if flags.lite_model:
|
||||||
|
type = 'lite_'
|
||||||
|
else:
|
||||||
|
type = ''
|
||||||
|
if flags.old_model:
|
||||||
|
type += 'vanilla'
|
||||||
|
else:
|
||||||
|
type += 'resnet'
|
||||||
|
requests.post(flags.upload_url, data={
|
||||||
|
'type': type,
|
||||||
|
'position': position,
|
||||||
|
'frame': frames
|
||||||
|
}, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))})
|
||||||
|
os.remove(model_weights_dir)
|
||||||
shutil.move(checkpointpath + '.new', checkpointpath)
|
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||||
|
|
||||||
fps_log = []
|
fps_log = []
|
||||||
|
@ -297,6 +313,7 @@ def train(flags):
|
||||||
proc['actor'] = actor
|
proc['actor'] = actor
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
flags.enable_upload = False
|
||||||
checkpoint(frames)
|
checkpoint(frames)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .orm import Baseline, Battle, Model
|
||||||
|
from evaluate import evaluate
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
|
|
||||||
|
def battle_logic(baseline : Baseline, battle : Battle):
|
||||||
|
eval_data_first = 'eval_data_200.pkl'
|
||||||
|
eval_data_second = 'eval_data_800.pkl'
|
||||||
|
challenger_baseline = {
|
||||||
|
'landlord_path': str(baseline.landlord_path),
|
||||||
|
'landlord_up_path': str(baseline.landlord_up_path),
|
||||||
|
'landlord_front_path': str(baseline.landlord_front_path),
|
||||||
|
'landlord_down_path': str(baseline.landlord_down_path),
|
||||||
|
}
|
||||||
|
challenger_baseline[battle.challenger_position + "_path"] = str(battle.challenger_path)
|
||||||
|
|
||||||
|
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
|
||||||
|
evaluate(challenger_baseline['landlord_path'],
|
||||||
|
challenger_baseline['landlord_up_path'],
|
||||||
|
challenger_baseline['landlord_front_path'],
|
||||||
|
challenger_baseline['landlord_down_path'],
|
||||||
|
eval_data_first,
|
||||||
|
1,
|
||||||
|
False,
|
||||||
|
'New')
|
||||||
|
def _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp):
|
||||||
|
landlord_wp_2, farmer_wp_2, landlord_adp_2, farmer_adp_2 = \
|
||||||
|
evaluate(challenger_baseline['landlord_path'],
|
||||||
|
challenger_baseline['landlord_up_path'],
|
||||||
|
challenger_baseline['landlord_front_path'],
|
||||||
|
challenger_baseline['landlord_down_path'],
|
||||||
|
eval_data_second,
|
||||||
|
1,
|
||||||
|
False,
|
||||||
|
'New')
|
||||||
|
return (landlord_wp + landlord_wp_2 * 4.0) / 5, \
|
||||||
|
(farmer_wp + farmer_wp_2 * 4.0) / 5, \
|
||||||
|
(landlord_adp + landlord_adp_2 * 4.0) / 5, \
|
||||||
|
(farmer_adp + farmer_adp_2 * 4.0) / 5
|
||||||
|
|
||||||
|
challenge_success = False
|
||||||
|
if battle.challenger_position == 'landlord':
|
||||||
|
if landlord_wp - 0.15 > baseline.landlord_wp:
|
||||||
|
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
|
||||||
|
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
|
||||||
|
if landlord_wp - 0.15 > baseline.landlord_wp:
|
||||||
|
challenge_success = True
|
||||||
|
else:
|
||||||
|
if farmer_wp - 0.15 > baseline.farmer_wp:
|
||||||
|
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
|
||||||
|
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
|
||||||
|
if farmer_wp - 0.15 > baseline.farmer_wp:
|
||||||
|
challenge_success = True
|
||||||
|
if challenge_success:
|
||||||
|
challenger_baseline['rank'] = baseline.rank + 1
|
||||||
|
challenger_baseline['landlord_wp'] = landlord_wp
|
||||||
|
challenger_baseline['landlord_adp'] = landlord_adp
|
||||||
|
challenger_baseline['farmer_wp'] = farmer_wp
|
||||||
|
challenger_baseline['farmer_adp'] = farmer_adp
|
||||||
|
challenger_baseline['create_time'] = datetime.now()
|
||||||
|
Baseline.create(**challenger_baseline)
|
||||||
|
else:
|
||||||
|
onnx_path = str(battle.challenger_path) + '.onnx'
|
||||||
|
if os.path.exists(onnx_path):
|
||||||
|
os.remove(onnx_path)
|
||||||
|
os.remove(str(battle.challenger_path))
|
||||||
|
battle.challenger_wp = landlord_wp if battle.challenger_position == 'landlord' else farmer_wp
|
||||||
|
battle.challenger_adp = landlord_adp if battle.challenger_position == 'landlord' else farmer_adp
|
||||||
|
battle.status = 1 if challenge_success else 2
|
||||||
|
battle.save()
|
||||||
|
|
||||||
|
def tick():
|
||||||
|
battles = Battle.select().where(Battle.status == 0).order_by(Battle.id.desc())
|
||||||
|
for battle in battles:
|
||||||
|
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
|
||||||
|
if len(baselines) == 0:
|
||||||
|
baseline = {}
|
||||||
|
for position in positions:
|
||||||
|
models = Model.select().where(Model.position == position).order_by(Model.create_time.desc()).limit(1)
|
||||||
|
if(len(models) > 0):
|
||||||
|
baseline['%s_path' % position] = models[0].path
|
||||||
|
if len(baseline.keys()) == 4:
|
||||||
|
baseline['rank'] = 0
|
||||||
|
baseline['landlord_wp'] = 0
|
||||||
|
baseline['farmer_wp'] = 0
|
||||||
|
baseline['landlord_adp'] = 0
|
||||||
|
baseline['farmer_adp'] = 0
|
||||||
|
baseline['create_time'] = datetime.now()
|
||||||
|
Baseline.create(**baseline)
|
||||||
|
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
|
||||||
|
battle_logic(baselines[0], battle)
|
||||||
|
else:
|
||||||
|
battle.status = 3
|
||||||
|
battle.save()
|
||||||
|
else:
|
||||||
|
battle_logic(baselines[0], battle)
|
|
@ -0,0 +1,44 @@
|
||||||
|
# -*- coding:utf8 -*-
|
||||||
|
from peewee import *
|
||||||
|
db = SqliteDatabase('model.db')
|
||||||
|
|
||||||
|
class BaseModel(Model):
|
||||||
|
class Meta:
|
||||||
|
database = db
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
path = CharField(primary_key=True)
|
||||||
|
position = CharField(null = False)
|
||||||
|
type = CharField(null = False)
|
||||||
|
frame = IntegerField(null = False)
|
||||||
|
create_time = DateTimeField(null=False)
|
||||||
|
class Meta:
|
||||||
|
order_by = ('type',)
|
||||||
|
db_table = 'model'
|
||||||
|
|
||||||
|
class Battle(BaseModel):
|
||||||
|
id = PrimaryKeyField()
|
||||||
|
challenger_path = CharField(null = False)
|
||||||
|
challenger_position = CharField(null = False)
|
||||||
|
status = IntegerField(null = False)
|
||||||
|
challenger_wp = DecimalField(null=True)
|
||||||
|
challenger_adp = DecimalField(null=True)
|
||||||
|
class Meta:
|
||||||
|
order_by = ('id',)
|
||||||
|
db_table = 'battle'
|
||||||
|
|
||||||
|
class Baseline(BaseModel):
|
||||||
|
id = PrimaryKeyField()
|
||||||
|
landlord_path = ForeignKeyField(Model, to_field='path',related_name = "model")
|
||||||
|
landlord_up_path = ForeignKeyField(Model, to_field='path',related_name = "model")
|
||||||
|
landlord_front_path = ForeignKeyField(Model, to_field='path',related_name = "model")
|
||||||
|
landlord_down_path = ForeignKeyField(Model, to_field='path',related_name = "model")
|
||||||
|
rank = IntegerField(null = False)
|
||||||
|
landlord_wp = DecimalField(null=False)
|
||||||
|
farmer_wp = DecimalField(null=False)
|
||||||
|
landlord_adp = DecimalField(null=False)
|
||||||
|
farmer_adp = DecimalField(null=False)
|
||||||
|
create_time = DateTimeField(null=False)
|
||||||
|
class Meta:
|
||||||
|
order_by = ('id',)
|
||||||
|
db_table = 'baseline'
|
|
@ -0,0 +1,46 @@
|
||||||
|
from douzero.evaluation.simulation import evaluate
|
||||||
|
|
||||||
|
from douzero.server.orm import Model, Battle, Baseline
|
||||||
|
from douzero.server.battle import tick
|
||||||
|
from flask import Flask, jsonify, request
|
||||||
|
from flask_cors import CORS
|
||||||
|
from datetime import datetime
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
CORS(app)
|
||||||
|
|
||||||
|
Model.create_table()
|
||||||
|
Battle.create_table()
|
||||||
|
Baseline.create_table()
|
||||||
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
|
|
||||||
|
threadpool = ThreadPoolExecutor(1)
|
||||||
|
|
||||||
|
@app.route('/upload', methods=['POST'])
|
||||||
|
def upload():
|
||||||
|
type = request.form.get('type')
|
||||||
|
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']:
|
||||||
|
return jsonify({'status': -1, 'message': 'illegal type'})
|
||||||
|
position = request.form.get("position")
|
||||||
|
if position not in positions:
|
||||||
|
return jsonify({'status': -2, 'message': 'illegal position'})
|
||||||
|
frame = int(request.form.get("frame"))
|
||||||
|
model_file = request.files.get('model_file')
|
||||||
|
if model_file is None:
|
||||||
|
return jsonify({'status': -3, 'message': 'illegal model_file'})
|
||||||
|
path = "baselines_server/%s_%s_%d.ckpt" % (type, position, frame)
|
||||||
|
model = Model.get_or_none(Model.path == path)
|
||||||
|
if model is None:
|
||||||
|
model_file.save(path)
|
||||||
|
Model.create(path=path, position=position,type=type,frame=frame,create_time=datetime.now())
|
||||||
|
Battle.create(challenger_path=path, challenger_position=position, status=0)
|
||||||
|
threadpool.submit(tick)
|
||||||
|
return jsonify({'status': 0, 'message': 'success', 'result': ''})
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description='DouZero evaluation backend')
|
||||||
|
parser.add_argument('--debug', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
app.run(debug=args.debug, host="0.0.0.0")
|
|
@ -0,0 +1,3 @@
|
||||||
|
flask==1.1
|
||||||
|
flask-cors
|
||||||
|
peewee
|
|
@ -5,3 +5,4 @@ rlcard
|
||||||
psutil
|
psutil
|
||||||
onnx
|
onnx
|
||||||
onnxruntime-gpu==1.7
|
onnxruntime-gpu==1.7
|
||||||
|
requests
|
||||||
|
|
Loading…
Reference in New Issue