添加web调用支持

This commit is contained in:
zhiyang7 2022-02-10 10:13:41 +08:00
parent 2760188f5a
commit 0c90abe6b9
2 changed files with 40 additions and 12 deletions

Binary file not shown.

52
main.py
View File

@ -3,8 +3,15 @@ import getopt
import pandas as pd
import re
import math
import json
from flask import Flask, jsonify, request
from flask_cors import CORS
from pypinyin import pinyin, lazy_pinyin, Style
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False
CORS(app)
def trim_space(s):
return s.replace(' ', '').replace(',', '').replace('', '')
@ -62,6 +69,13 @@ def main(argv):
parameter = arg
elif opt in ("-n", "--num"):
num = int(arg)
all_idiom, group = filter_logic(mode, parameter)
if all_idiom is not None and len(group) > 0:
print_max_group(all_idiom, group, num)
else:
print('未找到匹配项')
def filter_logic(mode, parameter):
if os.path.exists("all_idiom.csv"):
all_idiom = pd.read_csv('all_idiom.csv')
else:
@ -76,7 +90,7 @@ def main(argv):
all_idiom['pinyin_rt'] = all_idiom.apply(lambda x: ''.join(map(lambda y: str(len(y)), re.split('[ ,]',x['pinyin_r']))), axis=1)
groups = all_idiom.groupby(by='pinyin_rt')
group = groups.get_group(parameter).copy()
print_max_group(all_idiom, group, num)
return all_idiom, group
elif mode == '1':
parameter_rst = parameter.split(';', 1)
if len(parameter_rst) > 1:
@ -105,16 +119,15 @@ def main(argv):
hits = parameter.split(',')[1]
parameter = trim_space(parameter.split(',')[0])
elif len(group) <= 0:
print('未找到匹配项')
return
break
else:
break
print_max_group(all_idiom, group, num)
return all_idiom, group
elif mode == '2':
all_idiom = all_idiom[all_idiom['word'].str.len() == 4]
all_idiom['pinyin_tone'] = all_idiom.apply(lambda x: get_tone(x['pinyin']), axis=1)
group = all_idiom[all_idiom['pinyin_tone'].str.startswith(parameter)].copy()
print_max_group(all_idiom, group, num)
return all_idiom, group
elif mode == '3':
all_idiom = all_idiom[all_idiom['word'].str.len() == 4]
parameter_rst = parameter.split(';', 1)
@ -152,11 +165,11 @@ def main(argv):
parameter = parameter[:-5]
hits=hits[:-10]
elif len(group) <= 0:
print('未找到匹配项')
return
break
else:
break
print_max_group(all_idiom, group, num)
return all_idiom, group
return None, None
def filter_group_model2(parameter, group, hits, tones, tone_hits, word_hits):
group = filter_with_target_field(group, 'word', parameter, word_hits)
@ -203,10 +216,16 @@ def filter_with_target_field(group, field_name, values, value_hits):
return group
def print_max_group(all_idiom, group, num):
for item in get_max_group(all_idiom, group, num):
print(item)
def get_max_group(all_idiom, group, num):
group['pinyin_c'] = group.apply(lambda x: (math.log(x['frequency'], 2)/16 + 1) * len(set(trim_space(x['pinyin_r']))), axis=1)
list = group.nlargest(num, ['pinyin_c', 'frequency']).index.tolist()
for i in list:
print(all_idiom.loc[i])
ret_list = group.nlargest(num, ['pinyin_c', 'frequency']).index.tolist()
result = list()
for i in ret_list:
result.append(json.loads(all_idiom.loc[i].to_json(orient = 'index',force_ascii=False)))
return result
'''
parameter: xxx xx xx xxx
@ -240,7 +259,16 @@ def filter_group_mode1(parameter, group, hits):
break
return group
@app.route('/predict', methods=['GET'])
def predict():
mode = request.args.get('mode')
parameter = request.args.get('parameter')
all_idiom, group = filter_logic(mode, parameter)
result = get_max_group(all_idiom, group, 3)
return jsonify({'status': 0, 'message': 'success', 'result': result})
if __name__ == '__main__':
current_work_dir = os.path.dirname(__file__)
os.chdir(current_work_dir)
main(sys.argv[1:])
# main(sys.argv[1:])
app.run(debug=True, host="0.0.0.0")