From 7b21149add926ceba3b730b484a4e0439b36554c Mon Sep 17 00:00:00 2001 From: ZaneYork Date: Sun, 26 Dec 2021 20:11:05 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=99=8B=E5=8D=87=E8=A6=81?= =?UTF-8?q?=E6=B1=82=EF=BC=8C=E6=B7=BB=E5=8A=A0=E8=AF=84=E4=BC=B0=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/models.py | 3 - douzero/evaluation/deep_agent.py | 44 ++++-- douzero/evaluation/simulation.py | 2 +- douzero/server/battle.py | 8 +- evaluate_server.py | 261 +++++++++++++++++++++++++++++++ 5 files changed, 295 insertions(+), 23 deletions(-) diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index c710ef9..89a2793 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -12,9 +12,6 @@ from onnxruntime.datasets import get_example from torch import nn import torch.nn.functional as F -def to_numpy(tensor): - return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() - class LandlordLstmModel(nn.Module): def __init__(self): super().__init__() diff --git a/douzero/evaluation/deep_agent.py b/douzero/evaluation/deep_agent.py index 32ef4c4..18d6ccf 100644 --- a/douzero/evaluation/deep_agent.py +++ b/douzero/evaluation/deep_agent.py @@ -6,7 +6,7 @@ from onnxruntime.datasets import get_example from douzero.env.env import get_obs -def _load_model(position, model_path, model_type, use_legacy, use_lite): +def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False): from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite model = None if model_type == "general": @@ -34,9 +34,9 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite): if torch.cuda.is_available(): model.cuda() model.eval() - onnx_params = model.get_onnx_params(torch.device('cpu')) model_path = model_path + '.onnx' - if not os.path.exists(model_path): + if use_onnx and not os.path.exists(model_path): + onnx_params = model.get_onnx_params(torch.device('cpu')) torch.onnx.export( model, onnx_params['args'], @@ -53,28 +53,42 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite): class DeepAgent: - def __init__(self, position, model_path): + def __init__(self, position, model_path, use_onnx=False): self.use_legacy = True if "legacy" in model_path else False self.lite_model = True if "lite" in model_path else False self.model_type = "general" if "resnet" in model_path else "old" - self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model) - self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider']) + self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model, use_onnx=use_onnx) + if use_onnx: + self.onnx_model = onnxruntime.InferenceSession(get_example(os.path.abspath(model_path + '.onnx')), providers=['CPUExecutionProvider']) + else: + self.onnx_model = None self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', 13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'} - def act(self, infoset): - if len(infoset.legal_actions) == 1: + def act(self, infoset, with_confidence = False): + if not with_confidence and len(infoset.legal_actions) == 1: return infoset.legal_actions[0] obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model) - # z_batch = torch.from_numpy(obs['z_batch']).float() - # x_batch = torch.from_numpy(obs['x_batch']).float() - # if torch.cuda.is_available(): - # z_batch, x_batch = z_batch.cuda(), x_batch.cuda() - # y_pred = self.model.forward(z_batch, x_batch)['values'] - # y_pred = y_pred.detach().cpu().numpy() - y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0] + if self.onnx_model is None: + z_batch = torch.from_numpy(obs['z_batch']).float() + x_batch = torch.from_numpy(obs['x_batch']).float() + if torch.cuda.is_available(): + z_batch, x_batch = z_batch.cuda(), x_batch.cuda() + y_pred = self.model.forward(z_batch, x_batch)['values'] + y_pred = y_pred.detach().cpu().numpy() + else: + y_pred = self.onnx_model.run(None, {'z_batch': obs['z_batch'], 'x_batch': obs['x_batch']})[0] + + if with_confidence: + y_pred = y_pred.flatten() + size = min(3, len(y_pred)) + best_action_index = np.argpartition(y_pred, -size)[-size:] + best_action_confidence = y_pred[best_action_index] + best_action = [infoset.legal_actions[index] for index in best_action_index] + + return best_action, best_action_confidence best_action_index = np.argmax(y_pred, axis=0)[0] best_action = infoset.legal_actions[best_action_index] diff --git a/douzero/evaluation/simulation.py b/douzero/evaluation/simulation.py index e2a5c71..d5a6551 100644 --- a/douzero/evaluation/simulation.py +++ b/douzero/evaluation/simulation.py @@ -19,7 +19,7 @@ def load_card_play_models(card_play_model_path_dict): players[position] = RandomAgent() else: from .deep_agent import DeepAgent - players[position] = DeepAgent(position, card_play_model_path_dict[position]) + players[position] = DeepAgent(position, card_play_model_path_dict[position], use_onnx=True) return players def mp_simulate(card_play_data_list, card_play_model_path_dict, q, output, title): diff --git a/douzero/server/battle.py b/douzero/server/battle.py index a474a3d..05a2ef3 100644 --- a/douzero/server/battle.py +++ b/douzero/server/battle.py @@ -45,16 +45,16 @@ def battle_logic(baseline : Baseline, battle : Battle): challenge_success = False if battle.challenger_position == 'landlord': - if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.2: + if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.15: landlord_wp, farmer_wp, landlord_adp, farmer_adp = \ _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp) - if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.2: + if baseline.landlord_wp == 0 or landlord_wp / float(baseline.landlord_wp) > 1.15: challenge_success = True else: - if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.2: + if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.05: landlord_wp, farmer_wp, landlord_adp, farmer_adp = \ _second_eval(landlord_wp, farmer_wp, landlord_adp, farmer_adp) - if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.2: + if baseline.farmer_wp == 0 or farmer_wp / float(baseline.farmer_wp) > 1.05: challenge_success = True if challenge_success: challenger_baseline['rank'] = baseline.rank + 1 diff --git a/evaluate_server.py b/evaluate_server.py index 52ea422..c2ae440 100644 --- a/evaluate_server.py +++ b/evaluate_server.py @@ -1,3 +1,5 @@ +import itertools + from douzero.server.orm import Model, Battle, Baseline from douzero.server.battle import tick from flask import Flask, jsonify, request @@ -5,6 +7,10 @@ from flask_cors import CORS from datetime import datetime from concurrent.futures import ThreadPoolExecutor +from douzero.env.move_generator import MovesGener +from douzero.env import move_detector as md, move_selector as ms +from douzero.evaluation.deep_agent import DeepAgent + app = Flask(__name__) CORS(app) @@ -12,9 +18,29 @@ Model.create_table() Battle.create_table() Baseline.create_table() positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] +idx_position = {0: 'landlord', 1: 'landlord_down', 2: 'landlord_front', 3:'landlord_up'} threadpool = ThreadPoolExecutor(1) + +EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', + 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', + 13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'} + +RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, + '8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12, + 'K': 13, 'A': 14, '2': 17, 'X': 20, 'D': 30} + +baselines = Baseline.select().order_by(Baseline.rank.desc()).limit(1) +if len(baselines) >= 1: + baseline = baselines[0] + players = [ + DeepAgent('landlord', str(baseline.landlord_path), use_onnx=True), + DeepAgent('landlord_down', str(baseline.landlord_down_path), use_onnx=True), + DeepAgent('landlord_front', str(baseline.landlord_front_path), use_onnx=True), + DeepAgent('landlord_up', str(baseline.landlord_up_path), use_onnx=True) + ] + @app.route('/upload', methods=['POST']) def upload(): type = request.form.get('type') @@ -71,6 +97,241 @@ def metrics(): metrics[baseline.rank] = baseline_metric return jsonify({'status': 0, 'message': 'success', 'result': metrics}) + +@app.route('/predict', methods=['POST']) +def predict(): + if request.method == 'POST': + try: + # Player postion + player_position = request.form.get('player_position') + if player_position not in ['0', '1', '2', '3']: + return jsonify({'status': 1, 'message': 'player_position must be 0, 1, 2 or 3'}) + player_position = int(player_position) + + # Player hand cards + player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')] + if player_position == 0: + if len(player_hand_cards) < 1 or len(player_hand_cards) > 33: + return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-33'}) + else: + if len(player_hand_cards) < 1 or len(player_hand_cards) > 25: + return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-25'}) + + # Number cards left + num_cards_left = [ + int(request.form.get('num_cards_left_landlord')), + int(request.form.get('num_cards_left_landlord_down')), + int(request.form.get('num_cards_left_landlord_front')), + int(request.form.get('num_cards_left_landlord_up')) + ] + if num_cards_left[player_position] != len(player_hand_cards): + return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'}) + if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[3] < 0 \ + or num_cards_left[0] > 33 or num_cards_left[1] > 25 or num_cards_left[2] > 25 or num_cards_left[2] > 25: + return jsonify({'status': 5, 'message': 'the number of cards left not in range'}) + + # Card play sequence + if request.form.get('card_play_action_seq') == '': + card_play_action_seq = [] + else: + card_play_action_seq = [['', [RealCard2EnvCard[c] for c in cards]] for cards in request.form.get('card_play_action_seq').split(',')] + + # Other hand cards + other_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')] + if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]: + return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'}) + + # Last moves + last_moves = [] + for field in ['last_move_landlord', 'last_move_landlord_down', 'last_move_landlord_front', 'last_move_landlord_up']: + last_moves.append([RealCard2EnvCard[c] for c in request.form.get(field)]) + + # Played cards + played_cards = {} + for idx, field in enumerate(['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_front', 'played_cards_landlord_up']): + played_cards[idx_position[idx]] = [RealCard2EnvCard[c] for c in request.form.get(field)] + + # Bomb Num + bomb_num = request.form.get('bomb_num') + + # InfoSet + info_set = InfoSet() + info_set.player_position = idx_position[player_position] + info_set.player_hand_cards = player_hand_cards + info_set.num_cards_left = num_cards_left + info_set.card_play_action_seq = card_play_action_seq + info_set.other_hand_cards = other_hand_cards + info_set.last_moves = last_moves + info_set.played_cards = played_cards + info_set.bomb_num = [int(x) for x in str.split(bomb_num, ',')] + info_set.num_cards_left_dict['landlord'] = num_cards_left[0] + info_set.num_cards_left_dict['landlord_down'] = num_cards_left[1] + info_set.num_cards_left_dict['landlord_front'] = num_cards_left[2] + info_set.num_cards_left_dict['landlord_up'] = num_cards_left[3] + + # Get rival move and legal_actions + rival_move = [] + if len(card_play_action_seq) != 0: + if len(card_play_action_seq[-1][1]) == 0: + if len(card_play_action_seq[-2][1]) == 0: + rival_move = card_play_action_seq[-3][1] + else: + rival_move = card_play_action_seq[-2][1] + else: + rival_move = card_play_action_seq[-1][1] + info_set.rival_move = rival_move + info_set.legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move) + + # Prediction + actions, actions_confidence = players[player_position].act(info_set, True) + actions = [''.join([EnvCard2RealCard[a] for a in action]) for action in actions] + result = {} + win_rates = {} + for i in range(len(actions)): + # Here, we calculate the win rate + win_rate = max(actions_confidence[i], -1) + win_rate = min(win_rate, 1) + win_rates[actions[i]] = str(round((win_rate + 1) / 2, 4)) + result[actions[i]] = str(round(actions_confidence[i], 6)) + + ############## DEBUG ################ + if app.debug: + print('--------------- DEBUG START --------------') + command = 'curl --data "' + parameters = [] + for key in request.form: + parameters.append(key+'='+request.form.get(key)) + print(key+':', request.form.get(key)) + command += '&'.join(parameters) + command += '" "http://127.0.0.1:5000/predict"' + print('Command:', command) + print('Rival Move:', rival_move) + print('legal_actions:', info_set.legal_actions) + print('Result:', result) + print('--------------- DEBUG END --------------') + ############## DEBUG ################ + return jsonify({'status': 0, 'message': 'success', 'result': result, 'win_rates': win_rates}) + except: + import traceback + traceback.print_exc() + return jsonify({'status': -1, 'message': 'unkown error'}) + +@app.route('/legal', methods=['POST']) +def legal(): + if request.method == 'POST': + try: + player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')] + rival_move = [RealCard2EnvCard[c] for c in request.form.get('rival_move')] + legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move) + legal_actions = ','.join([''.join([EnvCard2RealCard[a] for a in action]) for action in legal_actions]) + return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions}) + except: + import traceback + traceback.print_exc() + return jsonify({'status': -1, 'message': 'unkown error'}) + +class InfoSet(object): + + def __init__(self): + self.player_position = None + self.player_hand_cards = None + self.num_cards_left = None + self.num_cards_left_dict = {} + self.all_handcards = {} + # self.three_landlord_cards = None + self.card_play_action_seq = None + self.other_hand_cards = None + self.legal_actions = None + self.rival_move = None + self.last_moves = None + self.played_cards = None + self.bomb_num = None + +def _get_legal_card_play_actions(player_hand_cards, rival_move): + mg = MovesGener(player_hand_cards) + + rival_type = md.get_move_type(rival_move) + rival_move_type = rival_type['type'] + rival_move_len = rival_type.get('len', 1) + moves = list() + + if rival_move_type == md.TYPE_0_PASS: + moves = mg.gen_moves() + + elif rival_move_type == md.TYPE_1_SINGLE: + all_moves = mg.gen_type_1_single() + moves = ms.filter_type_1_single(all_moves, rival_move) + + elif rival_move_type == md.TYPE_2_PAIR: + all_moves = mg.gen_type_2_pair() + moves = ms.filter_type_2_pair(all_moves, rival_move) + + elif rival_move_type == md.TYPE_3_TRIPLE: + all_moves = mg.gen_type_3_triple() + moves = ms.filter_type_3_triple(all_moves, rival_move) + + elif rival_move_type == md.TYPE_4_BOMB: + all_moves = mg.gen_type_4_bomb(4) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_4_BOMB5: + all_moves = mg.gen_type_4_bomb(5) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_4_BOMB6: + all_moves = mg.gen_type_4_bomb(6) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_4_BOMB7: + all_moves = mg.gen_type_4_bomb(7) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_4_BOMB8: + all_moves = mg.gen_type_4_bomb(8) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_5_KING_BOMB: + moves = [] + + elif rival_move_type == md.TYPE_7_3_2: + all_moves = mg.gen_type_7_3_2() + moves = ms.filter_type_7_3_2(all_moves, rival_move) + + elif rival_move_type == md.TYPE_8_SERIAL_SINGLE: + all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len) + moves = ms.filter_type_8_serial_single(all_moves, rival_move) + + elif rival_move_type == md.TYPE_9_SERIAL_PAIR: + all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len) + moves = ms.filter_type_9_serial_pair(all_moves, rival_move) + + elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE: + all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len) + moves = ms.filter_type_10_serial_triple(all_moves, rival_move) + + elif rival_move_type == md.TYPE_12_SERIAL_3_2: + all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len) + moves = ms.filter_type_12_serial_3_2(all_moves, rival_move) + + if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB: + moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + if len(rival_move) != 0: # rival_move is not 'pass' + moves = moves + [[]] + + for m in moves: + m.sort() + + moves.sort() + moves = list(move for move,_ in itertools.groupby(moves)) + + return moves + if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='DouZero evaluation backend')