diff --git a/douzero/env/game.py b/douzero/env/game.py index 56950e5..7ab3ee6 100644 --- a/douzero/env/game.py +++ b/douzero/env/game.py @@ -22,6 +22,105 @@ bombs[1].extend([[20, 20, 30, 30]]) # Normal bomb bombs[2].extend([[x] * 5 for x in cards_idx[:-2]]) + +def get_legal_card_play_actions(player_hand_cards, rival_move): + mg = MovesGener(player_hand_cards) + + rival_type = md.get_move_type(rival_move) + rival_move_type = rival_type['type'] + rival_move_len = rival_type.get('len', 1) + moves = list() + + if rival_move_type == md.TYPE_0_PASS: + moves = mg.gen_moves() + + elif rival_move_type == md.TYPE_1_SINGLE: + all_moves = mg.gen_type_1_single() + moves = ms.filter_type_1_single(all_moves, rival_move) + + elif rival_move_type == md.TYPE_2_PAIR: + all_moves = mg.gen_type_2_pair() + moves = ms.filter_type_2_pair(all_moves, rival_move) + + elif rival_move_type == md.TYPE_3_TRIPLE: + all_moves = mg.gen_type_3_triple() + moves = ms.filter_type_3_triple(all_moves, rival_move) + + elif rival_move_type == md.TYPE_4_BOMB: + all_moves = mg.gen_type_4_bomb(4) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_4_BOMB5: + all_moves = mg.gen_type_4_bomb(5) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_4_BOMB6: + all_moves = mg.gen_type_4_bomb(6) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_4_BOMB7: + all_moves = mg.gen_type_4_bomb(7) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_4_BOMB8: + all_moves = mg.gen_type_4_bomb(8) + moves = ms.filter_type_4_bomb(all_moves, rival_move) + moves += mg.gen_type_5_king_bomb() + + elif rival_move_type == md.TYPE_5_KING_BOMB: + moves = [] + + # elif rival_move_type == md.TYPE_6_3_1: + # all_moves = mg.gen_type_6_3_1() + # moves = ms.filter_type_6_3_1(all_moves, rival_move) + + elif rival_move_type == md.TYPE_7_3_2: + all_moves = mg.gen_type_7_3_2() + moves = ms.filter_type_7_3_2(all_moves, rival_move) + + elif rival_move_type == md.TYPE_8_SERIAL_SINGLE: + all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len) + moves = ms.filter_type_8_serial_single(all_moves, rival_move) + + elif rival_move_type == md.TYPE_9_SERIAL_PAIR: + all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len) + moves = ms.filter_type_9_serial_pair(all_moves, rival_move) + + elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE: + all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len) + moves = ms.filter_type_10_serial_triple(all_moves, rival_move) + + # elif rival_move_type == md.TYPE_11_SERIAL_3_1: + # all_moves = mg.gen_type_11_serial_3_1(repeat_num=rival_move_len) + # moves = ms.filter_type_11_serial_3_1(all_moves, rival_move) + + elif rival_move_type == md.TYPE_12_SERIAL_3_2: + all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len) + moves = ms.filter_type_12_serial_3_2(all_moves, rival_move) + + # elif rival_move_type == md.TYPE_13_4_2: + # all_moves = mg.gen_type_13_4_2() + # moves = ms.filter_type_13_4_2(all_moves, rival_move) + + # elif rival_move_type == md.TYPE_14_4_22: + # all_moves = mg.gen_type_14_4_22() + # moves = ms.filter_type_14_4_22(all_moves, rival_move) + + if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB: + moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() + + if len(rival_move) != 0: # rival_move is not 'pass' + moves = moves + [[]] + + for m in moves: + m.sort() + + return moves + class GameEnv(object): def __init__(self, players): @@ -217,9 +316,6 @@ class GameEnv(object): self.info_sets[self.acting_player_position].player_hand_cards.sort() def get_legal_card_play_actions(self): - mg = MovesGener( - self.info_sets[self.acting_player_position].player_hand_cards) - action_sequence = self.card_play_action_seq rival_move = [] @@ -232,100 +328,7 @@ class GameEnv(object): else: rival_move = action_sequence[-1][1] - rival_type = md.get_move_type(rival_move) - rival_move_type = rival_type['type'] - rival_move_len = rival_type.get('len', 1) - moves = list() - - if rival_move_type == md.TYPE_0_PASS: - moves = mg.gen_moves() - - elif rival_move_type == md.TYPE_1_SINGLE: - all_moves = mg.gen_type_1_single() - moves = ms.filter_type_1_single(all_moves, rival_move) - - elif rival_move_type == md.TYPE_2_PAIR: - all_moves = mg.gen_type_2_pair() - moves = ms.filter_type_2_pair(all_moves, rival_move) - - elif rival_move_type == md.TYPE_3_TRIPLE: - all_moves = mg.gen_type_3_triple() - moves = ms.filter_type_3_triple(all_moves, rival_move) - - elif rival_move_type == md.TYPE_4_BOMB: - all_moves = mg.gen_type_4_bomb(4) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_4_BOMB5: - all_moves = mg.gen_type_4_bomb(5) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_4_BOMB6: - all_moves = mg.gen_type_4_bomb(6) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_4_BOMB7: - all_moves = mg.gen_type_4_bomb(7) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_4_BOMB8: - all_moves = mg.gen_type_4_bomb(8) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_5_KING_BOMB: - moves = [] - - # elif rival_move_type == md.TYPE_6_3_1: - # all_moves = mg.gen_type_6_3_1() - # moves = ms.filter_type_6_3_1(all_moves, rival_move) - - elif rival_move_type == md.TYPE_7_3_2: - all_moves = mg.gen_type_7_3_2() - moves = ms.filter_type_7_3_2(all_moves, rival_move) - - elif rival_move_type == md.TYPE_8_SERIAL_SINGLE: - all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len) - moves = ms.filter_type_8_serial_single(all_moves, rival_move) - - elif rival_move_type == md.TYPE_9_SERIAL_PAIR: - all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len) - moves = ms.filter_type_9_serial_pair(all_moves, rival_move) - - elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE: - all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len) - moves = ms.filter_type_10_serial_triple(all_moves, rival_move) - - # elif rival_move_type == md.TYPE_11_SERIAL_3_1: - # all_moves = mg.gen_type_11_serial_3_1(repeat_num=rival_move_len) - # moves = ms.filter_type_11_serial_3_1(all_moves, rival_move) - - elif rival_move_type == md.TYPE_12_SERIAL_3_2: - all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len) - moves = ms.filter_type_12_serial_3_2(all_moves, rival_move) - - # elif rival_move_type == md.TYPE_13_4_2: - # all_moves = mg.gen_type_13_4_2() - # moves = ms.filter_type_13_4_2(all_moves, rival_move) - - # elif rival_move_type == md.TYPE_14_4_22: - # all_moves = mg.gen_type_14_4_22() - # moves = ms.filter_type_14_4_22(all_moves, rival_move) - - if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB: - moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - if len(rival_move) != 0: # rival_move is not 'pass' - moves = moves + [[]] - - for m in moves: - m.sort() - - return moves + return get_legal_card_play_actions(self.info_sets[self.acting_player_position].player_hand_cards, rival_move) def reset(self): self.card_play_action_seq = [] diff --git a/evaluate_server.py b/evaluate_server.py index c2ae440..f3ccd49 100644 --- a/evaluate_server.py +++ b/evaluate_server.py @@ -8,6 +8,7 @@ from datetime import datetime from concurrent.futures import ThreadPoolExecutor from douzero.env.move_generator import MovesGener +from douzero.env.game import get_legal_card_play_actions from douzero.env import move_detector as md, move_selector as ms from douzero.evaluation.deep_agent import DeepAgent @@ -238,7 +239,6 @@ class InfoSet(object): self.num_cards_left = None self.num_cards_left_dict = {} self.all_handcards = {} - # self.three_landlord_cards = None self.card_play_action_seq = None self.other_hand_cards = None self.legal_actions = None @@ -248,84 +248,7 @@ class InfoSet(object): self.bomb_num = None def _get_legal_card_play_actions(player_hand_cards, rival_move): - mg = MovesGener(player_hand_cards) - - rival_type = md.get_move_type(rival_move) - rival_move_type = rival_type['type'] - rival_move_len = rival_type.get('len', 1) - moves = list() - - if rival_move_type == md.TYPE_0_PASS: - moves = mg.gen_moves() - - elif rival_move_type == md.TYPE_1_SINGLE: - all_moves = mg.gen_type_1_single() - moves = ms.filter_type_1_single(all_moves, rival_move) - - elif rival_move_type == md.TYPE_2_PAIR: - all_moves = mg.gen_type_2_pair() - moves = ms.filter_type_2_pair(all_moves, rival_move) - - elif rival_move_type == md.TYPE_3_TRIPLE: - all_moves = mg.gen_type_3_triple() - moves = ms.filter_type_3_triple(all_moves, rival_move) - - elif rival_move_type == md.TYPE_4_BOMB: - all_moves = mg.gen_type_4_bomb(4) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_4_BOMB5: - all_moves = mg.gen_type_4_bomb(5) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_4_BOMB6: - all_moves = mg.gen_type_4_bomb(6) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_4_BOMB7: - all_moves = mg.gen_type_4_bomb(7) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_4_BOMB8: - all_moves = mg.gen_type_4_bomb(8) - moves = ms.filter_type_4_bomb(all_moves, rival_move) - moves += mg.gen_type_5_king_bomb() - - elif rival_move_type == md.TYPE_5_KING_BOMB: - moves = [] - - elif rival_move_type == md.TYPE_7_3_2: - all_moves = mg.gen_type_7_3_2() - moves = ms.filter_type_7_3_2(all_moves, rival_move) - - elif rival_move_type == md.TYPE_8_SERIAL_SINGLE: - all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len) - moves = ms.filter_type_8_serial_single(all_moves, rival_move) - - elif rival_move_type == md.TYPE_9_SERIAL_PAIR: - all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len) - moves = ms.filter_type_9_serial_pair(all_moves, rival_move) - - elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE: - all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len) - moves = ms.filter_type_10_serial_triple(all_moves, rival_move) - - elif rival_move_type == md.TYPE_12_SERIAL_3_2: - all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len) - moves = ms.filter_type_12_serial_3_2(all_moves, rival_move) - - if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB: - moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb() - - if len(rival_move) != 0: # rival_move is not 'pass' - moves = moves + [[]] - - for m in moves: - m.sort() + moves = get_legal_card_play_actions(player_hand_cards, rival_move) moves.sort() moves = list(move for move,_ in itertools.groupby(moves))