From de217c48abf72824ecd3801e1b2415db06c6bb10 Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Wed, 5 Jan 2022 09:54:42 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=B8=8A=E4=BC=A0=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/dmc.py | 13 ++++++++----- evaluate_server.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index efafb12..2d9aa0c 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -274,11 +274,14 @@ def train(flags): type += 'unified' else: type += 'resnet' - requests.post(flags.upload_url, data={ - 'type': type, - 'position': position, - 'frame': frames - }, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))}) + try: + requests.post(flags.upload_url, data={ + 'type': type, + 'position': position, + 'frame': frames + }, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))}) + except: + print("模型上传失败") os.remove(model_weights_dir) shutil.move(checkpointpath + '.new', checkpointpath) diff --git a/evaluate_server.py b/evaluate_server.py index 708b79f..aff29c1 100644 --- a/evaluate_server.py +++ b/evaluate_server.py @@ -28,7 +28,7 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, @app.route('/upload', methods=['POST']) def upload(): type = request.form.get('type') - if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']: + if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified']: return jsonify({'status': -1, 'message': 'illegal type'}) position = request.form.get("position") if position not in positions: