重构提高代码复用

This commit is contained in:
zhiyang7 2021-12-27 09:44:40 +08:00
parent 7b21149add
commit f5994e7711
2 changed files with 102 additions and 176 deletions

197
douzero/env/game.py vendored
View File

@ -22,6 +22,105 @@ bombs[1].extend([[20, 20, 30, 30]])
# Normal bomb # Normal bomb
bombs[2].extend([[x] * 5 for x in cards_idx[:-2]]) bombs[2].extend([[x] * 5 for x in cards_idx[:-2]])
def get_legal_card_play_actions(player_hand_cards, rival_move):
mg = MovesGener(player_hand_cards)
rival_type = md.get_move_type(rival_move)
rival_move_type = rival_type['type']
rival_move_len = rival_type.get('len', 1)
moves = list()
if rival_move_type == md.TYPE_0_PASS:
moves = mg.gen_moves()
elif rival_move_type == md.TYPE_1_SINGLE:
all_moves = mg.gen_type_1_single()
moves = ms.filter_type_1_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_2_PAIR:
all_moves = mg.gen_type_2_pair()
moves = ms.filter_type_2_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_3_TRIPLE:
all_moves = mg.gen_type_3_triple()
moves = ms.filter_type_3_triple(all_moves, rival_move)
elif rival_move_type == md.TYPE_4_BOMB:
all_moves = mg.gen_type_4_bomb(4)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB5:
all_moves = mg.gen_type_4_bomb(5)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB6:
all_moves = mg.gen_type_4_bomb(6)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB7:
all_moves = mg.gen_type_4_bomb(7)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB8:
all_moves = mg.gen_type_4_bomb(8)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_5_KING_BOMB:
moves = []
# elif rival_move_type == md.TYPE_6_3_1:
# all_moves = mg.gen_type_6_3_1()
# moves = ms.filter_type_6_3_1(all_moves, rival_move)
elif rival_move_type == md.TYPE_7_3_2:
all_moves = mg.gen_type_7_3_2()
moves = ms.filter_type_7_3_2(all_moves, rival_move)
elif rival_move_type == md.TYPE_8_SERIAL_SINGLE:
all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len)
moves = ms.filter_type_8_serial_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_9_SERIAL_PAIR:
all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len)
moves = ms.filter_type_9_serial_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE:
all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len)
moves = ms.filter_type_10_serial_triple(all_moves, rival_move)
# elif rival_move_type == md.TYPE_11_SERIAL_3_1:
# all_moves = mg.gen_type_11_serial_3_1(repeat_num=rival_move_len)
# moves = ms.filter_type_11_serial_3_1(all_moves, rival_move)
elif rival_move_type == md.TYPE_12_SERIAL_3_2:
all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len)
moves = ms.filter_type_12_serial_3_2(all_moves, rival_move)
# elif rival_move_type == md.TYPE_13_4_2:
# all_moves = mg.gen_type_13_4_2()
# moves = ms.filter_type_13_4_2(all_moves, rival_move)
# elif rival_move_type == md.TYPE_14_4_22:
# all_moves = mg.gen_type_14_4_22()
# moves = ms.filter_type_14_4_22(all_moves, rival_move)
if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB:
moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
if len(rival_move) != 0: # rival_move is not 'pass'
moves = moves + [[]]
for m in moves:
m.sort()
return moves
class GameEnv(object): class GameEnv(object):
def __init__(self, players): def __init__(self, players):
@ -217,9 +316,6 @@ class GameEnv(object):
self.info_sets[self.acting_player_position].player_hand_cards.sort() self.info_sets[self.acting_player_position].player_hand_cards.sort()
def get_legal_card_play_actions(self): def get_legal_card_play_actions(self):
mg = MovesGener(
self.info_sets[self.acting_player_position].player_hand_cards)
action_sequence = self.card_play_action_seq action_sequence = self.card_play_action_seq
rival_move = [] rival_move = []
@ -232,100 +328,7 @@ class GameEnv(object):
else: else:
rival_move = action_sequence[-1][1] rival_move = action_sequence[-1][1]
rival_type = md.get_move_type(rival_move) return get_legal_card_play_actions(self.info_sets[self.acting_player_position].player_hand_cards, rival_move)
rival_move_type = rival_type['type']
rival_move_len = rival_type.get('len', 1)
moves = list()
if rival_move_type == md.TYPE_0_PASS:
moves = mg.gen_moves()
elif rival_move_type == md.TYPE_1_SINGLE:
all_moves = mg.gen_type_1_single()
moves = ms.filter_type_1_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_2_PAIR:
all_moves = mg.gen_type_2_pair()
moves = ms.filter_type_2_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_3_TRIPLE:
all_moves = mg.gen_type_3_triple()
moves = ms.filter_type_3_triple(all_moves, rival_move)
elif rival_move_type == md.TYPE_4_BOMB:
all_moves = mg.gen_type_4_bomb(4)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB5:
all_moves = mg.gen_type_4_bomb(5)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB6:
all_moves = mg.gen_type_4_bomb(6)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB7:
all_moves = mg.gen_type_4_bomb(7)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB8:
all_moves = mg.gen_type_4_bomb(8)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_5_KING_BOMB:
moves = []
# elif rival_move_type == md.TYPE_6_3_1:
# all_moves = mg.gen_type_6_3_1()
# moves = ms.filter_type_6_3_1(all_moves, rival_move)
elif rival_move_type == md.TYPE_7_3_2:
all_moves = mg.gen_type_7_3_2()
moves = ms.filter_type_7_3_2(all_moves, rival_move)
elif rival_move_type == md.TYPE_8_SERIAL_SINGLE:
all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len)
moves = ms.filter_type_8_serial_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_9_SERIAL_PAIR:
all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len)
moves = ms.filter_type_9_serial_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE:
all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len)
moves = ms.filter_type_10_serial_triple(all_moves, rival_move)
# elif rival_move_type == md.TYPE_11_SERIAL_3_1:
# all_moves = mg.gen_type_11_serial_3_1(repeat_num=rival_move_len)
# moves = ms.filter_type_11_serial_3_1(all_moves, rival_move)
elif rival_move_type == md.TYPE_12_SERIAL_3_2:
all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len)
moves = ms.filter_type_12_serial_3_2(all_moves, rival_move)
# elif rival_move_type == md.TYPE_13_4_2:
# all_moves = mg.gen_type_13_4_2()
# moves = ms.filter_type_13_4_2(all_moves, rival_move)
# elif rival_move_type == md.TYPE_14_4_22:
# all_moves = mg.gen_type_14_4_22()
# moves = ms.filter_type_14_4_22(all_moves, rival_move)
if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB:
moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
if len(rival_move) != 0: # rival_move is not 'pass'
moves = moves + [[]]
for m in moves:
m.sort()
return moves
def reset(self): def reset(self):
self.card_play_action_seq = [] self.card_play_action_seq = []

View File

@ -8,6 +8,7 @@ from datetime import datetime
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from douzero.env.move_generator import MovesGener from douzero.env.move_generator import MovesGener
from douzero.env.game import get_legal_card_play_actions
from douzero.env import move_detector as md, move_selector as ms from douzero.env import move_detector as md, move_selector as ms
from douzero.evaluation.deep_agent import DeepAgent from douzero.evaluation.deep_agent import DeepAgent
@ -238,7 +239,6 @@ class InfoSet(object):
self.num_cards_left = None self.num_cards_left = None
self.num_cards_left_dict = {} self.num_cards_left_dict = {}
self.all_handcards = {} self.all_handcards = {}
# self.three_landlord_cards = None
self.card_play_action_seq = None self.card_play_action_seq = None
self.other_hand_cards = None self.other_hand_cards = None
self.legal_actions = None self.legal_actions = None
@ -248,84 +248,7 @@ class InfoSet(object):
self.bomb_num = None self.bomb_num = None
def _get_legal_card_play_actions(player_hand_cards, rival_move): def _get_legal_card_play_actions(player_hand_cards, rival_move):
mg = MovesGener(player_hand_cards) moves = get_legal_card_play_actions(player_hand_cards, rival_move)
rival_type = md.get_move_type(rival_move)
rival_move_type = rival_type['type']
rival_move_len = rival_type.get('len', 1)
moves = list()
if rival_move_type == md.TYPE_0_PASS:
moves = mg.gen_moves()
elif rival_move_type == md.TYPE_1_SINGLE:
all_moves = mg.gen_type_1_single()
moves = ms.filter_type_1_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_2_PAIR:
all_moves = mg.gen_type_2_pair()
moves = ms.filter_type_2_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_3_TRIPLE:
all_moves = mg.gen_type_3_triple()
moves = ms.filter_type_3_triple(all_moves, rival_move)
elif rival_move_type == md.TYPE_4_BOMB:
all_moves = mg.gen_type_4_bomb(4)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB5:
all_moves = mg.gen_type_4_bomb(5)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB6:
all_moves = mg.gen_type_4_bomb(6)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB7:
all_moves = mg.gen_type_4_bomb(7)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_4_BOMB8:
all_moves = mg.gen_type_4_bomb(8)
moves = ms.filter_type_4_bomb(all_moves, rival_move)
moves += mg.gen_type_5_king_bomb()
elif rival_move_type == md.TYPE_5_KING_BOMB:
moves = []
elif rival_move_type == md.TYPE_7_3_2:
all_moves = mg.gen_type_7_3_2()
moves = ms.filter_type_7_3_2(all_moves, rival_move)
elif rival_move_type == md.TYPE_8_SERIAL_SINGLE:
all_moves = mg.gen_type_8_serial_single(repeat_num=rival_move_len)
moves = ms.filter_type_8_serial_single(all_moves, rival_move)
elif rival_move_type == md.TYPE_9_SERIAL_PAIR:
all_moves = mg.gen_type_9_serial_pair(repeat_num=rival_move_len)
moves = ms.filter_type_9_serial_pair(all_moves, rival_move)
elif rival_move_type == md.TYPE_10_SERIAL_TRIPLE:
all_moves = mg.gen_type_10_serial_triple(repeat_num=rival_move_len)
moves = ms.filter_type_10_serial_triple(all_moves, rival_move)
elif rival_move_type == md.TYPE_12_SERIAL_3_2:
all_moves = mg.gen_type_12_serial_3_2(repeat_num=rival_move_len)
moves = ms.filter_type_12_serial_3_2(all_moves, rival_move)
if rival_move_type != md.TYPE_0_PASS and rival_move_type < md.TYPE_4_BOMB:
moves = moves + mg.gen_type_4_bomb(4) + mg.gen_type_4_bomb(5) + mg.gen_type_4_bomb(6) + mg.gen_type_4_bomb(7) + mg.gen_type_4_bomb(8) + mg.gen_type_5_king_bomb()
if len(rival_move) != 0: # rival_move is not 'pass'
moves = moves + [[]]
for m in moves:
m.sort()
moves.sort() moves.sort()
moves = list(move for move,_ in itertools.groupby(moves)) moves = list(move for move,_ in itertools.groupby(moves))