添加上传逻辑

This commit is contained in:
zhiyang7 2022-01-05 09:54:42 +08:00
parent 5276c4e3c6
commit de217c48ab
2 changed files with 9 additions and 6 deletions

View File

@ -274,11 +274,14 @@ def train(flags):
type += 'unified' type += 'unified'
else: else:
type += 'resnet' type += 'resnet'
requests.post(flags.upload_url, data={ try:
'type': type, requests.post(flags.upload_url, data={
'position': position, 'type': type,
'frame': frames 'position': position,
}, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))}) 'frame': frames
}, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))})
except:
print("模型上传失败")
os.remove(model_weights_dir) os.remove(model_weights_dir)
shutil.move(checkpointpath + '.new', checkpointpath) shutil.move(checkpointpath + '.new', checkpointpath)

View File

@ -28,7 +28,7 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
@app.route('/upload', methods=['POST']) @app.route('/upload', methods=['POST'])
def upload(): def upload():
type = request.form.get('type') type = request.form.get('type')
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']: if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified']:
return jsonify({'status': -1, 'message': 'illegal type'}) return jsonify({'status': -1, 'message': 'illegal type'})
position = request.form.get("position") position = request.form.get("position")
if position not in positions: if position not in positions: