Douzero_Resnet/evaluate_server.py

289 lines
13 KiB
Python
Raw Normal View History

import itertools
2021-12-27 11:22:49 +08:00
import os
2021-12-28 09:18:33 +08:00
import time
import threading
2021-12-27 11:22:49 +08:00
from peewee import JOIN
2021-12-28 16:05:12 +08:00
from douzero.server.orm import init_db, Model, Battle, Baseline
from douzero.server.battle import init_battlefield, tick, baseline_players, positions, idx_position
2021-12-25 19:06:34 +08:00
from flask import Flask, jsonify, request
from flask_cors import CORS
from datetime import datetime
2021-12-27 09:44:40 +08:00
from douzero.env.game import get_legal_card_play_actions
2021-12-25 19:06:34 +08:00
app = Flask(__name__)
CORS(app)
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
2021-12-25 19:06:34 +08:00
@app.route('/upload', methods=['POST'])
def upload():
type = request.form.get('type')
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']:
return jsonify({'status': -1, 'message': 'illegal type'})
position = request.form.get("position")
if position not in positions:
return jsonify({'status': -2, 'message': 'illegal position'})
frame = int(request.form.get("frame"))
model_file = request.files.get('model_file')
if model_file is None:
return jsonify({'status': -3, 'message': 'illegal model_file'})
path = "baselines_server/%s_%s_%d.ckpt" % (type, position, frame)
model = Model.get_or_none(Model.path == path)
if model is None:
model_file.save(path)
Model.create(path=path, position=position,type=type,frame=frame,create_time=datetime.now())
Battle.create(challenger_path=path, challenger_position=position, status=0)
return jsonify({'status': 0, 'message': 'success', 'result': ''})
2021-12-28 20:16:24 +08:00
def start_runner(flags):
2021-12-28 09:18:33 +08:00
def start_loop():
while True:
2021-12-28 20:16:24 +08:00
tick(flags)
2021-12-28 09:18:33 +08:00
time.sleep(10)
thread = threading.Thread(target=start_loop)
2021-12-28 09:39:40 +08:00
thread.setDaemon(True)
2021-12-28 09:18:33 +08:00
thread.start()
2021-12-27 10:10:41 +08:00
@app.route('/charts', methods=['GET'])
def charts():
return app.send_static_file("charts.html")
2021-12-25 22:41:25 +08:00
@app.route('/metrics', methods=['GET'])
def metrics():
type = request.args.get('type')
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(3)
end_time = datetime.now()
metrics = {}
for i in range(len(baselines)):
baseline = baselines[i]
2021-12-26 12:36:43 +08:00
baseline_metric = {
2021-12-27 10:06:35 +08:00
'baseline': {
2021-12-27 11:22:49 +08:00
'landlord_path': os.path.basename(str(baseline.landlord_path)).split('.')[0],
'landlord_up_path': os.path.basename(str(baseline.landlord_up_path)).split('.')[0].replace('landlord_', ''),
'landlord_down_path': os.path.basename(str(baseline.landlord_down_path)).split('.')[0].replace('landlord_', ''),
'landlord_front_path': os.path.basename(str(baseline.landlord_front_path)).split('.')[0].replace('landlord_', ''),
2021-12-27 10:06:35 +08:00
'landlord_wp': '%.4f' % float(baseline.landlord_wp),
'landlord_adp': '%.4f' % float(baseline.landlord_adp),
'farmer_wp': '%.4f' % float(baseline.farmer_wp),
'farmer_adp': '%.4f' % float(baseline.farmer_adp),
'create_time': str(baseline.create_time)
},
2021-12-26 12:36:43 +08:00
'landlord': {},
'landlord_up': {},
'landlord_front': {},
'landlord_down': {}
}
2021-12-27 11:22:49 +08:00
results = (
Model
.select(Model.frame, Model.path, Battle.challenger_position, Battle.challenger_wp, Battle.challenger_adp)
.where(
Model.type == type,
Model.create_time >= baseline.create_time,
Model.create_time <= end_time
)
.join(Battle, JOIN.INNER, on=(
(Battle.challenger_path == Model.path) &
2021-12-27 15:08:51 +08:00
(Battle.status > 0))
2021-12-27 11:22:49 +08:00
)
.order_by(Model.create_time.asc())
)
2021-12-25 22:41:25 +08:00
end_time = baseline.create_time
2021-12-27 11:22:49 +08:00
for result in results:
battle = result.battle
baseline_metric[str(battle.challenger_position)][result.frame] = {
'wp': '%.4f' % float(battle.challenger_wp),
'adp': '%.4f' % float(battle.challenger_adp)
}
2021-12-25 22:41:25 +08:00
metrics[baseline.rank] = baseline_metric
return jsonify({'status': 0, 'message': 'success', 'result': metrics})
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
try:
# Player postion
player_position = request.form.get('player_position')
if player_position not in ['0', '1', '2', '3']:
return jsonify({'status': 1, 'message': 'player_position must be 0, 1, 2 or 3'})
player_position = int(player_position)
# Player hand cards
player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]
if player_position == 0:
if len(player_hand_cards) < 1 or len(player_hand_cards) > 33:
return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-33'})
else:
if len(player_hand_cards) < 1 or len(player_hand_cards) > 25:
return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-25'})
# Number cards left
num_cards_left = [
int(request.form.get('num_cards_left_landlord')),
int(request.form.get('num_cards_left_landlord_down')),
int(request.form.get('num_cards_left_landlord_front')),
int(request.form.get('num_cards_left_landlord_up'))
]
if num_cards_left[player_position] != len(player_hand_cards):
return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'})
if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[3] < 0 \
or num_cards_left[0] > 33 or num_cards_left[1] > 25 or num_cards_left[2] > 25 or num_cards_left[2] > 25:
return jsonify({'status': 5, 'message': 'the number of cards left not in range'})
# Card play sequence
if request.form.get('card_play_action_seq') == '':
card_play_action_seq = []
else:
card_play_action_seq = [['', [RealCard2EnvCard[c] for c in cards]] for cards in request.form.get('card_play_action_seq').split(',')]
# Other hand cards
other_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')]
if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]:
return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'})
# Last moves
2021-12-28 19:48:01 +08:00
last_move_dict = {'landlord': [RealCard2EnvCard[c] for c in request.form.get('last_move_landlord')],
'landlord_up': [RealCard2EnvCard[c] for c in request.form.get('last_move_landlord_up')],
'landlord_front': [RealCard2EnvCard[c] for c in request.form.get('last_move_landlord_front')],
'landlord_down': [RealCard2EnvCard[c] for c in request.form.get('last_move_landlord_down')]}
# Played cards
played_cards = {}
for idx, field in enumerate(['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_front', 'played_cards_landlord_up']):
played_cards[idx_position[idx]] = [RealCard2EnvCard[c] for c in request.form.get(field)]
# Bomb Num
bomb_num = request.form.get('bomb_num')
# InfoSet
info_set = InfoSet()
info_set.player_position = idx_position[player_position]
info_set.player_hand_cards = player_hand_cards
info_set.num_cards_left = num_cards_left
info_set.card_play_action_seq = card_play_action_seq
info_set.other_hand_cards = other_hand_cards
2021-12-28 19:48:01 +08:00
info_set.last_move_dict = last_move_dict
info_set.played_cards = played_cards
info_set.bomb_num = [int(x) for x in str.split(bomb_num, ',')]
info_set.num_cards_left_dict['landlord'] = num_cards_left[0]
info_set.num_cards_left_dict['landlord_down'] = num_cards_left[1]
info_set.num_cards_left_dict['landlord_front'] = num_cards_left[2]
info_set.num_cards_left_dict['landlord_up'] = num_cards_left[3]
# Get rival move and legal_actions
rival_move = []
if len(card_play_action_seq) != 0:
if len(card_play_action_seq[-1][1]) == 0:
if len(card_play_action_seq[-2][1]) == 0:
rival_move = card_play_action_seq[-3][1]
else:
rival_move = card_play_action_seq[-2][1]
else:
rival_move = card_play_action_seq[-1][1]
2021-12-28 09:36:37 +08:00
info_set.last_move = rival_move
info_set.legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
# Prediction
2021-12-27 15:43:01 +08:00
actions, actions_confidence = baseline_players[player_position].act(info_set, True)
actions = [''.join([EnvCard2RealCard[a] for a in action]) for action in actions]
result = {}
win_rates = {}
for i in range(len(actions)):
# Here, we calculate the win rate
2021-12-29 09:57:03 +08:00
win_rate = max(actions_confidence[i] * 7.2, -1)
win_rate = min(win_rate, 1)
win_rates[actions[i]] = str(round((win_rate + 1) / 2, 4))
result[actions[i]] = str(round(actions_confidence[i], 6))
############## DEBUG ################
if app.debug:
print('--------------- DEBUG START --------------')
command = 'curl --data "'
parameters = []
for key in request.form:
parameters.append(key+'='+request.form.get(key))
print(key+':', request.form.get(key))
command += '&'.join(parameters)
command += '" "http://127.0.0.1:5000/predict"'
print('Command:', command)
print('Rival Move:', rival_move)
print('legal_actions:', info_set.legal_actions)
print('Result:', result)
print('--------------- DEBUG END --------------')
############## DEBUG ################
return jsonify({'status': 0, 'message': 'success', 'result': result, 'win_rates': win_rates})
except:
import traceback
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
@app.route('/legal', methods=['POST'])
def legal():
if request.method == 'POST':
try:
player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]
rival_move = [RealCard2EnvCard[c] for c in request.form.get('rival_move')]
legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
legal_actions = ','.join([''.join([EnvCard2RealCard[a] for a in action]) for action in legal_actions])
return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions})
except:
import traceback
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
class InfoSet(object):
def __init__(self):
self.player_position = None
self.player_hand_cards = None
self.num_cards_left = None
self.num_cards_left_dict = {}
self.all_handcards = {}
self.card_play_action_seq = None
self.other_hand_cards = None
self.legal_actions = None
2021-12-28 09:36:37 +08:00
self.last_move = None
2021-12-28 19:48:01 +08:00
self.last_move_dict = {'landlord': [],
'landlord_up': [],
'landlord_front': [],
'landlord_down': []}
self.last_moves = None
self.played_cards = None
self.bomb_num = None
def _get_legal_card_play_actions(player_hand_cards, rival_move):
2021-12-27 09:44:40 +08:00
moves = get_legal_card_play_actions(player_hand_cards, rival_move)
moves.sort()
moves = list(move for move,_ in itertools.groupby(moves))
return moves
2021-12-25 19:06:34 +08:00
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='DouZero evaluation backend')
parser.add_argument('--debug', action='store_true')
2021-12-28 16:05:12 +08:00
parser.add_argument('--enable_task', type=bool)
2021-12-29 16:48:32 +08:00
parser.add_argument('--update_threshold', type=float, default=0.09)
2021-12-28 16:05:12 +08:00
parser.add_argument('--db_host', type=str, default='127.0.0.1')
parser.add_argument('--db_schema', type=str, default='dou_model')
parser.add_argument('--db_user', type=str, default='douzero')
parser.add_argument('--db_passwd', type=str, default='VwjT6e0qf2t9iOH0')
2021-12-25 19:06:34 +08:00
args = parser.parse_args()
2021-12-28 16:05:12 +08:00
init_db(args)
init_battlefield(args)
if args.enable_task:
2021-12-28 20:16:24 +08:00
start_runner(args)
2021-12-25 19:06:34 +08:00
app.run(debug=args.debug, host="0.0.0.0")