diff --git a/douzero/env/env.py b/douzero/env/env.py index c774018..85c9d4b 100644 --- a/douzero/env/env.py +++ b/douzero/env/env.py @@ -581,7 +581,7 @@ def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num, use_legacy) + infoset.bomb_num, use_legacy, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -696,7 +696,7 @@ def _get_obs_landlord_up(infoset, use_legacy = False, compressed_form = False): num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num, use_legacy) + infoset.bomb_num, use_legacy, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -817,7 +817,7 @@ def _get_obs_landlord_front(infoset, use_legacy = False, compressed_form = False num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num, use_legacy) + infoset.bomb_num, use_legacy, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -938,7 +938,7 @@ def _get_obs_landlord_down(infoset, use_legacy = False, compressed_form = False) num_legal_actions, axis=0) bomb_num = _get_one_hot_bomb( - infoset.bomb_num, use_legacy) + infoset.bomb_num, use_legacy, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0) @@ -1022,7 +1022,7 @@ def _get_obs_general(infoset, position, compressed_form = False): infoset.played_cards['landlord_down'], compressed_form) bomb_num = _get_one_hot_bomb( - infoset.bomb_num) + infoset.bomb_num, compressed_form=compressed_form) bomb_num_batch = np.repeat( bomb_num[np.newaxis, :], num_legal_actions, axis=0)