import os import torch import numpy as np from heapq import nlargest from collections import Counter, OrderedDict from flask import Flask, jsonify, request from flask_cors import CORS app = Flask(__name__) CORS(app) from utils.move_generator import MovesGener from utils import move_detector as md, move_selector as ms import rlcard env = rlcard.make('doudizhu') DouZeroCard2RLCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', 13: 'K', 14: 'A', 17: '2', 20: 'B', 30: 'R'} RLCard2DouZeroCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12, 'K': 13, 'A': 14, '2': 17, 'B': 20, 'R': 30} EnvCard2RealCard = {'3': '3', '4':'4', '5': '5', '6': '6', '7': '7', '8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q', 'K': 'K', 'A': 'A', '2': '2', 'B': 'X', 'R': 'D'} RealCard2EnvCard = {'3': '3', '4':'4', '5': '5', '6': '6', '7': '7', '8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q', 'K': 'K', 'A': 'A', '2': '2', 'X': 'B', 'D': 'R'} pretrained_dir = 'pretrained/dmc_pretrained' device = torch.device('cpu') players = [] for i in range(3): model_path = os.path.join(pretrained_dir, str(i)+'.pth') agent = torch.load(model_path, map_location=device) agent.set_device(device) players.append(agent) @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': try: # Player postion player_position = request.form.get('player_position') if player_position not in ['0', '1', '2']: return jsonify({'status': 1, 'message': 'player_position must be 0, 1, or 2'}) player_position = int(player_position) # Player hand cards player_hand_cards = ''.join([RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]) if player_position == 0: if len(player_hand_cards) < 1 or len(player_hand_cards) > 20: return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-20'}) else: if len(player_hand_cards) < 1 or len(player_hand_cards) > 17: return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-17'}) # Number cards left num_cards_left = [int(request.form.get('num_cards_left_landlord')), int(request.form.get('num_cards_left_landlord_down')), int(request.form.get('num_cards_left_landlord_up'))] if num_cards_left[player_position] != len(player_hand_cards): return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'}) if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[0] > 20 or num_cards_left[1] > 17 or num_cards_left[2] > 17: return jsonify({'status': 5, 'message': 'the number of cards left not in range'}) # Three landlord cards three_landlord_cards = ''.join([RealCard2EnvCard[c] for c in request.form.get('three_landlord_cards')]) if len(three_landlord_cards) < 0 or len(three_landlord_cards) > 3: return jsonify({'status': 6, 'message': 'the number of landlord cards should be 0-3'}) # Card play sequence if request.form.get('card_play_action_seq') == '': card_play_action_seq = [] else: tmp_seq = [''.join([RealCard2EnvCard[c] for c in cards]) for cards in request.form.get('card_play_action_seq').split(',')] for i in range(len(tmp_seq)): if tmp_seq[i] == '': tmp_seq[i] = 'pass' card_play_action_seq = [] for i in range(len(tmp_seq)): card_play_action_seq.append((i%3, tmp_seq[i])) # Other hand cards other_hand_cards = ''.join([RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')]) if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]: return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'}) # Played cards played_cards = [] for field in ['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_up']: played_cards.append(''.join([RealCard2EnvCard[c] for c in request.form.get(field)])) # RLCard state state = {} state['current_hand'] = player_hand_cards state['landlord'] = 0 state['num_cards_left'] = num_cards_left state['others_hand'] = other_hand_cards state['played_cards'] = played_cards state['seen_cards'] = three_landlord_cards state['self'] = player_position state['trace'] = card_play_action_seq # Get rival move and legal_actions rival_move = 'pass' if len(card_play_action_seq) != 0: if card_play_action_seq[-1][1] == 'pass': rival_move = card_play_action_seq[-2][1] else: rival_move = card_play_action_seq[-1][1] if rival_move == 'pass': rival_move = '' rival_move = [RLCard2DouZeroCard[c] for c in rival_move] state['actions'] = _get_legal_card_play_actions([RLCard2DouZeroCard[c] for c in player_hand_cards], rival_move) state['actions'] = [''.join([DouZeroCard2RLCard[c] for c in a]) for a in state['actions']] for i in range(len(state['actions'])): if state['actions'][i] == '': state['actions'][i] = 'pass' # Prediction state = _extract_state(state) action, info = players[player_position].eval_step(state) if action == 'pass': action = '' for i in info['values']: if i == 'pass': info['values'][''] = info['values']['pass'] del info['values']['pass'] actions = nlargest(3, info['values'], key=info['values'].get) actions_confidence = [info['values'].get(action) for action in actions] actions = [''.join([EnvCard2RealCard[c] for c in action]) for action in actions] result = {} win_rates = {} for i in range(len(actions)): # Here, we calculate the win rate win_rate = min(actions_confidence[i], 1) win_rate = max(win_rate, 0) win_rates[actions[i]] = str(round(win_rate, 4)) result[actions[i]] = str(round(actions_confidence[i], 6)) ############## DEBUG ################ if app.debug: print('--------------- DEBUG START --------------') command = 'curl --data "' parameters = [] for key in request.form: parameters.append(key+'='+request.form.get(key)) print(key+':', request.form.get(key)) command += '&'.join(parameters) command += '" "http://127.0.0.1:5000/predict"' print('Command:', command) print('Rival Move:', rival_move) print('legal_actions:', state['legal_actions']) print('Result:', result) print('--------------- DEBUG END --------------') ############## DEBUG ################ return jsonify({'status': 0, 'message': 'success', 'result': result, 'win_rates': win_rates}) except: import traceback traceback.print_exc() return jsonify({'status': -1, 'message': 'unkown error'}) @app.route('/legal', methods=['POST']) def legal(): if request.method == 'POST': try: player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')] rival_move = [RealCard2EnvCard[c] for c in request.form.get('rival_move')] if rival_move == '': rival_move = 'pass' player_hand_cards = [RLCard2DouZeroCard[c] for c in player_hand_cards] rival_move = [RLCard2DouZeroCard[c] for c in rival_move] legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move) legal_actions = [''.join([DouZeroCard2RLCard[c] for c in a]) for a in legal_actions] for i in range(len(legal_actions)): if legal_actions[i] == 'pass': legal_actions[i] = '' legal_actions = ','.join([''.join([EnvCard2RealCard[c] for c in action]) for action in legal_actions]) return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions}) except: import traceback traceback.print_exc() return jsonify({'status': -1, 'message': 'unkown error'}) def _extract_state(state): current_hand = _cards2array(state['current_hand']) others_hand = _cards2array(state['others_hand']) last_action = '' if len(state['trace']) != 0: if state['trace'][-1][1] == 'pass': last_action = state['trace'][-2][1] else: last_action = state['trace'][-1][1] last_action = _cards2array(last_action) last_9_actions = _action_seq2array(_process_action_seq(state['trace'])) if state['self'] == 0: # landlord landlord_up_played_cards = _cards2array(state['played_cards'][2]) landlord_down_played_cards = _cards2array(state['played_cards'][1]) landlord_up_num_cards_left = _get_one_hot_array(state['num_cards_left'][2], 17) landlord_down_num_cards_left = _get_one_hot_array(state['num_cards_left'][1], 17) obs = np.concatenate((current_hand, others_hand, last_action, last_9_actions, landlord_up_played_cards, landlord_down_played_cards, landlord_up_num_cards_left, landlord_down_num_cards_left)) else: landlord_played_cards = _cards2array(state['played_cards'][0]) for i, action in reversed(state['trace']): if i == 0: last_landlord_action = action last_landlord_action = _cards2array(last_landlord_action) landlord_num_cards_left = _get_one_hot_array(state['num_cards_left'][0], 20) teammate_id = 3 - state['self'] teammate_played_cards = _cards2array(state['played_cards'][teammate_id]) last_teammate_action = 'pass' for i, action in reversed(state['trace']): if i == teammate_id: last_teammate_action = action last_teammate_action = _cards2array(last_teammate_action) teammate_num_cards_left = _get_one_hot_array(state['num_cards_left'][teammate_id], 17) obs = np.concatenate((current_hand, others_hand, last_action, last_9_actions, landlord_played_cards, teammate_played_cards, last_landlord_action, last_teammate_action, landlord_num_cards_left, teammate_num_cards_left)) legal_actions = {env._ACTION_2_ID[action]: _cards2array(action) for action in state['actions']} extracted_state = OrderedDict({'obs': obs, 'legal_actions': legal_actions}) extracted_state['raw_obs'] = state extracted_state['raw_legal_actions'] = [a for a in state['actions']] return extracted_state def _get_legal_card_play_actions(player_hand_cards, rival_move): mg = MovesGener(player_hand_cards) rival_type = md.get_move_type(rival_move) rival_move_type = rival_type['type'] rival_move_len = rival_type.get('len', 1) moves = list() if rival_move_type == md.TYPE_0_PASS: moves = mg.gen_moves() elif rival_move_type == md.TYPE_1_SINGLE: all_moves = mg.gen_type_1_single() moves = ms.filter_type_1_single(all_moves, rival_move) elif rival_move_type == md.TYPE_2_PAIR: all_moves = mg.gen_type_2_pair() moves = ms.filter_type_2_pair(all_moves, rival_move) elif rival_move_type == md.TYPE_3_TRIPLE: all_moves = mg.gen_type_3_triple() moves = ms.filter_type_3_triple(all_moves, rival_move) elif rival_move_type == md.TYPE_4_BOMB: all_moves = mg.gen_type_4_bomb() + mg.gen_type_5_king_bomb() moves = ms.filter_type_4_bomb(all_moves, rival_move) elif rival_move_type == md.TYPE_5_KING_BOMB: moves = [] elif rival_move_type == md.TYPE_6_3_1: all_moves = mg.gen_type_6_3_1() moves = ms.filter_type_6_3_1(all_moves, rival_move) elif rival_move_type == md.TYPE_7_3_2: all_moves = mg.gen_type_7_3_2() moves = ms.filter_type_7_3_2(all_moves, rival_move) elif rival_move_type == md.TYPE_8_SERIAL_SINGLE: all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len) moves = ms.filter_type_8_serial_single(all_moves, rival_move) elif rival_move_type == md.TYPE_9_SERIAL_PAIR: all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len) moves = ms.filter_type_9_serial_pair(all_moves, rival_move) elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE: all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len) moves = ms.filter_type_10_serial_triple(all_moves, rival_move) elif rival_move_type == md.TYPE_11_SERIAL_3_1: all_moves = mg.gen_type_11_serial_3_1(repeat_num=rival_move_len) moves = ms.filter_type_11_serial_3_1(all_moves, rival_move) elif rival_move_type == md.TYPE_12_SERIAL_3_2: all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len) moves = ms.filter_type_12_serial_3_2(all_moves, rival_move) elif rival_move_type == md.TYPE_13_4_2: all_moves = mg.gen_type_13_4_2() moves = ms.filter_type_13_4_2(all_moves, rival_move) elif rival_move_type == md.TYPE_14_4_22: all_moves = mg.gen_type_14_4_22() moves = ms.filter_type_14_4_22(all_moves, rival_move) if rival_move_type not in [md.TYPE_0_PASS, md.TYPE_4_BOMB, md.TYPE_5_KING_BOMB]: moves = moves + mg.gen_type_4_bomb() + mg.gen_type_5_king_bomb() if len(rival_move) != 0: # rival_move is not 'pass' moves = moves + [[]] for m in moves: m.sort() return moves Card2Column = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4, '8': 5, '9': 6, 'T': 7, 'J': 8, 'Q': 9, 'K': 10, 'A': 11, '2': 12} NumOnes2Array = {0: np.array([0, 0, 0, 0]), 1: np.array([1, 0, 0, 0]), 2: np.array([1, 1, 0, 0]), 3: np.array([1, 1, 1, 0]), 4: np.array([1, 1, 1, 1])} def _cards2array(cards): if cards == 'pass': return np.zeros(54, dtype=np.int8) matrix = np.zeros([4, 13], dtype=np.int8) jokers = np.zeros(2, dtype=np.int8) counter = Counter(cards) for card, num_times in counter.items(): if card == 'B': jokers[0] = 1 elif card == 'R': jokers[1] = 1 else: matrix[:, Card2Column[card]] = NumOnes2Array[num_times] return np.concatenate((matrix.flatten('F'), jokers)) def _get_one_hot_array(num_left_cards, max_num_cards): one_hot = np.zeros(max_num_cards, dtype=np.int8) one_hot[num_left_cards - 1] = 1 return one_hot def _action_seq2array(action_seq_list): action_seq_array = np.zeros((len(action_seq_list), 54), np.int8) for row, cards in enumerate(action_seq_list): action_seq_array[row, :] = _cards2array(cards) action_seq_array = action_seq_array.flatten() return action_seq_array def _process_action_seq(sequence, length=9): sequence = [action[1] for action in sequence[-length:]] if len(sequence) < length: empty_sequence = ['' for _ in range(length - len(sequence))] empty_sequence.extend(sequence) sequence = empty_sequence return sequence if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='DouZero backend') parser.add_argument('--debug', action='store_true') args = parser.parse_args() app.run(debug=args.debug)