diff --git a/douzero/dmc/arguments.py b/douzero/dmc/arguments.py index b7b94e6..05a7e2a 100644 --- a/douzero/dmc/arguments.py +++ b/douzero/dmc/arguments.py @@ -41,6 +41,10 @@ parser.add_argument('--enable_onnx', action='store_true', help='Use onnx model for train') parser.add_argument('--onnx_model_path', default='douzero_checkpoints', help='Root dir where onnx temp model will be saved') +parser.add_argument('--enable_upload', action='store_true', + help='Should the cpkt model will be upload to server') +parser.add_argument('--upload_url', default='https://dou.zaneyork.cn:8443/model/upload', + help='The cpkt model will be upload to') # Hyperparameters parser.add_argument('--total_frames', default=100000000000, type=int, diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 1bdecc1..1a309f8 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -17,6 +17,7 @@ from .models import Model, OldModel from .utils import get_batch, log, create_env, create_optimizers, act import psutil import shutil +import requests mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']} @@ -244,6 +245,21 @@ def train(flags): model_weights_dir = os.path.expandvars(os.path.expanduser( '%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt'))) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) + if flags.enable_upload: + if flags.lite_model: + type = 'lite_' + else: + type = '' + if flags.old_model: + type += 'vanilla' + else: + type += 'resnet' + requests.post(flags.upload_url, data={ + 'type': type, + 'position': position, + 'frame': frames + }, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))}) + os.remove(model_weights_dir) shutil.move(checkpointpath + '.new', checkpointpath) fps_log = [] @@ -297,6 +313,7 @@ def train(flags): proc['actor'] = actor except KeyboardInterrupt: + flags.enable_upload = False checkpoint(frames) return else: diff --git a/douzero/server/__init__.py b/douzero/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/douzero/server/battle.py b/douzero/server/battle.py new file mode 100644 index 0000000..fad1e5a --- /dev/null +++ b/douzero/server/battle.py @@ -0,0 +1,99 @@ +import os + +from .orm import Baseline, Battle, Model +from evaluate import evaluate +from datetime import datetime + +positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] + +def battle_logic(baseline : Baseline, battle : Battle): + eval_data_first = 'eval_data_200.pkl' + eval_data_second = 'eval_data_800.pkl' + challenger_baseline = { + 'landlord_path': str(baseline.landlord_path), + 'landlord_up_path': str(baseline.landlord_up_path), + 'landlord_front_path': str(baseline.landlord_front_path), + 'landlord_down_path': str(baseline.landlord_down_path), + } + challenger_baseline[battle.challenger_position + "_path"] = str(battle.challenger_path) + + landlord_wp, farmer_wp, landlord_adp, farmer_adp = \ + evaluate(challenger_baseline['landlord_path'], + challenger_baseline['landlord_up_path'], + challenger_baseline['landlord_front_path'], + challenger_baseline['landlord_down_path'], + eval_data_first, + 1, + False, + 'New') + def _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp): + landlord_wp_2, farmer_wp_2, landlord_adp_2, farmer_adp_2 = \ + evaluate(challenger_baseline['landlord_path'], + challenger_baseline['landlord_up_path'], + challenger_baseline['landlord_front_path'], + challenger_baseline['landlord_down_path'], + eval_data_second, + 1, + False, + 'New') + return (landlord_wp + landlord_wp_2 * 4.0) / 5, \ + (farmer_wp + farmer_wp_2 * 4.0) / 5, \ + (landlord_adp + landlord_adp_2 * 4.0) / 5, \ + (farmer_adp + farmer_adp_2 * 4.0) / 5 + + challenge_success = False + if battle.challenger_position == 'landlord': + if landlord_wp - 0.15 > baseline.landlord_wp: + landlord_wp, farmer_wp, landlord_adp, farmer_adp = \ + _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp) + if landlord_wp - 0.15 > baseline.landlord_wp: + challenge_success = True + else: + if farmer_wp - 0.15 > baseline.farmer_wp: + landlord_wp, farmer_wp, landlord_adp, farmer_adp = \ + _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp) + if farmer_wp - 0.15 > baseline.farmer_wp: + challenge_success = True + if challenge_success: + challenger_baseline['rank'] = baseline.rank + 1 + challenger_baseline['landlord_wp'] = landlord_wp + challenger_baseline['landlord_adp'] = landlord_adp + challenger_baseline['farmer_wp'] = farmer_wp + challenger_baseline['farmer_adp'] = farmer_adp + challenger_baseline['create_time'] = datetime.now() + Baseline.create(**challenger_baseline) + else: + onnx_path = str(battle.challenger_path) + '.onnx' + if os.path.exists(onnx_path): + os.remove(onnx_path) + os.remove(str(battle.challenger_path)) + battle.challenger_wp = landlord_wp if battle.challenger_position == 'landlord' else farmer_wp + battle.challenger_adp = landlord_adp if battle.challenger_position == 'landlord' else farmer_adp + battle.status = 1 if challenge_success else 2 + battle.save() + +def tick(): + battles = Battle.select().where(Battle.status == 0).order_by(Battle.id.desc()) + for battle in battles: + baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1) + if len(baselines) == 0: + baseline = {} + for position in positions: + models = Model.select().where(Model.position == position).order_by(Model.create_time.desc()).limit(1) + if(len(models) > 0): + baseline['%s_path' % position] = models[0].path + if len(baseline.keys()) == 4: + baseline['rank'] = 0 + baseline['landlord_wp'] = 0 + baseline['farmer_wp'] = 0 + baseline['landlord_adp'] = 0 + baseline['farmer_adp'] = 0 + baseline['create_time'] = datetime.now() + Baseline.create(**baseline) + baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1) + battle_logic(baselines[0], battle) + else: + battle.status = 3 + battle.save() + else: + battle_logic(baselines[0], battle) diff --git a/douzero/server/orm.py b/douzero/server/orm.py new file mode 100644 index 0000000..8b2874d --- /dev/null +++ b/douzero/server/orm.py @@ -0,0 +1,44 @@ +# -*- coding:utf8 -*- +from peewee import * +db = SqliteDatabase('model.db') + +class BaseModel(Model): + class Meta: + database = db + +class Model(BaseModel): + path = CharField(primary_key=True) + position = CharField(null = False) + type = CharField(null = False) + frame = IntegerField(null = False) + create_time = DateTimeField(null=False) + class Meta: + order_by = ('type',) + db_table = 'model' + +class Battle(BaseModel): + id = PrimaryKeyField() + challenger_path = CharField(null = False) + challenger_position = CharField(null = False) + status = IntegerField(null = False) + challenger_wp = DecimalField(null=True) + challenger_adp = DecimalField(null=True) + class Meta: + order_by = ('id',) + db_table = 'battle' + +class Baseline(BaseModel): + id = PrimaryKeyField() + landlord_path = ForeignKeyField(Model, to_field='path',related_name = "model") + landlord_up_path = ForeignKeyField(Model, to_field='path',related_name = "model") + landlord_front_path = ForeignKeyField(Model, to_field='path',related_name = "model") + landlord_down_path = ForeignKeyField(Model, to_field='path',related_name = "model") + rank = IntegerField(null = False) + landlord_wp = DecimalField(null=False) + farmer_wp = DecimalField(null=False) + landlord_adp = DecimalField(null=False) + farmer_adp = DecimalField(null=False) + create_time = DateTimeField(null=False) + class Meta: + order_by = ('id',) + db_table = 'baseline' diff --git a/evaluate_server.py b/evaluate_server.py new file mode 100644 index 0000000..9ed1355 --- /dev/null +++ b/evaluate_server.py @@ -0,0 +1,46 @@ +from douzero.evaluation.simulation import evaluate + +from douzero.server.orm import Model, Battle, Baseline +from douzero.server.battle import tick +from flask import Flask, jsonify, request +from flask_cors import CORS +from datetime import datetime +from concurrent.futures import ThreadPoolExecutor + +app = Flask(__name__) +CORS(app) + +Model.create_table() +Battle.create_table() +Baseline.create_table() +positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] + +threadpool = ThreadPoolExecutor(1) + +@app.route('/upload', methods=['POST']) +def upload(): + type = request.form.get('type') + if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']: + return jsonify({'status': -1, 'message': 'illegal type'}) + position = request.form.get("position") + if position not in positions: + return jsonify({'status': -2, 'message': 'illegal position'}) + frame = int(request.form.get("frame")) + model_file = request.files.get('model_file') + if model_file is None: + return jsonify({'status': -3, 'message': 'illegal model_file'}) + path = "baselines_server/%s_%s_%d.ckpt" % (type, position, frame) + model = Model.get_or_none(Model.path == path) + if model is None: + model_file.save(path) + Model.create(path=path, position=position,type=type,frame=frame,create_time=datetime.now()) + Battle.create(challenger_path=path, challenger_position=position, status=0) + threadpool.submit(tick) + return jsonify({'status': 0, 'message': 'success', 'result': ''}) + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='DouZero evaluation backend') + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + app.run(debug=args.debug, host="0.0.0.0") diff --git a/requirements-server.txt b/requirements-server.txt new file mode 100644 index 0000000..e16acbb --- /dev/null +++ b/requirements-server.txt @@ -0,0 +1,3 @@ +flask==1.1 +flask-cors +peewee diff --git a/requirements.txt b/requirements.txt index cab6fd3..cff6e50 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ rlcard psutil onnx onnxruntime-gpu==1.7 +requests