调整晋升要求,添加评估接口

This commit is contained in:
ZaneYork 2021-12-26 20:11:05 +08:00
parent c4c008d034
commit 7b21149add
5 changed files with 295 additions and 23 deletions

View File

@ -12,9 +12,6 @@ from onnxruntime.datasets import get_example
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
class LandlordLstmModel(nn.Module): class LandlordLstmModel(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -6,7 +6,7 @@ from onnxruntime.datasets import get_example
from douzero.env.env import get_obs from douzero.env.env import get_obs
def _load_model(position, model_path, model_type, use_legacy, use_lite): def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False):
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite
model = None model = None
if model_type == "general": if model_type == "general":
@ -34,9 +34,9 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite):
if torch.cuda.is_available(): if torch.cuda.is_available():
model.cuda() model.cuda()
model.eval() model.eval()
onnx_params = model.get_onnx_params(torch.device('cpu'))
model_path = model_path + '.onnx' model_path = model_path + '.onnx'
if not os.path.exists(model_path): if use_onnx and not os.path.exists(model_path):
onnx_params = model.get_onnx_params(torch.device('cpu'))
torch.onnx.export( torch.onnx.export(
model, model,
onnx_params['args'], onnx_params['args'],
@ -53,28 +53,42 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite):
class DeepAgent: class DeepAgent:
def __init__(self, position, model_path): def __init__(self, position, model_path, use_onnx=False):
self.use_legacy = True if "legacy" in model_path else False self.use_legacy = True if "legacy" in model_path else False
self.lite_model = True if "lite" in model_path else False self.lite_model = True if "lite" in model_path else False
self.model_type = "general" if "resnet" in model_path else "old" self.model_type = "general" if "resnet" in model_path else "old"
self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model) self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model, use_onnx=use_onnx)
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider']) if use_onnx:
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
else:
self.onnx_model = None
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'} 13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
def act(self, infoset): def act(self, infoset, with_confidence = False):
if len(infoset.legal_actions) == 1: if not with_confidence and len(infoset.legal_actions) == 1:
return infoset.legal_actions[0] return infoset.legal_actions[0]
obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model) obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model)
# z_batch = torch.from_numpy(obs['z_batch']).float() if self.onnx_model is None:
# x_batch = torch.from_numpy(obs['x_batch']).float() z_batch = torch.from_numpy(obs['z_batch']).float()
# if torch.cuda.is_available(): x_batch = torch.from_numpy(obs['x_batch']).float()
# z_batch, x_batch = z_batch.cuda(), x_batch.cuda() if torch.cuda.is_available():
# y_pred = self.model.forward(z_batch, x_batch)['values'] z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
# y_pred = y_pred.detach().cpu().numpy() y_pred = self.model.forward(z_batch, x_batch)['values']
y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0] y_pred = y_pred.detach().cpu().numpy()
else:
y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0]
if with_confidence:
y_pred = y_pred.flatten()
size = min(3, len(y_pred))
best_action_index = np.argpartition(y_pred, -size)[-size:]
best_action_confidence = y_pred[best_action_index]
best_action = [infoset.legal_actions[index] for index in best_action_index]
return best_action, best_action_confidence
best_action_index = np.argmax(y_pred, axis=0)[0] best_action_index = np.argmax(y_pred, axis=0)[0]
best_action = infoset.legal_actions[best_action_index] best_action = infoset.legal_actions[best_action_index]

View File

@ -19,7 +19,7 @@ def load_card_play_models(card_play_model_path_dict):
players[position] = RandomAgent() players[position] = RandomAgent()
else: else:
from .deep_agent import DeepAgent from .deep_agent import DeepAgent
players[position] = DeepAgent(position, card_play_model_path_dict[position]) players[position] = DeepAgent(position, card_play_model_path_dict[position], use_onnx=True)
return players return players
def mp_simulate(card_play_data_list, card_play_model_path_dict, q, output, title): def mp_simulate(card_play_data_list, card_play_model_path_dict, q, output, title):

View File

@ -45,16 +45,16 @@ def battle_logic(baseline : Baseline, battle : Battle):
challenge_success = False challenge_success = False
if battle.challenger_position == 'landlord': if battle.challenger_position == 'landlord':
if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.2: if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.15:
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \ landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp) _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.2: if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.15:
challenge_success = True challenge_success = True
else: else:
if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.2: if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.05:
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \ landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp) _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.2: if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.05:
challenge_success = True challenge_success = True
if challenge_success: if challenge_success:
challenger_baseline['rank'] = baseline.rank + 1 challenger_baseline['rank'] = baseline.rank + 1

