调整晋升要求,添加评估接口
This commit is contained in:
parent
c4c008d034
commit
7b21149add
|
@ -12,9 +12,6 @@ from onnxruntime.datasets import get_example
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
def to_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
class LandlordLstmModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -6,7 +6,7 @@ from onnxruntime.datasets import get_example
|
|||
|
||||
from douzero.env.env import get_obs
|
||||
|
||||
def _load_model(position, model_path, model_type, use_legacy, use_lite):
|
||||
def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False):
|
||||
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite
|
||||
model = None
|
||||
if model_type == "general":
|
||||
|
@ -34,9 +34,9 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite):
|
|||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
model.eval()
|
||||
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
||||
model_path = model_path + '.onnx'
|
||||
if not os.path.exists(model_path):
|
||||
if use_onnx and not os.path.exists(model_path):
|
||||
onnx_params = model.get_onnx_params(torch.device('cpu'))
|
||||
torch.onnx.export(
|
||||
model,
|
||||
onnx_params['args'],
|
||||
|
@ -53,28 +53,42 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite):
|
|||
|
||||
class DeepAgent:
|
||||
|
||||
def __init__(self, position, model_path):
|
||||
def __init__(self, position, model_path, use_onnx=False):
|
||||
self.use_legacy = True if "legacy" in model_path else False
|
||||
self.lite_model = True if "lite" in model_path else False
|
||||
self.model_type = "general" if "resnet" in model_path else "old"
|
||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model)
|
||||
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
|
||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model, use_onnx=use_onnx)
|
||||
if use_onnx:
|
||||
self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider'])
|
||||
else:
|
||||
self.onnx_model = None
|
||||
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
||||
def act(self, infoset):
|
||||
if len(infoset.legal_actions) == 1:
|
||||
def act(self, infoset, with_confidence = False):
|
||||
if not with_confidence and len(infoset.legal_actions) == 1:
|
||||
return infoset.legal_actions[0]
|
||||
|
||||
obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model)
|
||||
|
||||
# z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||
# x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||
# if torch.cuda.is_available():
|
||||
# z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
|
||||
# y_pred = self.model.forward(z_batch, x_batch)['values']
|
||||
# y_pred = y_pred.detach().cpu().numpy()
|
||||
y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0]
|
||||
if self.onnx_model is None:
|
||||
z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||
x_batch = torch.from_numpy(obs['x_batch']).float()
|
||||
if torch.cuda.is_available():
|
||||
z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
|
||||
y_pred = self.model.forward(z_batch, x_batch)['values']
|
||||
y_pred = y_pred.detach().cpu().numpy()
|
||||
else:
|
||||
y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0]
|
||||
|
||||
if with_confidence:
|
||||
y_pred = y_pred.flatten()
|
||||
size = min(3, len(y_pred))
|
||||
best_action_index = np.argpartition(y_pred, -size)[-size:]
|
||||
best_action_confidence = y_pred[best_action_index]
|
||||
best_action = [infoset.legal_actions[index] for index in best_action_index]
|
||||
|
||||
return best_action, best_action_confidence
|
||||
|
||||
best_action_index = np.argmax(y_pred, axis=0)[0]
|
||||
best_action = infoset.legal_actions[best_action_index]
|
||||
|
|
|
@ -19,7 +19,7 @@ def load_card_play_models(card_play_model_path_dict):
|
|||
players[position] = RandomAgent()
|
||||
else:
|
||||
from .deep_agent import DeepAgent
|
||||
players[position] = DeepAgent(position, card_play_model_path_dict[position])
|
||||
players[position] = DeepAgent(position, card_play_model_path_dict[position], use_onnx=True)
|
||||
return players
|
||||
|
||||
def mp_simulate(card_play_data_list, card_play_model_path_dict, q, output, title):
|
||||
|
|
|
@ -45,16 +45,16 @@ def battle_logic(baseline : Baseline, battle : Battle):
|
|||
|
||||
challenge_success = False
|
||||
if battle.challenger_position == 'landlord':
|
||||
if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.2:
|
||||
if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.15:
|
||||
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
|
||||
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
|
||||
if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.2:
|
||||
if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.15:
|
||||
challenge_success = True
|
||||
else:
|
||||
if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.2:
|
||||
if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.05:
|
||||
landlord_wp, farmer_wp, landlord_adp, farmer_adp = \
|
||||
_second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
|
||||
if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.2:
|
||||
if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.05:
|
||||
challenge_success = True
|
||||
if challenge_success:
|
||||
challenger_baseline['rank'] = baseline.rank + 1
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import itertools
|
||||
|
||||
from douzero.server.orm import Model, Battle, Baseline
|
||||
from douzero.server.battle import tick
|
||||
from flask import Flask, jsonify, request
|
||||
|
@ -5,6 +7,10 @@ from flask_cors import CORS
|
|||
from datetime import datetime
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from douzero.env.move_generator import MovesGener
|
||||
from douzero.env import move_detector as md, move_selector as ms
|
||||
from douzero.evaluation.deep_agent import DeepAgent
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
|
@ -12,9 +18,29 @@ Model.create_table()
|
|||
Battle.create_table()
|
||||
Baseline.create_table()
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'}
|
||||
|
||||
threadpool = ThreadPoolExecutor(1)
|
||||
|
||||
|
||||
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
|
||||
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
|
||||
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
|
||||
|
||||
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
|
||||
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
|
||||
'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30}
|
||||
|
||||
baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1)
|
||||
if len(baselines) >= 1:
|
||||
baseline = baselines[0]
|
||||
players = [
|
||||
DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True),
|
||||
DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True),
|
||||
DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True),
|
||||
DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True)
|
||||
]
|
||||
|
||||
@app.route('/upload', methods=['POST'])
|
||||
def upload():
|
||||
type = request.form.get('type')
|
||||
|
@ -71,6 +97,241 @@ def metrics():
|
|||
metrics[baseline.rank] = baseline_metric
|
||||
return jsonify({'status': 0, 'message': 'success', 'result': metrics})
|
||||
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
def predict():
|
||||
if request.method == 'POST':
|
||||
try:
|
||||
# Player postion
|
||||
player_position = request.form.get('player_position')
|
||||
if player_position not in ['0', '1', '2', '3']:
|
||||
return jsonify({'status': 1, 'message': 'player_position must be 0, 1, 2 or 3'})
|
||||
player_position = int(player_position)
|
||||
|
||||
# Player hand cards
|
||||
player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]
|
||||
if player_position == 0:
|
||||
if len(player_hand_cards) < 1 or len(player_hand_cards) > 33:
|
||||
return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-33'})
|
||||
else:
|
||||
if len(player_hand_cards) < 1 or len(player_hand_cards) > 25:
|
||||
return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-25'})
|
||||
|
||||
# Number cards left
|
||||
num_cards_left = [
|
||||
int(request.form.get('num_cards_left_landlord')),
|
||||
int(request.form.get('num_cards_left_landlord_down')),
|
||||
int(request.form.get('num_cards_left_landlord_front')),
|
||||
int(request.form.get('num_cards_left_landlord_up'))
|
||||
]
|
||||
if num_cards_left[player_position] != len(player_hand_cards):
|
||||
return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'})
|
||||
if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[3] < 0 \
|
||||
or num_cards_left[0] > 33 or num_cards_left[1] > 25 or num_cards_left[2] > 25 or num_cards_left[2] > 25:
|
||||
return jsonify({'status': 5, 'message': 'the number of cards left not in range'})
|
||||
|
||||
# Card play sequence
|
||||
if request.form.get('card_play_action_seq') == '':
|
||||
card_play_action_seq = []
|
||||
else:
|
||||
card_play_action_seq = [['', [RealCard2EnvCard[c] for c in cards]] for cards in request.form.get('card_play_action_seq').split(',')]
|
||||
|
||||
# Other hand cards
|
||||
other_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')]
|
||||
if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]:
|
||||
return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'})
|
||||
|
||||
# Last moves
|
||||
last_moves = []
|
||||
for field in ['last_move_landlord', 'last_move_landlord_down', 'last_move_landlord_front', 'last_move_landlord_up']:
|
||||
last_moves.append([RealCard2EnvCard[c] for c in request.form.get(field)])
|
||||
|
||||
# Played cards
|
||||
played_cards = {}
|
||||
for idx, field in enumerate(['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_front', 'played_cards_landlord_up']):
|
||||
played_cards[idx_position[idx]] = [RealCard2EnvCard[c] for c in request.form.get(field)]
|
||||
|
||||
# Bomb Num
|
||||
bomb_num = request.form.get('bomb_num')
|
||||
|
||||
# InfoSet
|
||||
info_set = InfoSet()
|
||||
info_set.player_position = idx_position[player_position]
|
||||
info_set.player_hand_cards = player_hand_cards
|
||||
info_set.num_cards_left = num_cards_left
|
||||
info_set.card_play_action_seq = card_play_action_seq
|
||||
info_set.other_hand_cards = other_hand_cards
|
||||
info_set.last_moves = last_moves
|
||||
info_set.played_cards = played_cards
|
||||
info_set.bomb_num = [int(x) for x in str.split(bomb_num, ',')]
|
||||
info_set.num_cards_left_dict['landlord'] = num_cards_left[0]
|
||||
info_set.num_cards_left_dict['landlord_down'] = num_cards_left[1]
|
||||
info_set.num_cards_left_dict['landlord_front'] = num_cards_left[2]
|
||||
info_set.num_cards_left_dict['landlord_up'] = num_cards_left[3]
|
||||
|
||||
# Get rival move and legal_actions
|
||||
rival_move = []
|
||||
if len(card_play_action_seq) != 0:
|
||||
if len(card_play_action_seq[-1][1]) == 0:
|
||||
if len(card_play_action_seq[-2][1]) == 0:
|
||||
rival_move = card_play_action_seq[-3][1]
|
||||
else:
|
||||
rival_move = card_play_action_seq[-2][1]
|
||||
else:
|
||||
rival_move = card_play_action_seq[-1][1]
|
||||
info_set.rival_move = rival_move
|
||||
info_set.legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
|
||||
|
||||
# Prediction
|
||||
actions, actions_confidence = players[player_position].act(info_set, True)
|
||||
actions = [''.join([EnvCard2RealCard[a] for a in action]) for action in actions]
|
||||
result = {}
|
||||
win_rates = {}
|
||||
for i in range(len(actions)):
|
||||
# Here, we calculate the win rate
|
||||
win_rate = max(actions_confidence[i], -1)
|
||||
win_rate = min(win_rate, 1)
|
||||
win_rates[actions[i]] = str(round((win_rate + 1) / 2, 4))
|
||||
result[actions[i]] = str(round(actions_confidence[i], 6))
|
||||
|
||||
############## DEBUG ################
|
||||
if app.debug:
|
||||
print('--------------- DEBUG START --------------')
|
||||
command = 'curl --data "'
|
||||
parameters = []
|
||||
for key in request.form:
|
||||
parameters.append(key+'='+request.form.get(key))
|
||||
print(key+':', request.form.get(key))
|
||||
command += '&'.join(parameters)
|
||||
command += '" "http://127.0.0.1:5000/predict"'
|
||||
print('Command:', command)
|
||||
print('Rival Move:', rival_move)
|
||||
print('legal_actions:', info_set.legal_actions)
|
||||
print('Result:', result)
|
||||
print('--------------- DEBUG END --------------')
|
||||
############## DEBUG ################
|
||||
return jsonify({'status': 0, 'message': 'success', 'result': result, 'win_rates': win_rates})
|
||||
except:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return jsonify({'status': -1, 'message': 'unkown error'})
|
||||
|
||||
@app.route('/legal', methods=['POST'])
|
||||
def legal():
|
||||
if request.method == 'POST':
|
||||
try:
|
||||
player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]
|
||||
rival_move = [RealCard2EnvCard[c] for c in request.form.get('rival_move')]
|
||||
legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
|
||||
legal_actions = ','.join([''.join([EnvCard2RealCard[a] for a in action]) for action in legal_actions])
|
||||
return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions})
|
||||
except:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return jsonify({'status': -1, 'message': 'unkown error'})
|
||||
|
||||
class InfoSet(object):
|
||||
|
||||
def __init__(self):
|
||||
self.player_position = None
|
||||
self.player_hand_cards = None
|
||||
self.num_cards_left = None
|
||||
self.num_cards_left_dict = {}
|
||||
self.all_handcards = {}
|
||||
# self.three_landlord_cards = None
|
||||
self.card_play_action_seq = None
|
||||
self.other_hand_cards = None
|
||||
self.legal_actions = None
|
||||
self.rival_move = None
|
||||
self.last_moves = None
|
||||
self.played_cards = None
|
||||
self.bomb_num = None
|
||||
|
||||
def _get_legal_card_play_actions(player_hand_cards, rival_move):
|
||||
mg = MovesGener(player_hand_cards)
|
||||
|
||||
rival_type = md.get_move_type(rival_move)
|
||||
rival_move_type = rival_type['type']
|
||||
rival_move_len = rival_type.get('len', 1)
|
||||
moves = list()
|
||||
|
||||
if rival_move_type == md.TYPE_0_PASS:
|
||||
moves = mg.gen_moves()
|
||||
|
||||
elif rival_move_type == md.TYPE_1_SINGLE:
|
||||
all_moves = mg.gen_type_1_single()
|
||||
moves = ms.filter_type_1_single(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_2_PAIR:
|
||||
all_moves = mg.gen_type_2_pair()
|
||||
moves = ms.filter_type_2_pair(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_3_TRIPLE:
|
||||
all_moves = mg.gen_type_3_triple()
|
||||
moves = ms.filter_type_3_triple(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB:
|
||||
all_moves = mg.gen_type_4_bomb(4)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB5:
|
||||
all_moves = mg.gen_type_4_bomb(5)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB6:
|
||||
all_moves = mg.gen_type_4_bomb(6)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB7:
|
||||
all_moves = mg.gen_type_4_bomb(7)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB8:
|
||||
all_moves = mg.gen_type_4_bomb(8)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_5_KING_BOMB:
|
||||
moves = []
|
||||
|
||||
elif rival_move_type == md.TYPE_7_3_2:
|
||||
all_moves = mg.gen_type_7_3_2()
|
||||
moves = ms.filter_type_7_3_2(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_8_SERIAL_SINGLE:
|
||||
all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_8_serial_single(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_9_SERIAL_PAIR:
|
||||
all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_9_serial_pair(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE:
|
||||
all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_10_serial_triple(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_12_SERIAL_3_2:
|
||||
all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_12_serial_3_2(all_moves, rival_move)
|
||||
|
||||
if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB:
|
||||
moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
if len(rival_move) != 0: # rival_move is not 'pass'
|
||||
moves = moves + [[]]
|
||||
|
||||
for m in moves:
|
||||
m.sort()
|
||||
|
||||
moves.sort()
|
||||
moves = list(move for move,_ in itertools.groupby(moves))
|
||||
|
||||
return moves
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='DouZero evaluation backend')
|
||||
|
|
Loading…
Reference in New Issue