diff --git a/pve_server/run_dmc.py b/pve_server/run_dmc.py index 2f73727..d3e32e9 100644 --- a/pve_server/run_dmc.py +++ b/pve_server/run_dmc.py @@ -1,35 +1,36 @@ -import os +from utils.move_generator import MovesGener +from utils import move_selector as ms +from utils import move_detector as md +import rlcard import itertools - -import torch -import numpy as np -from heapq import nlargest +import os from collections import Counter, OrderedDict +from heapq import nlargest + +import numpy as np +import torch from flask import Flask, jsonify, request from flask_cors import CORS app = Flask(__name__) CORS(app) -from utils.move_generator import MovesGener -from utils import move_detector as md, move_selector as ms -import rlcard env = rlcard.make('doudizhu') DouZeroCard2RLCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7', - 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', - 13: 'K', 14: 'A', 17: '2', 20: 'B', 30: 'R'} + 8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q', + 13: 'K', 14: 'A', 17: '2', 20: 'B', 30: 'R'} RLCard2DouZeroCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7, - '8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12, - 'K': 13, 'A': 14, '2': 17, 'B': 20, 'R': 30} + '8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12, + 'K': 13, 'A': 14, '2': 17, 'B': 20, 'R': 30} -EnvCard2RealCard = {'3': '3', '4':'4', '5': '5', '6': '6', '7': '7', +EnvCard2RealCard = {'3': '3', '4': '4', '5': '5', '6': '6', '7': '7', '8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q', 'K': 'K', 'A': 'A', '2': '2', 'B': 'X', 'R': 'D'} -RealCard2EnvCard = {'3': '3', '4':'4', '5': '5', '6': '6', '7': '7', +RealCard2EnvCard = {'3': '3', '4': '4', '5': '5', '6': '6', '7': '7', '8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q', 'K': 'K', 'A': 'A', '2': '2', 'X': 'B', 'D': 'R'} @@ -41,7 +42,8 @@ for i in range(3): agent = torch.load(model_path, map_location=device) agent.set_device(device) players.append(agent) - + + @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': @@ -53,7 +55,8 @@ def predict(): player_position = int(player_position) # Player hand cards - player_hand_cards = ''.join([RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]) + player_hand_cards = ''.join( + [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]) if player_position == 0: if len(player_hand_cards) < 1 or len(player_hand_cards) > 20: return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-20'}) @@ -62,14 +65,16 @@ def predict(): return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-17'}) # Number cards left - num_cards_left = [int(request.form.get('num_cards_left_landlord')), int(request.form.get('num_cards_left_landlord_down')), int(request.form.get('num_cards_left_landlord_up'))] + num_cards_left = [int(request.form.get('num_cards_left_landlord')), int(request.form.get( + 'num_cards_left_landlord_down')), int(request.form.get('num_cards_left_landlord_up'))] if num_cards_left[player_position] != len(player_hand_cards): return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'}) if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[0] > 20 or num_cards_left[1] > 17 or num_cards_left[2] > 17: return jsonify({'status': 5, 'message': 'the number of cards left not in range'}) # Three landlord cards - three_landlord_cards = ''.join([RealCard2EnvCard[c] for c in request.form.get('three_landlord_cards')]) + three_landlord_cards = ''.join( + [RealCard2EnvCard[c] for c in request.form.get('three_landlord_cards')]) if len(three_landlord_cards) < 0 or len(three_landlord_cards) > 3: return jsonify({'status': 6, 'message': 'the number of landlord cards should be 0-3'}) @@ -77,23 +82,26 @@ def predict(): if request.form.get('card_play_action_seq') == '': card_play_action_seq = [] else: - tmp_seq = [''.join([RealCard2EnvCard[c] for c in cards]) for cards in request.form.get('card_play_action_seq').split(',')] + tmp_seq = [''.join([RealCard2EnvCard[c] for c in cards]) + for cards in request.form.get('card_play_action_seq').split(',')] for i in range(len(tmp_seq)): if tmp_seq[i] == '': tmp_seq[i] = 'pass' card_play_action_seq = [] for i in range(len(tmp_seq)): - card_play_action_seq.append((i%3, tmp_seq[i])) + card_play_action_seq.append((i % 3, tmp_seq[i])) # Other hand cards - other_hand_cards = ''.join([RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')]) + other_hand_cards = ''.join( + [RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')]) if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]: return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'}) # Played cards played_cards = [] for field in ['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_up']: - played_cards.append(''.join([RealCard2EnvCard[c] for c in request.form.get(field)])) + played_cards.append( + ''.join([RealCard2EnvCard[c] for c in request.form.get(field)])) # RLCard state state = {} @@ -116,13 +124,14 @@ def predict(): if rival_move == 'pass': rival_move = '' rival_move = [RLCard2DouZeroCard[c] for c in rival_move] - state['actions'] = _get_legal_card_play_actions([RLCard2DouZeroCard[c] for c in player_hand_cards], rival_move) - state['actions'] = [''.join([DouZeroCard2RLCard[c] for c in a]) for a in state['actions']] + state['actions'] = _get_legal_card_play_actions( + [RLCard2DouZeroCard[c] for c in player_hand_cards], rival_move) + state['actions'] = [ + ''.join([DouZeroCard2RLCard[c] for c in a]) for a in state['actions']] for i in range(len(state['actions'])): if state['actions'][i] == '': state['actions'][i] = 'pass' - # Prediction state = _extract_state(state) action, info = players[player_position].eval_step(state) @@ -132,10 +141,13 @@ def predict(): if i == 'pass': info['values'][''] = info['values']['pass'] del info['values']['pass'] + break actions = nlargest(3, info['values'], key=info['values'].get) - actions_confidence = [info['values'].get(action) for action in actions] - actions = [''.join([EnvCard2RealCard[c] for c in action]) for action in actions] + actions_confidence = [info['values'].get( + action) for action in actions] + actions = [''.join([EnvCard2RealCard[c] for c in action]) + for action in actions] result = {} win_rates = {} for i in range(len(actions)): @@ -167,28 +179,36 @@ def predict(): traceback.print_exc() return jsonify({'status': -1, 'message': 'unkown error'}) + @app.route('/legal', methods=['POST']) def legal(): if request.method == 'POST': try: - player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')] - rival_move = [RealCard2EnvCard[c] for c in request.form.get('rival_move')] + player_hand_cards = [RealCard2EnvCard[c] + for c in request.form.get('player_hand_cards')] + rival_move = [RealCard2EnvCard[c] + for c in request.form.get('rival_move')] if rival_move == '': rival_move = 'pass' - player_hand_cards = [RLCard2DouZeroCard[c] for c in player_hand_cards] - rival_move = [RLCard2DouZeroCard[c] for c in rival_move] - legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move) - legal_actions = [''.join([DouZeroCard2RLCard[c] for c in a]) for a in legal_actions] + player_hand_cards = [RLCard2DouZeroCard[c] + for c in player_hand_cards] + rival_move = [RLCard2DouZeroCard[c] for c in rival_move] + legal_actions = _get_legal_card_play_actions( + player_hand_cards, rival_move) + legal_actions = [''.join([DouZeroCard2RLCard[c] + for c in a]) for a in legal_actions] for i in range(len(legal_actions)): if legal_actions[i] == 'pass': legal_actions[i] = '' - legal_actions = ','.join([''.join([EnvCard2RealCard[c] for c in action]) for action in legal_actions]) + legal_actions = ','.join( + [''.join([EnvCard2RealCard[c] for c in action]) for action in legal_actions]) return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions}) except: import traceback traceback.print_exc() return jsonify({'status': -1, 'message': 'unkown error'}) + def _extract_state(state): current_hand = _cards2array(state['current_hand']) others_hand = _cards2array(state['others_hand']) @@ -203,11 +223,13 @@ def _extract_state(state): last_9_actions = _action_seq2array(_process_action_seq(state['trace'])) - if state['self'] == 0: # landlord + if state['self'] == 0: # landlord landlord_up_played_cards = _cards2array(state['played_cards'][2]) landlord_down_played_cards = _cards2array(state['played_cards'][1]) - landlord_up_num_cards_left = _get_one_hot_array(state['num_cards_left'][2], 17) - landlord_down_num_cards_left = _get_one_hot_array(state['num_cards_left'][1], 17) + landlord_up_num_cards_left = _get_one_hot_array( + state['num_cards_left'][2], 17) + landlord_down_num_cards_left = _get_one_hot_array( + state['num_cards_left'][1], 17) obs = np.concatenate((current_hand, others_hand, last_action, @@ -222,16 +244,19 @@ def _extract_state(state): if i == 0: last_landlord_action = action last_landlord_action = _cards2array(last_landlord_action) - landlord_num_cards_left = _get_one_hot_array(state['num_cards_left'][0], 20) + landlord_num_cards_left = _get_one_hot_array( + state['num_cards_left'][0], 20) teammate_id = 3 - state['self'] - teammate_played_cards = _cards2array(state['played_cards'][teammate_id]) + teammate_played_cards = _cards2array( + state['played_cards'][teammate_id]) last_teammate_action = 'pass' for i, action in reversed(state['trace']): if i == teammate_id: last_teammate_action = action last_teammate_action = _cards2array(last_teammate_action) - teammate_num_cards_left = _get_one_hot_array(state['num_cards_left'][teammate_id], 17) + teammate_num_cards_left = _get_one_hot_array( + state['num_cards_left'][teammate_id], 17) obs = np.concatenate((current_hand, others_hand, last_action, @@ -243,12 +268,14 @@ def _extract_state(state): landlord_num_cards_left, teammate_num_cards_left)) - legal_actions = {env._ACTION_2_ID[action]: _cards2array(action) for action in state['actions']} + legal_actions = {env._ACTION_2_ID[action]: _cards2array( + action) for action in state['actions']} extracted_state = OrderedDict({'obs': obs, 'legal_actions': legal_actions}) extracted_state['raw_obs'] = state extracted_state['raw_legal_actions'] = [a for a in state['actions']] return extracted_state + def _get_legal_card_play_actions(player_hand_cards, rival_move): mg = MovesGener(player_hand_cards) @@ -326,10 +353,11 @@ def _get_legal_card_play_actions(player_hand_cards, rival_move): m.sort() moves.sort() - moves = list(move for move,_ in itertools.groupby(moves)) + moves = list(move for move, _ in itertools.groupby(moves)) return moves + Card2Column = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4, '8': 5, '9': 6, 'T': 7, 'J': 8, 'Q': 9, 'K': 10, 'A': 11, '2': 12} @@ -339,6 +367,7 @@ NumOnes2Array = {0: np.array([0, 0, 0, 0]), 3: np.array([1, 1, 1, 0]), 4: np.array([1, 1, 1, 1])} + def _cards2array(cards): if cards == 'pass': return np.zeros(54, dtype=np.int8) @@ -355,12 +384,14 @@ def _cards2array(cards): matrix[:, Card2Column[card]] = NumOnes2Array[num_times] return np.concatenate((matrix.flatten('F'), jokers)) + def _get_one_hot_array(num_left_cards, max_num_cards): one_hot = np.zeros(max_num_cards, dtype=np.int8) one_hot[num_left_cards - 1] = 1 return one_hot + def _action_seq2array(action_seq_list): action_seq_array = np.zeros((len(action_seq_list), 54), np.int8) for row, cards in enumerate(action_seq_list): @@ -368,6 +399,7 @@ def _action_seq2array(action_seq_list): action_seq_array = action_seq_array.flatten() return action_seq_array + def _process_action_seq(sequence, length=9): sequence = [action[1] for action in sequence[-length:]] if len(sequence) < length: @@ -376,6 +408,7 @@ def _process_action_seq(sequence, length=9): sequence = empty_sequence return sequence + if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='DouZero backend') diff --git a/src/assets/index.scss b/src/assets/index.scss index 5e1c736..58ec769 100644 --- a/src/assets/index.scss +++ b/src/assets/index.scss @@ -26,8 +26,9 @@ code { } .citation { + font-family: 'Rockwell', monospace, PingFangSC-Regular, sans-serif; margin-top: 20px; - padding: 5px; + padding: 6px; -webkit-user-select: text; -ms-user-select: text; @@ -40,6 +41,7 @@ code { } pre { + margin-top: 5px; padding: 10px; border-radius: 5px; color: #64686d; @@ -48,6 +50,8 @@ code { } } + + #upload-model-note { a { color: #3f51b5; diff --git a/src/components/GameBoard/DoudizhuGameBoard.js b/src/components/GameBoard/DoudizhuGameBoard.js index 7658d9c..7c68b9b 100644 --- a/src/components/GameBoard/DoudizhuGameBoard.js +++ b/src/components/GameBoard/DoudizhuGameBoard.js @@ -117,7 +117,9 @@ class DoudizhuGameBoard extends React.Component { return (
{`@article{zha2019rlcard, title={RLCard: A Toolkit for Reinforcement Learning in Card Games},