View File

@ -1,3 +1,5 @@
import itertools
from douzero.server.orm import Model, Battle, Baseline from douzero.server.orm import Model, Battle, Baseline
from douzero.server.battle import tick from douzero.server.battle import tick
from flask import Flask, jsonify, request from flask import Flask, jsonify, request
@ -5,6 +7,10 @@ from flask_cors import CORS
from datetime import datetime from datetime import datetime
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from douzero.env.move_generator import MovesGener
from douzero.env import move_detector as md, move_selector as ms
from douzero.evaluation.deep_agent import DeepAgent
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
@ -12,9 +18,29 @@ Model.create_table()
Battle.create_table() Battle.create_table()
Baseline.create_table() Baseline.create_table()
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'}
threadpool = ThreadPoolExecutor(1) threadpool = ThreadPoolExecutor(1)
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
if len(baselines) >= 1:
baseline = baselines[0]
players = [
DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True),
DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True),
DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True),
DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True)
]
@app.route('/upload', methods=['POST']) @app.route('/upload', methods=['POST'])
def upload(): def upload():
type = request.form.get('type') type = request.form.get('type')
@ -71,6 +97,241 @@ def metrics():
metrics[baseline.rank] = baseline_metric metrics[baseline.rank] = baseline_metric
return jsonify({'status': 0, 'message': 'success', 'result': metrics}) return jsonify({'status': 0, 'message': 'success', 'result': metrics})
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
try:
# Player postion
player_position = request.form.get('player_position')
if player_position not in ['0', '1', '2', '3']:
return jsonify({'status': 1, 'message': 'player_position must be 0, 1, 2 or 3'})
player_position = int(player_position)
# Player hand cards
player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]
if player_position == 0:
if len(player_hand_cards) < 1 or len(player_hand_cards) > 33:
return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-33'})
else:
if len(player_hand_cards) < 1 or len(player_hand_cards) > 25:
return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-25'})
# Number cards left
num_cards_left = [
int(request.form.get('num_cards_left_landlord')),
int(request.form.get('num_cards_left_landlord_down')),
int(request.form.get('num_cards_left_landlord_front')),
int(request.form.get('num_cards_left_landlord_up'))
]
if num_cards_left[player_position] != len(player_hand_cards):
return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'})
if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[3] < 0 \
or num_cards_left[0] > 33 or num_cards_left[1] > 25 or num_cards_left[2] > 25 or num_cards_left[2] > 25:
return jsonify({'status': 5, 'message': 'the number of cards left not in range'})
# Card play sequence
if request.form.get('card_play_action_seq') == '':
card_play_action_seq = []
else:
card_play_action_seq = [['', [RealCard2EnvCard[c] for c in cards]] for cards in request.form.get('card_play_action_seq').split(',')]
# Other hand cards
other_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')]
if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]:
return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'})
# Last moves
last_moves = []
for field in ['last_move_landlord', 'last_move_landlord_down', 'last_move_landlord_front', 'last_move_landlord_up']:
last_moves.append([RealCard2EnvCard[c] for c in request.form.get(field)])
# Played cards
played_cards = {}
for idx, field in enumerate(['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_front', 'played_cards_landlord_up']):
played_cards[idx_position[idx]] = [RealCard2EnvCard[c] for c in request.form.get(field)]
# Bomb Num
bomb_num = request.form.get('bomb_num')
# InfoSet
info_set = InfoSet()
info_set.player_position = idx_position[player_position]
info_set.player_hand_cards = player_hand_cards
info_set.num_cards_left = num_cards_left
info_set.card_play_action_seq = card_play_action_seq
info_set.other_hand_cards = other_hand_cards
info_set.last_moves = last_moves
info_set.played_cards = played_cards
info_set.bomb_num = [int(x) for x in str.split(bomb_num, ',')]
info_set.num_cards_left_dict['landlord'] = num_cards_left[0]
info_set.num_cards_left_dict['landlord_down'] = num_cards_left[1]
info_set.num_cards_left_dict['landlord_front'] = num_cards_left[2]
info_set.num_cards_left_dict['landlord_up'] = num_cards_left[3]
# Get rival move and legal_actions
rival_move = []
if len(card_play_action_seq) != 0:
if len(card_play_action_seq[-1][1]) == 0:
if len(card_play_action_seq[-2][1]) == 0:
rival_move = card_play_action_seq[-3][1]
else:
rival_move = card_play_action_seq[-2][1]
else:
rival_move = card_play_action_seq[-1][1]
info_set.rival_move = rival_move
info_set.legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
# Prediction
actions, actions_confidence = players[player_position].act(info_set, True)
actions = [''.join([EnvCard2RealCard[a] for a in action]) for action in actions]
result = {}
win_rates = {}
for i in range(len(actions)):
# Here, we calculate the win rate
win_rate = max(actions_confidence[i], -1)
win_rate = min(win_rate, 1)
win_rates[actions[i]] = str(round((win_rate + 1) / 2, 4))
result[actions[i]] = str(round(actions_confidence[i], 6))
############## DEBUG ################
if app.debug:
print('--------------- DEBUG START --------------')
command = 'curl --data "'
parameters = []
for key in request.form:
parameters.append(key+'='+request.form.get(key))
print(key+':', request.form.get(key))
command += '&'.join(parameters)
command += '" "http://127.0.0.1:5000/predict"'
print('Command:', command)
print('Rival Move:', rival_move)
print('legal_actions:', info_set.legal_actions)
print('Result:', result)
print('--------------- DEBUG END --------------')
############## DEBUG ################
return jsonify({'status': 0, 'message': 'success', 'result': result, 'win_rates': win_rates})
except:
import traceback
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
@app.route('/legal', methods=['POST'])
def legal():
if request.method == 'POST':
try:
player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]
rival_move = [RealCard2EnvCard[c] for c in request.form.get('rival_move')]
legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
legal_actions = ','.join([''.join([EnvCard2RealCard[a] for a in action]) for action in legal_actions])
return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions})
except:
import traceback
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
class InfoSet(object):
def __init__(self):
self.player_position = None
self.player_hand_cards = None
self.num_cards_left = None
self.num_cards_left_dict = {}
self.all_handcards = {}
# self.three_landlord_cards = None
self.card_play_action_seq = None
self.other_hand_cards = None
self.legal_actions = None
self.rival_move = None
self.last_moves = None
self.played_cards = None
self.bomb_num = None
def _get_legal_card_play_actions(player_hand_cards, rival_move):
mg = MovesGener(player_hand_cards)
rival_type = md.get_move_type(rival_move)
rival_move_type = rival_type['type']
rival_move_len = rival_type.get('len', 1)
moves = list()
if rival_move_type == md.TYPE_0_PASS:
moves = mg.gen_moves()
elif rival_move_type == md.TYPE_1_SINGLE:
all_moves = mg.gen_type_1_single()
moves = ms.filter_type_1_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_2_PAIR:
all_moves = mg.gen_type_2_pair()
moves = ms.filter_type_2_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_3_TRIPLE:
all_moves = mg.gen_type_3_triple()
moves = ms.filter_type_3_triple(all_moves, rival_move)
elif rival_move_type == md.TYPE_4_BOMB:
all_moves = mg.gen_type_4_bomb(4)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB5:
all_moves = mg.gen_type_4_bomb(5)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB6:
all_moves = mg.gen_type_4_bomb(6)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB7:
all_moves = mg.gen_type_4_bomb(7)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB8:
all_moves = mg.gen_type_4_bomb(8)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_5_KING_BOMB:
moves = []
elif rival_move_type == md.TYPE_7_3_2:
all_moves = mg.gen_type_7_3_2()
moves = ms.filter_type_7_3_2(all_moves, rival_move)
elif rival_move_type == md.TYPE_8_SERIAL_SINGLE:
all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len)
moves = ms.filter_type_8_serial_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_9_SERIAL_PAIR:
all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len)
moves = ms.filter_type_9_serial_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE:
all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len)
moves = ms.filter_type_10_serial_triple(all_moves, rival_move)
elif rival_move_type == md.TYPE_12_SERIAL_3_2:
all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len)
moves = ms.filter_type_12_serial_3_2(all_moves, rival_move)
if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB:
moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
if len(rival_move) != 0: # rival_move is not 'pass'
moves = moves + [[]]
for m in moves:
m.sort()
moves.sort()
moves = list(move for move,_ in itertools.groupby(moves))
return moves
if __name__ == '__main__': if __name__ == '__main__':
import argparse import argparse
parser = argparse.ArgumentParser(description='DouZero evaluation backend') parser = argparse.ArgumentParser(description='DouZero evaluation backend')