重构提高代码复用
This commit is contained in:
parent
7b21149add
commit
f5994e7711
|
@ -22,6 +22,105 @@ bombs[1].extend([[20, 20, 30, 30]])
|
|||
# Normal bomb
|
||||
bombs[2].extend([[x] * 5 for x in cards_idx[:-2]])
|
||||
|
||||
|
||||
def get_legal_card_play_actions(player_hand_cards, rival_move):
|
||||
mg = MovesGener(player_hand_cards)
|
||||
|
||||
rival_type = md.get_move_type(rival_move)
|
||||
rival_move_type = rival_type['type']
|
||||
rival_move_len = rival_type.get('len', 1)
|
||||
moves = list()
|
||||
|
||||
if rival_move_type == md.TYPE_0_PASS:
|
||||
moves = mg.gen_moves()
|
||||
|
||||
elif rival_move_type == md.TYPE_1_SINGLE:
|
||||
all_moves = mg.gen_type_1_single()
|
||||
moves = ms.filter_type_1_single(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_2_PAIR:
|
||||
all_moves = mg.gen_type_2_pair()
|
||||
moves = ms.filter_type_2_pair(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_3_TRIPLE:
|
||||
all_moves = mg.gen_type_3_triple()
|
||||
moves = ms.filter_type_3_triple(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB:
|
||||
all_moves = mg.gen_type_4_bomb(4)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB5:
|
||||
all_moves = mg.gen_type_4_bomb(5)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB6:
|
||||
all_moves = mg.gen_type_4_bomb(6)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB7:
|
||||
all_moves = mg.gen_type_4_bomb(7)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB8:
|
||||
all_moves = mg.gen_type_4_bomb(8)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_5_KING_BOMB:
|
||||
moves = []
|
||||
|
||||
# elif rival_move_type == md.TYPE_6_3_1:
|
||||
# all_moves = mg.gen_type_6_3_1()
|
||||
# moves = ms.filter_type_6_3_1(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_7_3_2:
|
||||
all_moves = mg.gen_type_7_3_2()
|
||||
moves = ms.filter_type_7_3_2(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_8_SERIAL_SINGLE:
|
||||
all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_8_serial_single(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_9_SERIAL_PAIR:
|
||||
all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_9_serial_pair(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE:
|
||||
all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_10_serial_triple(all_moves, rival_move)
|
||||
|
||||
# elif rival_move_type == md.TYPE_11_SERIAL_3_1:
|
||||
# all_moves = mg.gen_type_11_serial_3_1(repeat_num=rival_move_len)
|
||||
# moves = ms.filter_type_11_serial_3_1(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_12_SERIAL_3_2:
|
||||
all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_12_serial_3_2(all_moves, rival_move)
|
||||
|
||||
# elif rival_move_type == md.TYPE_13_4_2:
|
||||
# all_moves = mg.gen_type_13_4_2()
|
||||
# moves = ms.filter_type_13_4_2(all_moves, rival_move)
|
||||
|
||||
# elif rival_move_type == md.TYPE_14_4_22:
|
||||
# all_moves = mg.gen_type_14_4_22()
|
||||
# moves = ms.filter_type_14_4_22(all_moves, rival_move)
|
||||
|
||||
if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB:
|
||||
moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
if len(rival_move) != 0: # rival_move is not 'pass'
|
||||
moves = moves + [[]]
|
||||
|
||||
for m in moves:
|
||||
m.sort()
|
||||
|
||||
return moves
|
||||
|
||||
class GameEnv(object):
|
||||
|
||||
def __init__(self, players):
|
||||
|
@ -217,9 +316,6 @@ class GameEnv(object):
|
|||
self.info_sets[self.acting_player_position].player_hand_cards.sort()
|
||||
|
||||
def get_legal_card_play_actions(self):
|
||||
mg = MovesGener(
|
||||
self.info_sets[self.acting_player_position].player_hand_cards)
|
||||
|
||||
action_sequence = self.card_play_action_seq
|
||||
|
||||
rival_move = []
|
||||
|
@ -232,100 +328,7 @@ class GameEnv(object):
|
|||
else:
|
||||
rival_move = action_sequence[-1][1]
|
||||
|
||||
rival_type = md.get_move_type(rival_move)
|
||||
rival_move_type = rival_type['type']
|
||||
rival_move_len = rival_type.get('len', 1)
|
||||
moves = list()
|
||||
|
||||
if rival_move_type == md.TYPE_0_PASS:
|
||||
moves = mg.gen_moves()
|
||||
|
||||
elif rival_move_type == md.TYPE_1_SINGLE:
|
||||
all_moves = mg.gen_type_1_single()
|
||||
moves = ms.filter_type_1_single(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_2_PAIR:
|
||||
all_moves = mg.gen_type_2_pair()
|
||||
moves = ms.filter_type_2_pair(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_3_TRIPLE:
|
||||
all_moves = mg.gen_type_3_triple()
|
||||
moves = ms.filter_type_3_triple(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB:
|
||||
all_moves = mg.gen_type_4_bomb(4)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB5:
|
||||
all_moves = mg.gen_type_4_bomb(5)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB6:
|
||||
all_moves = mg.gen_type_4_bomb(6)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB7:
|
||||
all_moves = mg.gen_type_4_bomb(7)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB8:
|
||||
all_moves = mg.gen_type_4_bomb(8)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_5_KING_BOMB:
|
||||
moves = []
|
||||
|
||||
# elif rival_move_type == md.TYPE_6_3_1:
|
||||
# all_moves = mg.gen_type_6_3_1()
|
||||
# moves = ms.filter_type_6_3_1(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_7_3_2:
|
||||
all_moves = mg.gen_type_7_3_2()
|
||||
moves = ms.filter_type_7_3_2(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_8_SERIAL_SINGLE:
|
||||
all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_8_serial_single(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_9_SERIAL_PAIR:
|
||||
all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_9_serial_pair(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE:
|
||||
all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_10_serial_triple(all_moves, rival_move)
|
||||
|
||||
# elif rival_move_type == md.TYPE_11_SERIAL_3_1:
|
||||
# all_moves = mg.gen_type_11_serial_3_1(repeat_num=rival_move_len)
|
||||
# moves = ms.filter_type_11_serial_3_1(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_12_SERIAL_3_2:
|
||||
all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_12_serial_3_2(all_moves, rival_move)
|
||||
|
||||
# elif rival_move_type == md.TYPE_13_4_2:
|
||||
# all_moves = mg.gen_type_13_4_2()
|
||||
# moves = ms.filter_type_13_4_2(all_moves, rival_move)
|
||||
|
||||
# elif rival_move_type == md.TYPE_14_4_22:
|
||||
# all_moves = mg.gen_type_14_4_22()
|
||||
# moves = ms.filter_type_14_4_22(all_moves, rival_move)
|
||||
|
||||
if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB:
|
||||
moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
if len(rival_move) != 0: # rival_move is not 'pass'
|
||||
moves = moves + [[]]
|
||||
|
||||
for m in moves:
|
||||
m.sort()
|
||||
|
||||
return moves
|
||||
return get_legal_card_play_actions(self.info_sets[self.acting_player_position].player_hand_cards, rival_move)
|
||||
|
||||
def reset(self):
|
||||
self.card_play_action_seq = []
|
||||
|
|
|
@ -8,6 +8,7 @@ from datetime import datetime
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from douzero.env.move_generator import MovesGener
|
||||
from douzero.env.game import get_legal_card_play_actions
|
||||
from douzero.env import move_detector as md, move_selector as ms
|
||||
from douzero.evaluation.deep_agent import DeepAgent
|
||||
|
||||
|
@ -238,7 +239,6 @@ class InfoSet(object):
|
|||
self.num_cards_left = None
|
||||
self.num_cards_left_dict = {}
|
||||
self.all_handcards = {}
|
||||
# self.three_landlord_cards = None
|
||||
self.card_play_action_seq = None
|
||||
self.other_hand_cards = None
|
||||
self.legal_actions = None
|
||||
|
@ -248,84 +248,7 @@ class InfoSet(object):
|
|||
self.bomb_num = None
|
||||
|
||||
def _get_legal_card_play_actions(player_hand_cards, rival_move):
|
||||
mg = MovesGener(player_hand_cards)
|
||||
|
||||
rival_type = md.get_move_type(rival_move)
|
||||
rival_move_type = rival_type['type']
|
||||
rival_move_len = rival_type.get('len', 1)
|
||||
moves = list()
|
||||
|
||||
if rival_move_type == md.TYPE_0_PASS:
|
||||
moves = mg.gen_moves()
|
||||
|
||||
elif rival_move_type == md.TYPE_1_SINGLE:
|
||||
all_moves = mg.gen_type_1_single()
|
||||
moves = ms.filter_type_1_single(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_2_PAIR:
|
||||
all_moves = mg.gen_type_2_pair()
|
||||
moves = ms.filter_type_2_pair(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_3_TRIPLE:
|
||||
all_moves = mg.gen_type_3_triple()
|
||||
moves = ms.filter_type_3_triple(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB:
|
||||
all_moves = mg.gen_type_4_bomb(4)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB5:
|
||||
all_moves = mg.gen_type_4_bomb(5)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB6:
|
||||
all_moves = mg.gen_type_4_bomb(6)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB7:
|
||||
all_moves = mg.gen_type_4_bomb(7)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_4_BOMB8:
|
||||
all_moves = mg.gen_type_4_bomb(8)
|
||||
moves = ms.filter_type_4_bomb(all_moves, rival_move)
|
||||
moves += mg.gen_type_5_king_bomb()
|
||||
|
||||
elif rival_move_type == md.TYPE_5_KING_BOMB:
|
||||
moves = []
|
||||
|
||||
elif rival_move_type == md.TYPE_7_3_2:
|
||||
all_moves = mg.gen_type_7_3_2()
|
||||
moves = ms.filter_type_7_3_2(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_8_SERIAL_SINGLE:
|
||||
all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_8_serial_single(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_9_SERIAL_PAIR:
|
||||
all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_9_serial_pair(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE:
|
||||
all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_10_serial_triple(all_moves, rival_move)
|
||||
|
||||
elif rival_move_type == md.TYPE_12_SERIAL_3_2:
|
||||
all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len)
|
||||
moves = ms.filter_type_12_serial_3_2(all_moves, rival_move)
|
||||
|
||||
if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB:
|
||||
moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
|
||||
|
||||
if len(rival_move) != 0: # rival_move is not 'pass'
|
||||
moves = moves + [[]]
|
||||
|
||||
for m in moves:
|
||||
m.sort()
|
||||
moves = get_legal_card_play_actions(player_hand_cards, rival_move)
|
||||
|
||||
moves.sort()
|
||||
moves = list(move for move,_ in itertools.groupby(moves))
|
||||
|
|
Loading…
Reference in New Issue