提速降费

This commit is contained in:
zhiyang7 2021-12-21 15:34:27 +08:00
parent 84775d52e6
commit 42066ae7a9
1 changed files with 1 additions and 31 deletions

32
douzero/env/env.py vendored
View File

@ -355,7 +355,7 @@ def _action_seq_list2array(action_seq_list, new_model=True):
if new_model:
# position_map = {"landlord": 0, "landlord_up": 1, "landlord_front": 2, "landlord_down": 3}
action_seq_array = np.ones((len(action_seq_list), 108)) * -1 # Default Value -1 for not using area
action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :108] = _cards2array(list_cards[1])
@ -876,12 +876,6 @@ def _get_obs_general(infoset, position):
num_legal_actions, axis=0)
other_handcards = _cards2array(infoset.other_hand_cards)
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
num_legal_actions, axis=0)
last_action = _cards2array(infoset.last_move)
last_action_batch = np.repeat(last_action[np.newaxis, :],
num_legal_actions, axis=0)
my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions):
@ -889,27 +883,15 @@ def _get_obs_general(infoset, position):
landlord_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord'], 33)
landlord_num_cards_left_batch = np.repeat(
landlord_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
landlord_up_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_up'], 25)
landlord_up_num_cards_left_batch = np.repeat(
landlord_up_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
landlord_front_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_front'], 25)
landlord_front_num_cards_left_batch = np.repeat(
landlord_front_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
landlord_down_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_down'], 25)
landlord_down_num_cards_left_batch = np.repeat(
landlord_down_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
other_handcards_left_list = []
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
@ -918,27 +900,15 @@ def _get_obs_general(infoset, position):
landlord_played_cards = _cards2array(
infoset.played_cards['landlord'])
landlord_played_cards_batch = np.repeat(
landlord_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
landlord_up_played_cards = _cards2array(
infoset.played_cards['landlord_up'])
landlord_up_played_cards_batch = np.repeat(
landlord_up_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
landlord_front_played_cards = _cards2array(
infoset.played_cards['landlord_front'])
landlord_front_played_cards_batch = np.repeat(
landlord_front_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
landlord_down_played_cards = _cards2array(
infoset.played_cards['landlord_down'])
landlord_down_played_cards_batch = np.repeat(
landlord_down_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num)