optimize ux; fix run dmc

This commit is contained in:
Songyi Huang 2021-06-04 23:46:43 -07:00
parent 4facaaceb6
commit 7ab2d3fa44
6 changed files with 138 additions and 66 deletions

View File

@ -1,35 +1,36 @@
import os
from utils.move_generator import MovesGener
from utils import move_selector as ms
from utils import move_detector as md
import rlcard
import itertools
import torch
import numpy as np
from heapq import nlargest
import os
from collections import Counter, OrderedDict
from heapq import nlargest
import numpy as np
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
from utils.move_generator import MovesGener
from utils import move_detector as md, move_selector as ms
import rlcard
env = rlcard.make('doudizhu')
DouZeroCard2RLCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'B', 30: 'R'}
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'B', 30: 'R'}
RLCard2DouZeroCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'B': 20, 'R': 30}
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'B': 20, 'R': 30}
EnvCard2RealCard = {'3': '3', '4':'4', '5': '5', '6': '6', '7': '7',
EnvCard2RealCard = {'3': '3', '4': '4', '5': '5', '6': '6', '7': '7',
'8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q',
'K': 'K', 'A': 'A', '2': '2', 'B': 'X', 'R': 'D'}
RealCard2EnvCard = {'3': '3', '4':'4', '5': '5', '6': '6', '7': '7',
RealCard2EnvCard = {'3': '3', '4': '4', '5': '5', '6': '6', '7': '7',
'8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q',
'K': 'K', 'A': 'A', '2': '2', 'X': 'B', 'D': 'R'}
@ -41,7 +42,8 @@ for i in range(3):
agent = torch.load(model_path, map_location=device)
agent.set_device(device)
players.append(agent)
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
@ -53,7 +55,8 @@ def predict():
player_position = int(player_position)
# Player hand cards
player_hand_cards = ''.join([RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')])
player_hand_cards = ''.join(
[RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')])
if player_position == 0:
if len(player_hand_cards) < 1 or len(player_hand_cards) > 20:
return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-20'})
@ -62,14 +65,16 @@ def predict():
return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-17'})
# Number cards left
num_cards_left = [int(request.form.get('num_cards_left_landlord')), int(request.form.get('num_cards_left_landlord_down')), int(request.form.get('num_cards_left_landlord_up'))]
num_cards_left = [int(request.form.get('num_cards_left_landlord')), int(request.form.get(
'num_cards_left_landlord_down')), int(request.form.get('num_cards_left_landlord_up'))]
if num_cards_left[player_position] != len(player_hand_cards):
return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'})
if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[0] > 20 or num_cards_left[1] > 17 or num_cards_left[2] > 17:
return jsonify({'status': 5, 'message': 'the number of cards left not in range'})
# Three landlord cards
three_landlord_cards = ''.join([RealCard2EnvCard[c] for c in request.form.get('three_landlord_cards')])
three_landlord_cards = ''.join(
[RealCard2EnvCard[c] for c in request.form.get('three_landlord_cards')])
if len(three_landlord_cards) < 0 or len(three_landlord_cards) > 3:
return jsonify({'status': 6, 'message': 'the number of landlord cards should be 0-3'})
@ -77,23 +82,26 @@ def predict():
if request.form.get('card_play_action_seq') == '':
card_play_action_seq = []
else:
tmp_seq = [''.join([RealCard2EnvCard[c] for c in cards]) for cards in request.form.get('card_play_action_seq').split(',')]
tmp_seq = [''.join([RealCard2EnvCard[c] for c in cards])
for cards in request.form.get('card_play_action_seq').split(',')]
for i in range(len(tmp_seq)):
if tmp_seq[i] == '':
tmp_seq[i] = 'pass'
card_play_action_seq = []
for i in range(len(tmp_seq)):
card_play_action_seq.append((i%3, tmp_seq[i]))
card_play_action_seq.append((i % 3, tmp_seq[i]))
# Other hand cards
other_hand_cards = ''.join([RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')])
other_hand_cards = ''.join(
[RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')])
if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]:
return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'})
# Played cards
played_cards = []
for field in ['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_up']:
played_cards.append(''.join([RealCard2EnvCard[c] for c in request.form.get(field)]))
played_cards.append(
''.join([RealCard2EnvCard[c] for c in request.form.get(field)]))
# RLCard state
state = {}
@ -116,13 +124,14 @@ def predict():
if rival_move == 'pass':
rival_move = ''
rival_move = [RLCard2DouZeroCard[c] for c in rival_move]
state['actions'] = _get_legal_card_play_actions([RLCard2DouZeroCard[c] for c in player_hand_cards], rival_move)
state['actions'] = [''.join([DouZeroCard2RLCard[c] for c in a]) for a in state['actions']]
state['actions'] = _get_legal_card_play_actions(
[RLCard2DouZeroCard[c] for c in player_hand_cards], rival_move)
state['actions'] = [
''.join([DouZeroCard2RLCard[c] for c in a]) for a in state['actions']]
for i in range(len(state['actions'])):
if state['actions'][i] == '':
state['actions'][i] = 'pass'
# Prediction
state = _extract_state(state)
action, info = players[player_position].eval_step(state)
@ -132,10 +141,13 @@ def predict():
if i == 'pass':
info['values'][''] = info['values']['pass']
del info['values']['pass']
break
actions = nlargest(3, info['values'], key=info['values'].get)
actions_confidence = [info['values'].get(action) for action in actions]
actions = [''.join([EnvCard2RealCard[c] for c in action]) for action in actions]
actions_confidence = [info['values'].get(
action) for action in actions]
actions = [''.join([EnvCard2RealCard[c] for c in action])
for action in actions]
result = {}
win_rates = {}
for i in range(len(actions)):
@ -167,28 +179,36 @@ def predict():
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
@app.route('/legal', methods=['POST'])
def legal():
if request.method == 'POST':
try:
player_hand_cards = [RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')]
rival_move = [RealCard2EnvCard[c] for c in request.form.get('rival_move')]
player_hand_cards = [RealCard2EnvCard[c]
for c in request.form.get('player_hand_cards')]
rival_move = [RealCard2EnvCard[c]
for c in request.form.get('rival_move')]
if rival_move == '':
rival_move = 'pass'
player_hand_cards = [RLCard2DouZeroCard[c] for c in player_hand_cards]
rival_move = [RLCard2DouZeroCard[c] for c in rival_move]
legal_actions = _get_legal_card_play_actions(player_hand_cards, rival_move)
legal_actions = [''.join([DouZeroCard2RLCard[c] for c in a]) for a in legal_actions]
player_hand_cards = [RLCard2DouZeroCard[c]
for c in player_hand_cards]
rival_move = [RLCard2DouZeroCard[c] for c in rival_move]
legal_actions = _get_legal_card_play_actions(
player_hand_cards, rival_move)
legal_actions = [''.join([DouZeroCard2RLCard[c]
for c in a]) for a in legal_actions]
for i in range(len(legal_actions)):
if legal_actions[i] == 'pass':
legal_actions[i] = ''
legal_actions = ','.join([''.join([EnvCard2RealCard[c] for c in action]) for action in legal_actions])
legal_actions = ','.join(
[''.join([EnvCard2RealCard[c] for c in action]) for action in legal_actions])
return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions})
except:
import traceback
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
def _extract_state(state):
current_hand = _cards2array(state['current_hand'])
others_hand = _cards2array(state['others_hand'])
@ -203,11 +223,13 @@ def _extract_state(state):
last_9_actions = _action_seq2array(_process_action_seq(state['trace']))
if state['self'] == 0: # landlord
if state['self'] == 0: # landlord
landlord_up_played_cards = _cards2array(state['played_cards'][2])
landlord_down_played_cards = _cards2array(state['played_cards'][1])
landlord_up_num_cards_left = _get_one_hot_array(state['num_cards_left'][2], 17)
landlord_down_num_cards_left = _get_one_hot_array(state['num_cards_left'][1], 17)
landlord_up_num_cards_left = _get_one_hot_array(
state['num_cards_left'][2], 17)
landlord_down_num_cards_left = _get_one_hot_array(
state['num_cards_left'][1], 17)
obs = np.concatenate((current_hand,
others_hand,
last_action,
@ -222,16 +244,19 @@ def _extract_state(state):
if i == 0:
last_landlord_action = action
last_landlord_action = _cards2array(last_landlord_action)
landlord_num_cards_left = _get_one_hot_array(state['num_cards_left'][0], 20)
landlord_num_cards_left = _get_one_hot_array(
state['num_cards_left'][0], 20)
teammate_id = 3 - state['self']
teammate_played_cards = _cards2array(state['played_cards'][teammate_id])
teammate_played_cards = _cards2array(
state['played_cards'][teammate_id])
last_teammate_action = 'pass'
for i, action in reversed(state['trace']):
if i == teammate_id:
last_teammate_action = action
last_teammate_action = _cards2array(last_teammate_action)
teammate_num_cards_left = _get_one_hot_array(state['num_cards_left'][teammate_id], 17)
teammate_num_cards_left = _get_one_hot_array(
state['num_cards_left'][teammate_id], 17)
obs = np.concatenate((current_hand,
others_hand,
last_action,
@ -243,12 +268,14 @@ def _extract_state(state):
landlord_num_cards_left,
teammate_num_cards_left))
legal_actions = {env._ACTION_2_ID[action]: _cards2array(action) for action in state['actions']}
legal_actions = {env._ACTION_2_ID[action]: _cards2array(
action) for action in state['actions']}
extracted_state = OrderedDict({'obs': obs, 'legal_actions': legal_actions})
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['actions']]
return extracted_state
def _get_legal_card_play_actions(player_hand_cards, rival_move):
mg = MovesGener(player_hand_cards)
@ -326,10 +353,11 @@ def _get_legal_card_play_actions(player_hand_cards, rival_move):
m.sort()
moves.sort()
moves = list(move for move,_ in itertools.groupby(moves))
moves = list(move for move, _ in itertools.groupby(moves))
return moves
Card2Column = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4, '8': 5, '9': 6, 'T': 7,
'J': 8, 'Q': 9, 'K': 10, 'A': 11, '2': 12}
@ -339,6 +367,7 @@ NumOnes2Array = {0: np.array([0, 0, 0, 0]),
3: np.array([1, 1, 1, 0]),
4: np.array([1, 1, 1, 1])}
def _cards2array(cards):
if cards == 'pass':
return np.zeros(54, dtype=np.int8)
@ -355,12 +384,14 @@ def _cards2array(cards):
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
return np.concatenate((matrix.flatten('F'), jokers))
def _get_one_hot_array(num_left_cards, max_num_cards):
one_hot = np.zeros(max_num_cards, dtype=np.int8)
one_hot[num_left_cards - 1] = 1
return one_hot
def _action_seq2array(action_seq_list):
action_seq_array = np.zeros((len(action_seq_list), 54), np.int8)
for row, cards in enumerate(action_seq_list):
@ -368,6 +399,7 @@ def _action_seq2array(action_seq_list):
action_seq_array = action_seq_array.flatten()
return action_seq_array
def _process_action_seq(sequence, length=9):
sequence = [action[1] for action in sequence[-length:]]
if len(sequence) < length:
@ -376,6 +408,7 @@ def _process_action_seq(sequence, length=9):
sequence = empty_sequence
return sequence
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='DouZero backend')

View File

@ -26,8 +26,9 @@ code {
}
.citation {
font-family: 'Rockwell', monospace, PingFangSC-Regular, sans-serif;
margin-top: 20px;
padding: 5px;
padding: 6px;
-webkit-user-select: text;
-ms-user-select: text;
@ -40,6 +41,7 @@ code {
}
pre {
margin-top: 5px;
padding: 10px;
border-radius: 5px;
color: #64686d;
@ -48,6 +50,8 @@ code {
}
}
#upload-model-note {
a {
color: #3f51b5;

View File

@ -117,7 +117,9 @@ class DoudizhuGameBoard extends React.Component {
return (
<div
className={`playingCards loose ${fadeClassName} ${
this.props.gamePlayable && cardSelectable ? 'selectable' : 'unselectable'
this.props.gameStatus === 'playing' && this.props.gamePlayable && cardSelectable
? 'selectable'
: 'unselectable'
}`}
>
<ul className="hand" style={{ width: computeHandCardsWidth(cards.length, 12) }}>
@ -243,7 +245,7 @@ class DoudizhuGameBoard extends React.Component {
{t('doudizhu.deselect')}
</Button> */}
<Button
disabled={this.props.isHintDisabled}
disabled={this.props.isHintDisabled || this.props.gameStatus !== 'playing'}
onClick={(e) => {
e.stopPropagation();
this.props.handleMainPlayerAct('hint');
@ -255,7 +257,7 @@ class DoudizhuGameBoard extends React.Component {
{t('doudizhu.hint')}
</Button>
<Button
disabled={this.props.isPassDisabled}
disabled={this.props.isPassDisabled || this.props.gameStatus !== 'playing'}
onClick={(e) => {
e.stopPropagation();
this.props.handleMainPlayerAct('pass');
@ -267,7 +269,11 @@ class DoudizhuGameBoard extends React.Component {
{t('doudizhu.pass')}
</Button>
<Button
disabled={!this.props.selectedCards || this.props.selectedCards.length === 0}
disabled={
!this.props.selectedCards ||
this.props.selectedCards.length === 0 ||
this.props.gameStatus !== 'playing'
}
onClick={(e) => {
console.log('play', e.stopPropagation);
e.stopPropagation();

View File

@ -4,6 +4,7 @@
"play_again": "Play Again",
"cancel": "Cancel",
"turn": "Turn",
"reset": "Reset",
"doudizhu": {
"ai_hand_faceup": "AI Hand Face-Up",

View File

@ -1,9 +1,10 @@
{
"hidden": "隐藏",
"waiting...": "等待中...",
"play_again": "再次游玩",
"play_again": "再来一局",
"cancel": "取消",
"turn": "回合",
"reset": "重置",
"doudizhu": {
"ai_hand_faceup": "显示AI手牌",

View File

@ -63,7 +63,6 @@ let playedCardsLandlordUp = [];
let legalActions = { turn: -1, actions: [] };
let hintIdx = -1;
let gameEndDialogTitle = '';
let statisticRows = [];
let syncGameStatus = localStorage.getItem('LOCALE') ? 'ready' : 'localeSelection';
function PvEDoudizhuDemoView() {
@ -86,6 +85,7 @@ function PvEDoudizhuDemoView() {
const [hideRivalHand, setHideRivalHand] = useState(true);
const [hidePredictionArea, setHidePredictionArea] = useState(true);
const [locale, setLocale] = useState(localStorage.getItem('LOCALE') || 'en');
const [statisticRows, setStatisticRows] = useState([]);
const cardArr2DouzeroFormat = (cards) => {
return cards
@ -228,6 +228,7 @@ function PvEDoudizhuDemoView() {
}
if (newHand.length === 0) {
setHideRivalHand(false);
const winner = playerInfo[gameState.currentPlayer];
// update game overall history
@ -279,9 +280,9 @@ function PvEDoudizhuDemoView() {
setTimeout(() => {
gameEndDialogTitle =
winner.role === 'peasant' ? t('doudizhu.peasants_win') : t('doudizhu.landlord_win');
statisticRows = [
setStatisticRows([
{
role: 'Landlord',
role: t('doudizhu.landlord'),
win: gameStatistics.landlordWinNum,
total: gameStatistics.landlordGameNum,
winRate: gameStatistics.landlordGameNum
@ -289,7 +290,7 @@ function PvEDoudizhuDemoView() {
: '-',
},
{
role: 'Landlord Up',
role: t('doudizhu.landlord_up'),
win: gameStatistics.landlordUpWinNum,
total: gameStatistics.landlordUpGameNum,
winRate: gameStatistics.landlordUpGameNum
@ -298,7 +299,7 @@ function PvEDoudizhuDemoView() {
: '-',
},
{
role: 'Landlord Down',
role: t('doudizhu.landlord_down'),
win: gameStatistics.landlordDownWinNum,
total: gameStatistics.landlordDownGameNum,
winRate: gameStatistics.landlordDownGameNum
@ -315,10 +316,10 @@ function PvEDoudizhuDemoView() {
? ((gameStatistics.totalWinNum / gameStatistics.totalGameNum) * 100).toFixed(2) + '%'
: '-',
},
];
]);
setIsGameEndDialogOpen(true);
}, 300);
}, 2000);
} else {
setConsiderationTime(initConsiderationTime);
// manually trigger timer if consideration time equals initConsiderationTime
@ -540,6 +541,36 @@ function PvEDoudizhuDemoView() {
}, considerationTimeDeduction);
};
const handleResetStatistics = () => {
localStorage.removeItem('GAME_STATISTICS');
setStatisticRows([
{
role: 'Landlord',
win: 0,
total: 0,
winRate: '-',
},
{
role: 'Landlord Up',
win: 0,
total: 0,
winRate: '-',
},
{
role: 'Landlord Down',
win: 0,
total: 0,
winRate: '-',
},
{
role: 'All',
win: 0,
total: 0,
winRate: '-',
},
]);
};
const handleCloseGameEndDialog = () => {
// reset all game state for new game
shuffledDoudizhuDeck = shuffleArray(fullDoudizhuDeck.slice());
@ -578,6 +609,7 @@ function PvEDoudizhuDemoView() {
});
setSelectedCards([]); // user selected hand card
setPredictionRes({ prediction: [], hands: [] });
setHideRivalHand(hidePredictionArea);
setGameStatus('ready');
syncGameStatus = 'ready';
@ -884,14 +916,7 @@ function PvEDoudizhuDemoView() {
</TableContainer>
</DialogContent>
<DialogActions>
<Button
onClick={() => {
// todo: disable all action (pass, deselect) if cancel
setIsGameEndDialogOpen(false);
}}
>
{t('cancel')}
</Button>
<Button onClick={() => handleResetStatistics()}>{t('reset')}</Button>
<Button
onClick={() => handleCloseGameEndDialog()}
color="primary"
@ -1077,7 +1102,7 @@ function PvEDoudizhuDemoView() {
</Paper>
</div>
<div className="citation">
{locale === 'en' ? (
{/* {locale === 'en' ? (
<>
This demo is based on{' '}
<a href="https://github.com/datamllab/rlcard" target="_blank">
@ -1101,8 +1126,10 @@ function PvEDoudizhuDemoView() {
</a>{' '}
项目如果这些项目帮到您请添加引用:
</>
)}
)} */}
Zha, Daochen, Kwei-Herng Lai, Songyi Huang, Yuanpu Cao, Keerthana Reddy, Juan Vargas, Alex Nguyen,
Ruzhe Wei, Junyu Guo, and Xia Hu. "RLCard: A Platform for Reinforcement Learning in Card Games." In
IJCAI. 2020.
<pre>
{`@article{zha2019rlcard,
title={RLCard: A Toolkit for Reinforcement Learning in Card Games},