From 8980a973244777261032b86d25b5d60b0a7fa85c Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Wed, 5 Jan 2022 11:22:11 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBUG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/env/env.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/douzero/env/env.py b/douzero/env/env.py index c774018..85c9d4b 100644 --- a/douzero/env/env.py +++ b/douzero/env/env.py @@ -581,7 +581,7 @@ def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num, use_legacy) + infoset.bomb_num, use_legacy, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -696,7 +696,7 @@ def _get_obs_landlord_up(infoset, use_legacy = False, compressed_form = False): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num, use_legacy) + infoset.bomb_num, use_legacy, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -817,7 +817,7 @@ def _get_obs_landlord_front(infoset, use_legacy = False, compressed_form = False num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num, use_legacy) + infoset.bomb_num, use_legacy, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -938,7 +938,7 @@ def _get_obs_landlord_down(infoset, use_legacy = False, compressed_form = False) num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num, use_legacy) + infoset.bomb_num, use_legacy, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -1022,7 +1022,7 @@ def _get_obs_general(infoset, position, compressed_form = False): infoset.played_cards['landlord_down'], compressed_form) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0)