From 84775d52e66a4a386284c66fcbb9d0ae2acea0d6 Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Tue, 21 Dec 2021 14:58:28 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E9=80=9F=E9=99=8D=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/utils.py | 27 ++++++++++++++++++++------- douzero/env/env.py | 27 ++++++++++++++++++++------- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/douzero/dmc/utils.py b/douzero/dmc/utils.py index ac0025b..7df335d 100644 --- a/douzero/dmc/utils.py +++ b/douzero/dmc/utils.py @@ -26,6 +26,16 @@ NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]), 7: np.array([1, 1, 1, 1, 1, 1, 1, 0]), 8: np.array([1, 1, 1, 1, 1, 1, 1, 1])} +NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]), + 1: np.array([1, 0, 0, 0, 0, 0, 0, 0]), + 3: np.array([1, 1, 0, 0, 0, 0, 0, 0]), + 4: np.array([0, 0, 1, 0, 0, 0, 0, 0]), + 5: np.array([1, 0, 1, 0, 0, 0, 0, 0]), + 7: np.array([1, 1, 1, 0, 0, 0, 0, 0]), + 12: np.array([0, 0, 1, 1, 0, 0, 0, 0]), + 13: np.array([1, 0, 1, 1, 0, 0, 0, 0]), + 15: np.array([1, 1, 1, 1, 0, 0, 0, 0])} + shandle = logging.StreamHandler() shandle.setFormatter( logging.Formatter( @@ -195,20 +205,23 @@ def _cards2tensor(list_cards): if len(list_cards) == 0: return torch.zeros(108, dtype=torch.int8) - matrix = np.zeros([8, 13], dtype=np.int8) - jokers = np.zeros(4, dtype=np.int8) + matrix = np.zeros([8, 14], dtype=np.int8) counter = Counter(list_cards) + joker_cnt = 0 for card, num_times in counter.items(): if card < 20: matrix[:, Card2Column[card]] = NumOnes2Array[num_times] elif card == 20: - jokers[0] = 1 if num_times == 2: - jokers[1] = 1 + joker_cnt |= 0b11 + else: + joker_cnt |= 0b01 elif card == 30: - jokers[2] = 1 if num_times == 2: - jokers[3] = 1 - matrix = np.concatenate((matrix.flatten('F'), jokers)) + joker_cnt |= 0b1100 + else: + joker_cnt |= 0b0100 + matrix[:, 13] = NumOnesJoker2Array[joker_cnt] + matrix = matrix.flatten('F')[:-4] matrix = torch.from_numpy(matrix) return matrix diff --git a/douzero/env/env.py b/douzero/env/env.py index 822ee32..9faab63 100644 --- a/douzero/env/env.py +++ b/douzero/env/env.py @@ -20,6 +20,16 @@ NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]), 7: np.array([1, 1, 1, 1, 1, 1, 1, 0]), 8: np.array([1, 1, 1, 1, 1, 1, 1, 1])} +NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]), + 1: np.array([1, 0, 0, 0, 0, 0, 0, 0]), + 3: np.array([1, 1, 0, 0, 0, 0, 0, 0]), + 4: np.array([0, 0, 1, 0, 0, 0, 0, 0]), + 5: np.array([1, 0, 1, 0, 0, 0, 0, 0]), + 7: np.array([1, 1, 1, 0, 0, 0, 0, 0]), + 12: np.array([0, 0, 1, 1, 0, 0, 0, 0]), + 13: np.array([1, 0, 1, 1, 0, 0, 0, 0]), + 15: np.array([1, 1, 1, 1, 0, 0, 0, 0])} + deck = [] for i in range(3, 15): @@ -296,21 +306,24 @@ def _cards2array(list_cards): if len(list_cards) == 0: return np.zeros(108, dtype=np.int8) - matrix = np.zeros([8, 13], dtype=np.int8) - jokers = np.zeros(4, dtype=np.int8) + matrix = np.zeros([8, 14], dtype=np.int8) counter = Counter(list_cards) + joker_cnt = 0 for card, num_times in counter.items(): if card < 20: matrix[:, Card2Column[card]] = NumOnes2Array[num_times] elif card == 20: - jokers[0] = 1 if num_times == 2: - jokers[1] = 1 + joker_cnt |= 0b11 + else: + joker_cnt |= 0b01 elif card == 30: - jokers[2] = 1 if num_times == 2: - jokers[3] = 1 - return np.concatenate((matrix.flatten('F'), jokers)) + joker_cnt |= 0b1100 + else: + joker_cnt |= 0b0100 + matrix[:, 13] = NumOnesJoker2Array[joker_cnt] + return matrix.flatten('F')[:-4] # def _action_seq_list2array(action_seq_list):