from utils.move_generator import MovesGener from utils import move_selector as ms from utils import move_detector as md import rlcard import itertools import os from collections import Counter, OrderedDict from heapq import nlargest import numpy as np import torch from flask import Flask, jsonify, request from flask_cors import CORS app = Flask(__name__) CORS(app) env = rlcard.make('doudizhu') DouZeroCard2RLCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', 13: 'K', 14: 'A', 17: '2', 20: 'B', 30: 'R'} RLCard2DouZeroCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12, 'K': 13, 'A': 14, '2': 17, 'B': 20, 'R': 30} EnvCard2RealCard = {'3': '3', '4': '4', '5': '5', '6': '6', '7': '7', '8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q', 'K': 'K', 'A': 'A', '2': '2', 'B': 'X', 'R': 'D'} RealCard2EnvCard = {'3': '3', '4': '4', '5': '5', '6': '6', '7': '7', '8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q', 'K': 'K', 'A': 'A', '2': '2', 'X': 'B', 'D': 'R'} pretrained_dir = 'pretrained/dmc_pretrained' device = torch.device('cpu') players = [] for i in range(4): model_path = os.path.join(pretrained_dir, str(i)+'.pth') agent = torch.load(model_path, map_location=device) agent.set_device(device) players.append(agent) @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': try: # Player postion player_position = request.form.get('player_position') if player_position not in ['0', '1', '2', '3']: return jsonify({'status': 1, 'message': 'player_position must be 0, 1, 2 or 3'}) player_position = int(player_position) # Player hand cards player_hand_cards = ''.join( [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]) if player_position == 0: if len(player_hand_cards) < 1 or len(player_hand_cards) > 33: return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-33'}) else: if len(player_hand_cards) < 1 or len(player_hand_cards) > 25: return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-25'}) # Number cards left num_cards_left = [ int(request.form.get('num_cards_left_landlord')), int(request.form.get('num_cards_left_landlord_down')), int(request.form.get('num_cards_left_landlord_front')), int(request.form.get('num_cards_left_landlord_up')) ] if num_cards_left[player_position] != len(player_hand_cards): return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'}) if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[3] < 0 \ or num_cards_left[0] > 33 or num_cards_left[1] > 25 or num_cards_left[2] > 25 or num_cards_left[3] > 25: return jsonify({'status': 5, 'message': 'the number of cards left not in range'}) # Three landlord cards # three_landlord_cards = ''.join( # [RealCard2EnvCard[c] for c in request.form.get('three_landlord_cards')]) # if len(three_landlord_cards) < 0 or len(three_landlord_cards) > 3: # return jsonify({'status': 6, 'message': 'the number of landlord cards should be 0-3'}) # Card play sequence if request.form.get('card_play_action_seq') == '': card_play_action_seq = [] else: tmp_seq = [''.join([RealCard2EnvCard[c] for c in cards]) for cards in request.form.get('card_play_action_seq').split(',')] for i in range(len(tmp_seq)): if tmp_seq[i] == '': tmp_seq[i] = 'pass' card_play_action_seq = [] for i in range(len(tmp_seq)): card_play_action_seq.append((i % 4, tmp_seq[i])) # Other hand cards other_hand_cards = ''.join( [RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')]) if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]: return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'}) # Played cards played_cards = [] for field in ['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_front', 'played_cards_landlord_up']: played_cards.append( ''.join([RealCard2EnvCard[c] for c in request.form.get(field)])) # RLCard state state = {} state['current_hand'] = player_hand_cards state['landlord'] = 0 state['num_cards_left'] = num_cards_left state['others_hand'] = other_hand_cards state['played_cards'] = played_cards # state['seen_cards'] = three_landlord_cards state['self'] = player_position state['trace'] = card_play_action_seq # Get rival move and legal_actions rival_move = 'pass' if len(card_play_action_seq) != 0: if card_play_action_seq[-1][1] == 'pass': if card_play_action_seq[-2][1] == 'pass': rival_move = card_play_action_seq[-3][1] else: rival_move = card_play_action_seq[-2][1] else: rival_move = card_play_action_seq[-1][1] if rival_move == 'pass': rival_move = '' rival_move = [RLCard2DouZeroCard[c] for c in rival_move] state['actions'] = _get_legal_card_play_actions( [RLCard2DouZeroCard[c] for c in player_hand_cards], rival_move) state['actions'] = [ ''.join([DouZeroCard2RLCard[c] for c in a]) for a in state['actions']] for i in range(len(state['actions'])): if state['actions'][i] == '': state['actions'][i] = 'pass' # Prediction state = _extract_state(state) action, info = players[player_position].eval_step(state) if action == 'pass': action = '' for i in info['values']: if i == 'pass': info['values'][''] = info['values']['pass'] del info['values']['pass'] break actions = nlargest(3, info['values'], key=info['values'].get) actions_confidence = [info['values'].get( action) for action in actions] actions = [''.join([EnvCard2RealCard[c] for c in action]) for action in actions] result = {} win_rates = {} for i in range(len(actions)): # Here, we calculate the win rate win_rate = min(actions_confidence[i], 1) win_rate = max(win_rate, 0) win_rates[actions[i]] = str(round(win_rate, 4)) result[actions[i]] = str(round(actions_confidence[i], 6)) ############## DEBUG ################ if app.debug: print('--------------- DEBUG START --------------') command = 'curl --data "' parameters = [] for key in request.form: parameters.append(key+'='+request.form.get(key)) print(key+':', request.form.get(key)) command += '&'.join(parameters) command += '" "http://127.0.0.1:5000/predict"' print('Command:', command) print('Rival Move:', rival_move) print('legal_actions:', state['legal_actions']) print('Result:', result) print('--------------- DEBUG END --------------') ############## DEBUG ################ return jsonify({'status': 0, 'message': 'success', 'result': result, 'win_rates': win_rates}) except: import traceback traceback.print_exc() return jsonify({'status': -1, 'message': 'unkown error'}) @app.route('/legal', methods=['POST']) def legal(): if request.method == 'POST': try: player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')] rival_move = [RealCard2EnvCard[c] for c in request.form.get('rival_move')] if rival_move == '': rival_move = 'pass' player_hand_cards = [RLCard2DouZeroCard[c] for c in player_hand_cards] rival_move = [RLCard2DouZeroCard[c] for c in rival_move] legal_actions = _get_legal_card_play_actions( player_hand_cards, rival_move) legal_actions = [''.join([DouZeroCard2RLCard[c] for c in a]) for a in legal_actions] for i in range(len(legal_actions)): if legal_actions[i] == 'pass': legal_actions[i] = '' legal_actions = ','.join( [''.join([EnvCard2RealCard[c] for c in action]) for action in legal_actions]) return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions}) except: import traceback traceback.print_exc() return jsonify({'status': -1, 'message': 'unkown error'}) def _extract_state(state): current_hand = _cards2array(state['current_hand']) others_hand = _cards2array(state['others_hand']) last_action = '' if len(state['trace']) != 0: if state['trace'][-1][1] == 'pass': last_action = state['trace'][-2][1] else: last_action = state['trace'][-1][1] last_action = _cards2array(last_action) last_9_actions = _action_seq2array(_process_action_seq(state['trace'])) if state['self'] == 0: # landlord landlord_up_played_cards = _cards2array(state['played_cards'][2]) landlord_down_played_cards = _cards2array(state['played_cards'][1]) landlord_up_num_cards_left = _get_one_hot_array( state['num_cards_left'][2], 17) landlord_down_num_cards_left = _get_one_hot_array( state['num_cards_left'][1], 17) obs = np.concatenate((current_hand, others_hand, last_action, last_9_actions, landlord_up_played_cards, landlord_down_played_cards, landlord_up_num_cards_left, landlord_down_num_cards_left)) else: landlord_played_cards = _cards2array(state['played_cards'][0]) for i, action in reversed(state['trace']): if i == 0: last_landlord_action = action last_landlord_action = _cards2array(last_landlord_action) landlord_num_cards_left = _get_one_hot_array( state['num_cards_left'][0], 20) teammate_id = 3 - state['self'] teammate_played_cards = _cards2array( state['played_cards'][teammate_id]) last_teammate_action = 'pass' for i, action in reversed(state['trace']): if i == teammate_id: last_teammate_action = action last_teammate_action = _cards2array(last_teammate_action) teammate_num_cards_left = _get_one_hot_array( state['num_cards_left'][teammate_id], 17) obs = np.concatenate((current_hand, others_hand, last_action, last_9_actions, landlord_played_cards, teammate_played_cards, last_landlord_action, last_teammate_action, landlord_num_cards_left, teammate_num_cards_left)) legal_actions = {env._ACTION_2_ID[action]: _cards2array( action) for action in state['actions']} extracted_state = OrderedDict({'obs': obs, 'legal_actions': legal_actions}) extracted_state['raw_obs'] = state extracted_state['raw_legal_actions'] = [a for a in state['actions']] return extracted_state def _get_legal_card_play_actions(player_hand_cards, rival_move): mg = MovesGener(player_hand_cards) rival_type = md.get_move_type(rival_move) rival_move_type = rival_type['type'] rival_move_len = rival_type.get('len', 1) moves = list() if rival_move_type == md.TYPE_0_PASS: moves = mg.gen_moves() elif rival_move_type == md.TYPE_1_SINGLE: all_moves = mg.gen_type_1_single() moves = ms.filter_type_1_single(all_moves, rival_move) elif rival_move_type == md.TYPE_2_PAIR: all_moves = mg.gen_type_2_pair() moves = ms.filter_type_2_pair(all_moves, rival_move) elif rival_move_type == md.TYPE_3_TRIPLE: all_moves = mg.gen_type_3_triple() moves = ms.filter_type_3_triple(all_moves, rival_move) elif rival_move_type == md.TYPE_4_BOMB: all_moves = mg.gen_type_4_bomb(4) moves = ms.filter_type_4_bomb(all_moves, rival_move) moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() elif rival_move_type == md.TYPE_4_BOMB5: all_moves = mg.gen_type_4_bomb(5) moves = ms.filter_type_4_bomb(all_moves, rival_move) moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() elif rival_move_type == md.TYPE_4_BOMB6: all_moves = mg.gen_type_4_bomb(6) moves = ms.filter_type_4_bomb(all_moves, rival_move) moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() elif rival_move_type == md.TYPE_4_BOMB7: all_moves = mg.gen_type_4_bomb(7) moves = ms.filter_type_4_bomb(all_moves, rival_move) moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() elif rival_move_type == md.TYPE_4_BOMB8: all_moves = mg.gen_type_4_bomb(8) moves = ms.filter_type_4_bomb(all_moves, rival_move) moves += mg.gen_type_5_king_bomb() elif rival_move_type == md.TYPE_5_KING_BOMB: moves = [] elif rival_move_type == md.TYPE_7_3_2: all_moves = mg.gen_type_7_3_2() moves = ms.filter_type_7_3_2(all_moves, rival_move) elif rival_move_type == md.TYPE_8_SERIAL_SINGLE: all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len) moves = ms.filter_type_8_serial_single(all_moves, rival_move) elif rival_move_type == md.TYPE_9_SERIAL_PAIR: all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len) moves = ms.filter_type_9_serial_pair(all_moves, rival_move) elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE: all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len) moves = ms.filter_type_10_serial_triple(all_moves, rival_move) elif rival_move_type == md.TYPE_12_SERIAL_3_2: all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len) moves = ms.filter_type_12_serial_3_2(all_moves, rival_move) if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB: moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() if len(rival_move) != 0: # rival_move is not 'pass' moves = moves + [[]] for m in moves: m.sort() moves.sort() moves = list(move for move, _ in itertools.groupby(moves)) return moves Card2Column = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4, '8': 5, '9': 6, 'T': 7, 'J': 8, 'Q': 9, 'K': 10, 'A': 11, '2': 12} NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]), 1: np.array([1, 0, 0, 0, 0, 0, 0, 0]), 2: np.array([1, 1, 0, 0, 0, 0, 0, 0]), 3: np.array([1, 1, 1, 0, 0, 0, 0, 0]), 4: np.array([1, 1, 1, 1, 0, 0, 0, 0]), 5: np.array([1, 1, 1, 1, 1, 0, 0, 0]), 6: np.array([1, 1, 1, 1, 1, 1, 0, 0]), 7: np.array([1, 1, 1, 1, 1, 1, 1, 0]), 8: np.array([1, 1, 1, 1, 1, 1, 1, 1])} def _cards2array(cards): if cards == 'pass': return np.zeros(108, dtype=np.int8) matrix = np.zeros([4, 13], dtype=np.int8) jokers = np.zeros(4, dtype=np.int8) counter = Counter(cards) for card, num_times in counter.items(): if card == 'B': jokers[0] = 1 if num_times == 2: jokers[1] = 1 elif card == 'R': jokers[2] = 1 if num_times == 2: jokers[3] = 1 else: matrix[:, Card2Column[card]] = NumOnes2Array[num_times] return np.concatenate((matrix.flatten('F'), jokers)) def _get_one_hot_array(num_left_cards, max_num_cards): one_hot = np.zeros(max_num_cards, dtype=np.int8) one_hot[num_left_cards - 1] = 1 return one_hot def _action_seq2array(action_seq_list): action_seq_array = np.ones((len(action_seq_list), 108)) * -1 # Default Value -1 for not using area for row, list_cards in enumerate(action_seq_list): if list_cards != []: action_seq_array[row, :108] = _cards2array(list_cards[1]) return action_seq_array def _process_action_seq(sequence, length=9): sequence = [action[1] for action in sequence[-length:]] if len(sequence) < length: empty_sequence = ['' for _ in range(length - len(sequence))] empty_sequence.extend(sequence) sequence = empty_sequence return sequence if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='DouZero backend') parser.add_argument('--debug', action='store_true') args = parser.parse_args() app.run(debug=args.debug)