修复BUG

This commit is contained in:
zhiyang7 2022-01-05 11:22:11 +08:00
parent f0883bbf31
commit 8980a97324
1 changed files with 5 additions and 5 deletions

10
douzero/env/env.py vendored
View File

@ -581,7 +581,7 @@ def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num, use_legacy) infoset.bomb_num, use_legacy, compressed_form=compressed_form)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -696,7 +696,7 @@ def _get_obs_landlord_up(infoset, use_legacy = False, compressed_form = False):
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num, use_legacy) infoset.bomb_num, use_legacy, compressed_form=compressed_form)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -817,7 +817,7 @@ def _get_obs_landlord_front(infoset, use_legacy = False, compressed_form = False
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num, use_legacy) infoset.bomb_num, use_legacy, compressed_form=compressed_form)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -938,7 +938,7 @@ def _get_obs_landlord_down(infoset, use_legacy = False, compressed_form = False)
num_legal_actions, axis=0) num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num, use_legacy) infoset.bomb_num, use_legacy, compressed_form=compressed_form)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)
@ -1022,7 +1022,7 @@ def _get_obs_general(infoset, position, compressed_form = False):
infoset.played_cards['landlord_down'], compressed_form) infoset.played_cards['landlord_down'], compressed_form)
bomb_num = _get_one_hot_bomb( bomb_num = _get_one_hot_bomb(
infoset.bomb_num) infoset.bomb_num, compressed_form=compressed_form)
bomb_num_batch = np.repeat( bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :], bomb_num[np.newaxis, :],
num_legal_actions, axis=0) num_legal_actions, axis=0)