diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index efafb12..2d9aa0c 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -274,11 +274,14 @@ def train(flags): type += 'unified' else: type += 'resnet' - requests.post(flags.upload_url, data={ - 'type': type, - 'position': position, - 'frame': frames - }, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))}) + try: + requests.post(flags.upload_url, data={ + 'type': type, + 'position': position, + 'frame': frames + }, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))}) + except: + print("模型上传失败") os.remove(model_weights_dir) shutil.move(checkpointpath + '.new', checkpointpath) diff --git a/evaluate_server.py b/evaluate_server.py index 708b79f..aff29c1 100644 --- a/evaluate_server.py +++ b/evaluate_server.py @@ -28,7 +28,7 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, @app.route('/upload', methods=['POST']) def upload(): type = request.form.get('type') - if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']: + if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified']: return jsonify({'status': -1, 'message': 'illegal type'}) position = request.form.get("position") if position not in positions: