rlcard-showdown/pve_server/run_dmc.py

438 lines
18 KiB
Python
Raw Normal View History

2021-06-05 14:46:43 +08:00
from utils.move_generator import MovesGener
from utils import move_selector as ms
from utils import move_detector as md
import rlcard
2021-05-30 22:22:28 +08:00
import itertools
2021-06-05 14:46:43 +08:00
import os
from collections import Counter, OrderedDict
from heapq import nlargest
2021-05-30 22:22:28 +08:00
2021-05-29 04:43:44 +08:00
import numpy as np
2021-06-05 14:46:43 +08:00
import torch
2021-05-29 04:43:44 +08:00
from flask import Flask, jsonify, request
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
env = rlcard.make('doudizhu')
DouZeroCard2RLCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
2021-06-05 14:46:43 +08:00
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'B', 30: 'R'}
2021-05-29 04:43:44 +08:00
RLCard2DouZeroCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
2021-06-05 14:46:43 +08:00
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'B': 20, 'R': 30}
2021-05-29 04:43:44 +08:00
2021-06-05 14:46:43 +08:00
EnvCard2RealCard = {'3': '3', '4': '4', '5': '5', '6': '6', '7': '7',
2021-05-29 04:43:44 +08:00
'8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q',
'K': 'K', 'A': 'A', '2': '2', 'B': 'X', 'R': 'D'}
2021-06-05 14:46:43 +08:00
RealCard2EnvCard = {'3': '3', '4': '4', '5': '5', '6': '6', '7': '7',
2021-05-29 04:43:44 +08:00
'8': '8', '9': '9', 'T': 'T', 'J': 'J', 'Q': 'Q',
'K': 'K', 'A': 'A', '2': '2', 'X': 'B', 'D': 'R'}
pretrained_dir = 'pretrained/dmc_pretrained'
device = torch.device('cpu')
players = []
2021-12-18 15:49:37 +08:00
for i in range(4):
2021-05-29 04:43:44 +08:00
model_path = os.path.join(pretrained_dir, str(i)+'.pth')
agent = torch.load(model_path, map_location=device)
agent.set_device(device)
players.append(agent)
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
try:
# Player postion
player_position = request.form.get('player_position')
2021-12-18 15:49:37 +08:00
if player_position not in ['0', '1', '2', '3']:
return jsonify({'status': 1, 'message': 'player_position must be 0, 1, 2 or 3'})
2021-05-29 04:43:44 +08:00
player_position = int(player_position)
# Player hand cards
2021-06-05 14:46:43 +08:00
player_hand_cards = ''.join(
[RealCard2EnvCard[c] for c in request.form.get('player_hand_cards')])
2021-05-29 04:43:44 +08:00
if player_position == 0:
2021-12-18 15:49:37 +08:00
if len(player_hand_cards) < 1 or len(player_hand_cards) > 33:
return jsonify({'status': 2, 'message': 'the number of hand cards should be 1-33'})
2021-05-29 04:43:44 +08:00
else:
2021-12-18 15:49:37 +08:00
if len(player_hand_cards) < 1 or len(player_hand_cards) > 25:
return jsonify({'status': 3, 'message': 'the number of hand cards should be 1-25'})
2021-05-29 04:43:44 +08:00
# Number cards left
2021-12-18 15:49:37 +08:00
num_cards_left = [
int(request.form.get('num_cards_left_landlord')),
int(request.form.get('num_cards_left_landlord_down')),
int(request.form.get('num_cards_left_landlord_front')),
int(request.form.get('num_cards_left_landlord_up'))
]
2021-05-29 04:43:44 +08:00
if num_cards_left[player_position] != len(player_hand_cards):
return jsonify({'status': 4, 'message': 'the number of cards left do not align with hand cards'})
2021-12-18 15:49:37 +08:00
if num_cards_left[0] < 0 or num_cards_left[1] < 0 or num_cards_left[2] < 0 or num_cards_left[3] < 0 \
or num_cards_left[0] > 33 or num_cards_left[1] > 25 or num_cards_left[2] > 25 or num_cards_left[3] > 25:
2021-05-29 04:43:44 +08:00
return jsonify({'status': 5, 'message': 'the number of cards left not in range'})
# Three landlord cards
2021-12-18 15:49:37 +08:00
# three_landlord_cards = ''.join(
# [RealCard2EnvCard[c] for c in request.form.get('three_landlord_cards')])
# if len(three_landlord_cards) < 0 or len(three_landlord_cards) > 3:
# return jsonify({'status': 6, 'message': 'the number of landlord cards should be 0-3'})
2021-05-29 04:43:44 +08:00
# Card play sequence
if request.form.get('card_play_action_seq') == '':
card_play_action_seq = []
else:
2021-06-05 14:46:43 +08:00
tmp_seq = [''.join([RealCard2EnvCard[c] for c in cards])
for cards in request.form.get('card_play_action_seq').split(',')]
2021-05-29 04:43:44 +08:00
for i in range(len(tmp_seq)):
if tmp_seq[i] == '':
tmp_seq[i] = 'pass'
card_play_action_seq = []
for i in range(len(tmp_seq)):
2021-12-18 15:49:37 +08:00
card_play_action_seq.append((i % 4, tmp_seq[i]))
2021-05-29 04:43:44 +08:00
# Other hand cards
2021-06-05 14:46:43 +08:00
other_hand_cards = ''.join(
[RealCard2EnvCard[c] for c in request.form.get('other_hand_cards')])
2021-05-29 04:43:44 +08:00
if len(other_hand_cards) != sum(num_cards_left) - num_cards_left[player_position]:
return jsonify({'status': 7, 'message': 'the number of the other hand cards do not align with the number of cards left'})
# Played cards
played_cards = []
2021-12-18 15:49:37 +08:00
for field in ['played_cards_landlord', 'played_cards_landlord_down', 'played_cards_landlord_front', 'played_cards_landlord_up']:
2021-06-05 14:46:43 +08:00
played_cards.append(
''.join([RealCard2EnvCard[c] for c in request.form.get(field)]))
2021-05-29 04:43:44 +08:00
# RLCard state
state = {}
state['current_hand'] = player_hand_cards
state['landlord'] = 0
state['num_cards_left'] = num_cards_left
state['others_hand'] = other_hand_cards
state['played_cards'] = played_cards
2021-12-18 15:49:37 +08:00
# state['seen_cards'] = three_landlord_cards
2021-05-29 04:43:44 +08:00
state['self'] = player_position
state['trace'] = card_play_action_seq
# Get rival move and legal_actions
rival_move = 'pass'
if len(card_play_action_seq) != 0:
if card_play_action_seq[-1][1] == 'pass':
2021-12-18 15:49:37 +08:00
if card_play_action_seq[-2][1] == 'pass':
rival_move = card_play_action_seq[-3][1]
else:
rival_move = card_play_action_seq[-2][1]
2021-05-29 04:43:44 +08:00
else:
rival_move = card_play_action_seq[-1][1]
if rival_move == 'pass':
rival_move = ''
rival_move = [RLCard2DouZeroCard[c] for c in rival_move]
2021-06-05 14:46:43 +08:00
state['actions'] = _get_legal_card_play_actions(
[RLCard2DouZeroCard[c] for c in player_hand_cards], rival_move)
state['actions'] = [
''.join([DouZeroCard2RLCard[c] for c in a]) for a in state['actions']]
2021-05-29 04:43:44 +08:00
for i in range(len(state['actions'])):
if state['actions'][i] == '':
state['actions'][i] = 'pass'
# Prediction
state = _extract_state(state)
action, info = players[player_position].eval_step(state)
if action == 'pass':
action = ''
for i in info['values']:
if i == 'pass':
info['values'][''] = info['values']['pass']
del info['values']['pass']
2021-06-05 14:46:43 +08:00
break
2021-05-29 04:43:44 +08:00
actions = nlargest(3, info['values'], key=info['values'].get)
2021-06-05 14:46:43 +08:00
actions_confidence = [info['values'].get(
action) for action in actions]
actions = [''.join([EnvCard2RealCard[c] for c in action])
for action in actions]
2021-05-29 04:43:44 +08:00
result = {}
win_rates = {}
for i in range(len(actions)):
# Here, we calculate the win rate
win_rate = min(actions_confidence[i], 1)
win_rate = max(win_rate, 0)
win_rates[actions[i]] = str(round(win_rate, 4))
result[actions[i]] = str(round(actions_confidence[i], 6))
############## DEBUG ################
if app.debug:
print('--------------- DEBUG START --------------')
command = 'curl --data "'
parameters = []
for key in request.form:
parameters.append(key+'='+request.form.get(key))
print(key+':', request.form.get(key))
command += '&'.join(parameters)
command += '" "http://127.0.0.1:5000/predict"'
print('Command:', command)
print('Rival Move:', rival_move)
print('legal_actions:', state['legal_actions'])
print('Result:', result)
print('--------------- DEBUG END --------------')
############## DEBUG ################
return jsonify({'status': 0, 'message': 'success', 'result': result, 'win_rates': win_rates})
except:
import traceback
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
@app.route('/legal', methods=['POST'])
def legal():
if request.method == 'POST':
try:
2021-06-05 14:46:43 +08:00
player_hand_cards = [RealCard2EnvCard[c]
for c in request.form.get('player_hand_cards')]
rival_move = [RealCard2EnvCard[c]
for c in request.form.get('rival_move')]
2021-05-29 04:43:44 +08:00
if rival_move == '':
rival_move = 'pass'
2021-06-05 14:46:43 +08:00
player_hand_cards = [RLCard2DouZeroCard[c]
for c in player_hand_cards]
rival_move = [RLCard2DouZeroCard[c] for c in rival_move]
legal_actions = _get_legal_card_play_actions(
player_hand_cards, rival_move)
legal_actions = [''.join([DouZeroCard2RLCard[c]
for c in a]) for a in legal_actions]
2021-05-29 04:43:44 +08:00
for i in range(len(legal_actions)):
if legal_actions[i] == 'pass':
legal_actions[i] = ''
2021-06-05 14:46:43 +08:00
legal_actions = ','.join(
[''.join([EnvCard2RealCard[c] for c in action]) for action in legal_actions])
2021-05-29 04:43:44 +08:00
return jsonify({'status': 0, 'message': 'success', 'legal_action': legal_actions})
except:
import traceback
traceback.print_exc()
return jsonify({'status': -1, 'message': 'unkown error'})
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
def _extract_state(state):
current_hand = _cards2array(state['current_hand'])
others_hand = _cards2array(state['others_hand'])
last_action = ''
if len(state['trace']) != 0:
if state['trace'][-1][1] == 'pass':
last_action = state['trace'][-2][1]
else:
last_action = state['trace'][-1][1]
last_action = _cards2array(last_action)
last_9_actions = _action_seq2array(_process_action_seq(state['trace']))
2021-06-05 14:46:43 +08:00
if state['self'] == 0: # landlord
2021-05-29 04:43:44 +08:00
landlord_up_played_cards = _cards2array(state['played_cards'][2])
landlord_down_played_cards = _cards2array(state['played_cards'][1])
2021-06-05 14:46:43 +08:00
landlord_up_num_cards_left = _get_one_hot_array(
state['num_cards_left'][2], 17)
landlord_down_num_cards_left = _get_one_hot_array(
state['num_cards_left'][1], 17)
2021-05-29 04:43:44 +08:00
obs = np.concatenate((current_hand,
others_hand,
last_action,
last_9_actions,
landlord_up_played_cards,
landlord_down_played_cards,
landlord_up_num_cards_left,
landlord_down_num_cards_left))
else:
landlord_played_cards = _cards2array(state['played_cards'][0])
for i, action in reversed(state['trace']):
if i == 0:
last_landlord_action = action
last_landlord_action = _cards2array(last_landlord_action)
2021-06-05 14:46:43 +08:00
landlord_num_cards_left = _get_one_hot_array(
state['num_cards_left'][0], 20)
2021-05-29 04:43:44 +08:00
teammate_id = 3 - state['self']
2021-06-05 14:46:43 +08:00
teammate_played_cards = _cards2array(
state['played_cards'][teammate_id])
2021-05-29 04:43:44 +08:00
last_teammate_action = 'pass'
for i, action in reversed(state['trace']):
if i == teammate_id:
last_teammate_action = action
last_teammate_action = _cards2array(last_teammate_action)
2021-06-05 14:46:43 +08:00
teammate_num_cards_left = _get_one_hot_array(
state['num_cards_left'][teammate_id], 17)
2021-05-29 04:43:44 +08:00
obs = np.concatenate((current_hand,
others_hand,
last_action,
last_9_actions,
landlord_played_cards,
teammate_played_cards,
last_landlord_action,
last_teammate_action,
landlord_num_cards_left,
teammate_num_cards_left))
2021-06-05 14:46:43 +08:00
legal_actions = {env._ACTION_2_ID[action]: _cards2array(
action) for action in state['actions']}
2021-05-29 04:43:44 +08:00
extracted_state = OrderedDict({'obs': obs, 'legal_actions': legal_actions})
extracted_state['raw_obs'] = state
extracted_state['raw_legal_actions'] = [a for a in state['actions']]
return extracted_state
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
def _get_legal_card_play_actions(player_hand_cards, rival_move):
mg = MovesGener(player_hand_cards)
rival_type = md.get_move_type(rival_move)
rival_move_type = rival_type['type']
rival_move_len = rival_type.get('len', 1)
moves = list()
if rival_move_type == md.TYPE_0_PASS:
moves = mg.gen_moves()
elif rival_move_type == md.TYPE_1_SINGLE:
all_moves = mg.gen_type_1_single()
moves = ms.filter_type_1_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_2_PAIR:
all_moves = mg.gen_type_2_pair()
moves = ms.filter_type_2_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_3_TRIPLE:
all_moves = mg.gen_type_3_triple()
moves = ms.filter_type_3_triple(all_moves, rival_move)
elif rival_move_type == md.TYPE_4_BOMB:
2021-12-18 15:49:37 +08:00
all_moves = mg.gen_type_4_bomb(4)
2021-05-29 04:43:44 +08:00
moves = ms.filter_type_4_bomb(all_moves, rival_move)
2021-12-18 15:49:37 +08:00
moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB5:
all_moves = mg.gen_type_4_bomb(5)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB6:
all_moves = mg.gen_type_4_bomb(6)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB7:
all_moves = mg.gen_type_4_bomb(7)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB8:
all_moves = mg.gen_type_4_bomb(8)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_5_king_bomb()
2021-05-29 04:43:44 +08:00
elif rival_move_type == md.TYPE_5_KING_BOMB:
moves = []
elif rival_move_type == md.TYPE_7_3_2:
all_moves = mg.gen_type_7_3_2()
moves = ms.filter_type_7_3_2(all_moves, rival_move)
elif rival_move_type == md.TYPE_8_SERIAL_SINGLE:
all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len)
moves = ms.filter_type_8_serial_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_9_SERIAL_PAIR:
all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len)
moves = ms.filter_type_9_serial_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE:
all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len)
moves = ms.filter_type_10_serial_triple(all_moves, rival_move)
elif rival_move_type == md.TYPE_12_SERIAL_3_2:
all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len)
moves = ms.filter_type_12_serial_3_2(all_moves, rival_move)
2021-12-18 15:49:37 +08:00
if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB:
moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
2021-05-29 04:43:44 +08:00
if len(rival_move) != 0: # rival_move is not 'pass'
moves = moves + [[]]
for m in moves:
m.sort()
2021-05-30 22:22:28 +08:00
moves.sort()
2021-06-05 14:46:43 +08:00
moves = list(move for move, _ in itertools.groupby(moves))
2021-05-30 22:22:28 +08:00
2021-05-29 04:43:44 +08:00
return moves
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
Card2Column = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4, '8': 5, '9': 6, 'T': 7,
'J': 8, 'Q': 9, 'K': 10, 'A': 11, '2': 12}
2021-12-18 15:49:37 +08:00
NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
2: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
3: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
4: np.array([1, 1, 1, 1, 0, 0, 0, 0]),
5: np.array([1, 1, 1, 1, 1, 0, 0, 0]),
6: np.array([1, 1, 1, 1, 1, 1, 0, 0]),
7: np.array([1, 1, 1, 1, 1, 1, 1, 0]),
8: np.array([1, 1, 1, 1, 1, 1, 1, 1])}
2021-05-29 04:43:44 +08:00
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
def _cards2array(cards):
if cards == 'pass':
2021-12-18 15:49:37 +08:00
return np.zeros(108, dtype=np.int8)
2021-05-29 04:43:44 +08:00
matrix = np.zeros([4, 13], dtype=np.int8)
2021-12-18 15:49:37 +08:00
jokers = np.zeros(4, dtype=np.int8)
2021-05-29 04:43:44 +08:00
counter = Counter(cards)
for card, num_times in counter.items():
if card == 'B':
jokers[0] = 1
2021-12-18 15:49:37 +08:00
if num_times == 2:
jokers[1] = 1
2021-05-29 04:43:44 +08:00
elif card == 'R':
2021-12-18 15:49:37 +08:00
jokers[2] = 1
if num_times == 2:
jokers[3] = 1
2021-05-29 04:43:44 +08:00
else:
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
return np.concatenate((matrix.flatten('F'), jokers))
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
def _get_one_hot_array(num_left_cards, max_num_cards):
one_hot = np.zeros(max_num_cards, dtype=np.int8)
one_hot[num_left_cards - 1] = 1
return one_hot
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
def _action_seq2array(action_seq_list):
2021-12-18 15:49:37 +08:00
action_seq_array = np.ones((len(action_seq_list), 108)) * -1 # Default Value -1 for not using area
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :108] = _cards2array(list_cards[1])
2021-05-29 04:43:44 +08:00
return action_seq_array
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
def _process_action_seq(sequence, length=9):
sequence = [action[1] for action in sequence[-length:]]
if len(sequence) < length:
empty_sequence = ['' for _ in range(length - len(sequence))]
empty_sequence.extend(sequence)
sequence = empty_sequence
return sequence
2021-06-05 14:46:43 +08:00
2021-05-29 04:43:44 +08:00
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='DouZero backend')
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
app.run(debug=args.debug)