Douzero_Resnet/evaluate_server.py

292 lines
13 KiB
Python

import itertools
import os
from peewee import JOIN
from douzero.server.orm import Model, Battle, Baseline
from douzero.server.battle import tick
from flask import Flask, jsonify, request
from flask_cors import CORS
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from douzero.env.game import get_legal_card_play_actions
from douzero.evaluation.deep_agent import DeepAgent
app = Flask(__name__)
CORS(app)
Model.create_table()
Battle.create_table()
Baseline.create_table()
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'}
threadpool = ThreadPoolExecutor(1)
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
if len(baselines) >= 1:
baseline = baselines[0]
try:
players = [
DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True),
DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True),
DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True),
DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True)
]
except:
pass
@app.route('/upload', methods=['POST'])
def upload():
type = request.form.get('type')
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']:
return jsonify({'status': -1, 'message': 'illegal type'})
position = request.form.get("position")
if position not in positions:
return jsonify({'status': -2, 'message': 'illegal position'})
frame = int(request.form.get("frame"))
model_file = request.files.get('model_file')
if model_file is None:
return jsonify({'status': -3, 'message': 'illegal model_file'})
path = "baselines_server/%s_%s_%d.ckpt" % (type, position, frame)
model = Model.get_or_none(Model.path == path)
if model is None:
model_file.save(path)
Model.create(path=path, position=position,type=type,frame=frame,create_time=datetime.now())
Battle.create(challenger_path=path, challenger_position=position, status=0)
threadpool.submit(tick)
return jsonify({'status': 0, 'message': 'success', 'result': ''})
@app.route('/battle_tick', methods=['POST'])
def battle_tick():
tick()
return jsonify({'status': 0, 'message': 'success', 'result': ''})
@app.route('/charts', methods=['GET'])
def charts():
return app.send_static_file("charts.html")
@app.route('/metrics', methods=['GET'])
def metrics():
type = request.args.get('type')
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(3)
end_time = datetime.now()
metrics = {}
for i in range(len(baselines)):
baseline = baselines[i]
baseline_metric = {
'baseline': {
'landlord_path': os.path.basename(str(baseline.landlord_path)).split('.')[0],
'landlord_up_path': os.path.basename(str(baseline.landlord_up_path)).split('.')[0].replace('landlord_', ''),
'landlord_down_path': os.path.basename(str(baseline.landlord_down_path)).split('.')[0].replace('landlord_', ''),
'landlord_front_path': os.path.basename(str(baseline.landlord_front_path)).split('.')[0].replace('landlord_', ''),
'landlord_wp': '%.4f' % float(baseline.landlord_wp),
'landlord_adp': '%.4f' % float(baseline.landlord_adp),
'farmer_wp': '%.4f' % float(baseline.farmer_wp),
'farmer_adp': '%.4f' % float(baseline.farmer_adp),
'create_time': str(baseline.create_time)
},
'landlord': {},
'landlord_up': {},
'landlord_front': {},
'landlord_down': {}
}
results = (
Model
.select(Model.frame, Model.path, Battle.challenger_position, Battle.challenger_wp, Battle.challenger_adp)
.where(
Model.type == type,
Model.create_time >= baseline.create_time,
Model.create_time <= end_time
)
.join(Battle, JOIN.INNER, on=(
(Battle.challenger_path == Model.path) &
(Battle.status > 0) &
(Battle.status != 3))
)
.order_by(Model.create_time.asc())
)
end_time = baseline.create_time
for result in results:
battle = result.battle
baseline_metric[str(battle.challenger_position)][result.frame] = {
'wp': '%.4f' % float(battle.challenger_wp),
'adp': '%.4f' % float(battle.challenger_adp)
}
metrics[baseline.rank] = baseline_metric
return jsonify({'status': 0, 'message': 'success', 'result': metrics})
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
try:
# Player postion
player_position = request.form.get('player_position')
if player_position not in ['0', '1', '2', '3']:
return jsonify({'status': 1, 'message': 'player_position must be 0, 1, 2 or 3'})
player_position = int(player_position)
# Player hand cards
player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]
if player_position == 0:
if len(player_hand_cards) < 1 or len(player_hand_cards) > 33:
return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-33'})
else:
if len(player_hand_cards) < 1 or len(player_hand_cards) > 25:
return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-25'})
# Number cards left
num_cards_left = [
int(request.form.get('num_cards_left_landlord')),
int(request.form.get('num_cards_left_landlord_down')),
int(request.form.get('num_cards_left_landlord_front')),
int(request.form.get('num_cards_left_landlord_up'))
]
if num_cards_left[player_position] != len(player_hand_cards):
return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'})
if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[3] < 0 \
or num_cards_left[0] > 33 or num_cards_left[1] > 25 or num_cards_left[2] > 25 or num_cards_left[2] > 25:
return jsonify({'status': 5, 'message': 'the number of cards left not in range'})
# Card play sequence
if request.form.get('card_play_action_seq') == '':
card_play_action_seq = []
else:
card_play_action_seq = [['', [RealCard2EnvCard[c] for c in cards]] for cards in request.form.get('card_play_action_seq').split(',')]
# Other hand cards
other_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')]
if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]:
return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'})
# Last moves
last_moves = []
for field in ['last_move_landlord', 'last_move_landlord_down', 'last_move_landlord_front', 'last_move_landlord_up']:
last_moves.append([RealCard2EnvCard[c] for c in request.form.get(field)])
# Played cards
played_cards = {}
for idx, field in enumerate(['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_front', 'played_cards_landlord_up']):
played_cards[idx_position[idx]] = [RealCard2EnvCard[c] for c in request.form.get(field)]
# Bomb Num
bomb_num = request.form.get('bomb_num')
# InfoSet
info_set = InfoSet()
info_set.player_position = idx_position[player_position]
info_set.player_hand_cards = player_hand_cards
info_set.num_cards_left = num_cards_left
info_set.card_play_action_seq = card_play_action_seq
info_set.other_hand_cards = other_hand_cards
info_set.last_moves = last_moves
info_set.played_cards = played_cards
info_set.bomb_num = [int(x) for x in str.split(bomb_num, ',')]
info_set.num_cards_left_dict['landlord'] = num_cards_left[0]
info_set.num_cards_left_dict['landlord_down'] = num_cards_left[1]
info_set.num_cards_left_dict['landlord_front'] = num_cards_left[2]
info_set.num_cards_left_dict['landlord_up'] = num_cards_left[3]
# Get rival move and legal_actions
rival_move = []
if len(card_play_action_seq) != 0:
if len(card_play_action_seq[-1][1]) == 0:
if len(card_play_action_seq[-2][1]) == 0:
rival_move = card_play_action_seq[-3][1]
else:
rival_move = card_play_action_seq[-2][1]
else:
rival_move = card_play_action_seq[-1][1]
info_set.rival_move = rival_move
info_set.legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
# Prediction
actions, actions_confidence = players[player_position].act(info_set, True)
actions = [''.join([EnvCard2RealCard[a] for a in action]) for action in actions]
result = {}
win_rates = {}
for i in range(len(actions)):
# Here, we calculate the win rate
win_rate = max(actions_confidence[i], -1)
win_rate = min(win_rate, 1)
win_rates[actions[i]] = str(round((win_rate + 1) / 2, 4))
result[actions[i]] = str(round(actions_confidence[i], 6))
############## DEBUG ################
if app.debug:
print('--------------- DEBUG START --------------')
command = 'curl --data "'
parameters = []
for key in request.form:
parameters.append(key+'='+request.form.get(key))
print(key+':', request.form.get(key))
command += '&'.join(parameters)
command += '" "http://127.0.0.1:5000/predict"'
print('Command:', command)
print('Rival Move:', rival_move)
print('legal_actions:', info_set.legal_actions)
print('Result:', result)
print('--------------- DEBUG END --------------')
############## DEBUG ################
return jsonify({'status': 0, 'message': 'success', 'result': result, 'win_rates': win_rates})
except:
import traceback
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
@app.route('/legal', methods=['POST'])
def legal():
if request.method == 'POST':
try:
player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]
rival_move = [RealCard2EnvCard[c] for c in request.form.get('rival_move')]
legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
legal_actions = ','.join([''.join([EnvCard2RealCard[a] for a in action]) for action in legal_actions])
return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions})
except:
import traceback
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
class InfoSet(object):
def __init__(self):
self.player_position = None
self.player_hand_cards = None
self.num_cards_left = None
self.num_cards_left_dict = {}
self.all_handcards = {}
self.card_play_action_seq = None
self.other_hand_cards = None
self.legal_actions = None
self.rival_move = None
self.last_moves = None
self.played_cards = None
self.bomb_num = None
def _get_legal_card_play_actions(player_hand_cards, rival_move):
moves = get_legal_card_play_actions(player_hand_cards, rival_move)
moves.sort()
moves = list(move for move,_ in itertools.groupby(moves))
return moves
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='DouZero evaluation backend')
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
app.run(debug=args.debug, host="0.0.0.0")