Douzero_Resnet/douzero/env/env.py

1018 lines
41 KiB
Python
Raw Normal View History

2021-09-07 17:19:25 +08:00
from collections import Counter
import numpy as np
import random
import torch
from douzero.env.game import GameEnv
env_version = "3.2"
env_url = "http://od.vcccz.com/hechuan/env.py"
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
11: 8, 12: 9, 13: 10, 14: 11, 17: 12}
2021-12-05 12:03:30 +08:00
NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
2: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
3: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
4: np.array([1, 1, 1, 1, 0, 0, 0, 0]),
5: np.array([1, 1, 1, 1, 1, 0, 0, 0]),
6: np.array([1, 1, 1, 1, 1, 1, 0, 0]),
7: np.array([1, 1, 1, 1, 1, 1, 1, 0]),
8: np.array([1, 1, 1, 1, 1, 1, 1, 1])}
2021-12-21 14:58:28 +08:00
NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
3: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
4: np.array([0, 0, 1, 0, 0, 0, 0, 0]),
5: np.array([1, 0, 1, 0, 0, 0, 0, 0]),
7: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
12: np.array([0, 0, 1, 1, 0, 0, 0, 0]),
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
2021-12-22 21:19:10 +08:00
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
2: np.array([1, 1, 0, 0, 0]),
3: np.array([1, 1, 1, 0, 0]),
4: np.array([1, 1, 1, 1, 0]),
5: np.array([1, 0, 0, 0, 1]),
6: np.array([1, 1, 0, 0, 1]),
7: np.array([1, 1, 1, 0, 1]),
8: np.array([1, 1, 1, 1, 1])}
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
3: np.array([1, 1, 0, 0, 0]),
4: np.array([0, 0, 1, 0, 0]),
5: np.array([1, 0, 1, 0, 0]),
7: np.array([1, 1, 1, 0, 0]),
12: np.array([0, 0, 1, 1, 0]),
13: np.array([1, 0, 1, 1, 0]),
15: np.array([1, 1, 1, 1, 0])}
2021-09-07 17:19:25 +08:00
deck = []
for i in range(3, 15):
2021-12-05 12:03:30 +08:00
deck.extend([i for _ in range(8)])
deck.extend([17 for _ in range(8)])
deck.extend([20, 20, 30, 30])
2021-09-07 17:19:25 +08:00
class Env:
"""
Doudizhu multi-agent wrapper
"""
2021-12-22 21:19:10 +08:00
def __init__(self, objective, old_model, legacy_model=False, lite_model = False):
2021-09-07 17:19:25 +08:00
"""
Objective is wp/adp/logadp. It indicates whether considers
bomb in reward calculation. Here, we use dummy agents.
This is because, in the orignial game, the players
are `in` the game. Here, we want to isolate
players and environments to have a more gym style
interface. To achieve this, we use dummy players
to play. For each move, we tell the corresponding
dummy player which action to play, then the player
will perform the actual action in the game engine.
"""
self.objective = objective
2021-12-20 16:04:33 +08:00
self.use_legacy = legacy_model
self.use_general = not old_model
2021-12-22 21:19:10 +08:00
self.lite_model = lite_model
2021-09-07 17:19:25 +08:00
# Initialize players
# We use three dummy player for the target position
self.players = {}
2021-12-05 12:03:30 +08:00
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
2021-09-07 17:19:25 +08:00
self.players[position] = DummyAgent(position)
# Initialize the internal environment
self._env = GameEnv(self.players)
self.total_round = 0
self.infoset = None
def reset(self, model, device, flags=None):
"""
Every time reset is called, the environment
will be re-initialized with a new deck of cards.
This function is usually called when a game is over.
"""
self._env.reset()
# Randomly shuffle the deck
if model is None:
_deck = deck.copy()
np.random.shuffle(_deck)
2021-12-05 12:03:30 +08:00
card_play_data = {'landlord': _deck[:33],
'landlord_up': _deck[33:58],
'landlord_front': _deck[58:83],
'landlord_down': _deck[83:108],
# 'three_landlord_cards': _deck[17:20],
2021-09-07 17:19:25 +08:00
}
for key in card_play_data:
card_play_data[key].sort()
self._env.card_play_init(card_play_data)
self.infoset = self._game_infoset
2021-12-22 21:19:10 +08:00
return get_obs(self.infoset, self.use_general, self.lite_model)
2021-09-07 17:19:25 +08:00
else:
self.total_round += 1
2021-12-19 17:19:32 +08:00
_deck = deck.copy()
np.random.shuffle(_deck)
card_play_data = {'landlord': _deck[:33],
'landlord_up': _deck[33:58],
'landlord_front': _deck[58:83],
'landlord_down': _deck[83:108],
2021-09-07 17:19:25 +08:00
}
2021-12-19 17:19:32 +08:00
for key in card_play_data:
card_play_data[key].sort()
2021-09-07 17:19:25 +08:00
player_ids = {
2021-12-19 17:19:32 +08:00
'landlord': 0,
'landlord_down': 1,
'landlord_front': 2,
'landlord_up': 3,
2021-09-07 17:19:25 +08:00
}
# Initialize the cards
self._env.card_play_init(card_play_data)
2021-12-05 12:03:30 +08:00
for pos in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
2021-09-07 17:19:25 +08:00
pid = player_ids[pos]
self._env.info_sets[pos].player_id = pid
self.infoset = self._game_infoset
2021-12-19 17:19:32 +08:00
2021-12-22 21:19:10 +08:00
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
2021-09-07 17:19:25 +08:00
def step(self, action):
"""
Step function takes as input the action, which
is a list of integers, and output the next obervation,
reward, and a Boolean variable indicating whether the
current game is finished. It also returns an empty
dictionary that is reserved to pass useful information.
"""
assert action in self.infoset.legal_actions
self.players[self._acting_player_position].set_action(action)
self._env.step()
self.infoset = self._game_infoset
done = False
reward = 0.0
if self._game_over:
done = True
reward = {
"play": {
"landlord": self._get_reward("landlord"),
"landlord_up": self._get_reward("landlord_up"),
2021-12-05 12:03:30 +08:00
"landlord_front": self._get_reward("landlord_front"),
2021-09-07 17:19:25 +08:00
"landlord_down": self._get_reward("landlord_down")
}
}
obs = None
else:
2021-12-22 21:19:10 +08:00
obs = get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
2021-09-07 17:19:25 +08:00
return obs, reward, done, {}
def _get_reward(self, pos):
"""
This function is called in the end of each
game. It returns either 1/-1 for win/loss,
or ADP, i.e., every bomb will double the score.
"""
winner = self._game_winner
bomb_num = self._game_bomb_num
self_bomb_num = self._env.pos_bomb_num[pos]
if winner == 'landlord':
if self.objective == 'adp':
2021-12-20 16:04:33 +08:00
return (1.1 - self._env.step_count * 0.0005) * (1.3 ** bomb_num[0]) * (1.95 ** bomb_num[1]) / 8
2021-09-07 17:19:25 +08:00
elif self.objective == 'logadp':
2021-12-20 16:04:33 +08:00
return (1.0 - self._env.step_count * 0.0005) * 1.3**self_bomb_num / 4
2021-09-07 17:19:25 +08:00
else:
2021-12-20 16:04:33 +08:00
return 1.0 - self._env.step_count * 0.0005
2021-09-07 17:19:25 +08:00
else:
if self.objective == 'adp':
2021-12-20 16:04:33 +08:00
return (-1.1 + self._env.step_count * 0.0005) * (1.3 ** bomb_num[0]) * (1.95 ** bomb_num[1]) / 8
2021-09-07 17:19:25 +08:00
elif self.objective == 'logadp':
2021-12-20 16:04:33 +08:00
return (-1.0 + self._env.step_count * 0.0005) * 1.3**(self_bomb_num) / 4
2021-09-07 17:19:25 +08:00
else:
2021-12-20 16:04:33 +08:00
return -1.0 + self._env.step_count * 0.0005
2021-09-07 17:19:25 +08:00
@property
def _game_infoset(self):
"""
Here, inforset is defined as all the information
in the current situation, incuding the hand cards
of all the players, all the historical moves, etc.
That is, it contains perferfect infomation. Later,
we will use functions to extract the observable
information from the views of the three players.
"""
return self._env.game_infoset
@property
def _game_bomb_num(self):
"""
The number of bombs played so far. This is used as
a feature of the neural network and is also used to
calculate ADP.
"""
return self._env.get_bomb_num()
@property
def _game_winner(self):
""" A string of landlord/peasants
"""
return self._env.get_winner()
@property
def _acting_player_position(self):
"""
The player that is active. It can be landlord,
landlod_down, or landlord_up.
"""
return self._env.acting_player_position
@property
def _game_over(self):
""" Returns a Boolean
"""
return self._env.game_over
class DummyAgent(object):
"""
Dummy agent is designed to easily interact with the
game engine. The agent will first be told what action
to perform. Then the environment will call this agent
to perform the actual action. This can help us to
isolate environment and agents towards a gym like
interface.
"""
def __init__(self, position):
self.position = position
self.action = None
def act(self, infoset):
"""
Simply return the action that is set previously.
"""
assert self.action in infoset.legal_actions
return self.action
def set_action(self, action):
"""
The environment uses this function to tell
the dummy agent what to do.
"""
self.action = action
2021-12-22 21:19:10 +08:00
def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False):
2021-09-07 17:19:25 +08:00
"""
This function obtains observations with imperfect information
from the infoset. It has three branches since we encode
different features for different positions.
2021-12-05 12:03:30 +08:00
2021-09-07 17:19:25 +08:00
This function will return dictionary named `obs`. It contains
several fields. These fields will be used to train the model.
One can play with those features to improve the performance.
2021-12-05 12:03:30 +08:00
`position` is a string that can be landlord/landlord_down/landlord_front/landlord_up
2021-09-07 17:19:25 +08:00
`x_batch` is a batch of features (excluding the hisorical moves).
It also encodes the action feature
`z_batch` is a batch of features with hisorical moves only.
`legal_actions` is the legal moves
`x_no_action`: the features (exluding the hitorical moves and
the action features). It does not have the batch dim.
`z`: same as z_batch but not a batch.
"""
if use_general:
2021-12-05 12:03:30 +08:00
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
2021-09-07 17:19:25 +08:00
raise ValueError('')
2021-12-22 21:19:10 +08:00
return _get_obs_general(infoset, infoset.player_position, lite_model)
2021-09-07 17:19:25 +08:00
else:
if infoset.player_position == 'landlord':
2021-12-22 21:19:10 +08:00
return _get_obs_landlord(infoset, use_legacy, lite_model)
2021-09-07 17:19:25 +08:00
elif infoset.player_position == 'landlord_up':
2021-12-22 21:19:10 +08:00
return _get_obs_landlord_up(infoset, use_legacy, lite_model)
2021-12-05 12:03:30 +08:00
elif infoset.player_position == 'landlord_front':
2021-12-22 21:19:10 +08:00
return _get_obs_landlord_front(infoset, use_legacy, lite_model)
2021-09-07 17:19:25 +08:00
elif infoset.player_position == 'landlord_down':
2021-12-22 21:19:10 +08:00
return _get_obs_landlord_down(infoset, use_legacy, lite_model)
2021-09-07 17:19:25 +08:00
else:
raise ValueError('')
2021-12-22 21:19:10 +08:00
def _get_one_hot_array(num_left_cards, max_num_cards, compress_size = 0):
2021-09-07 17:19:25 +08:00
"""
A utility function to obtain one-hot endoding
"""
2021-12-22 21:19:10 +08:00
if compress_size > 0:
assert compress_size <= max_num_cards / 2
array_size = max_num_cards - compress_size
one_hot = np.zeros(array_size)
if num_left_cards >= array_size:
one_hot[-1] = 1
num_left_cards -= array_size
if num_left_cards > 0:
one_hot[num_left_cards - 1] = 1
else:
one_hot = np.zeros(max_num_cards)
if num_left_cards > 0:
one_hot[num_left_cards - 1] = 1
2021-09-07 17:19:25 +08:00
return one_hot
2021-12-22 21:19:10 +08:00
def _cards2array(list_cards, compressed_form = False):
2021-09-07 17:19:25 +08:00
"""
A utility function that transforms the actions, i.e.,
A list of integers into card matrix. Here we remove
the six entries that are always zero and flatten the
the representations.
"""
2021-12-22 21:19:10 +08:00
if compressed_form:
if len(list_cards) == 0:
return np.zeros(69, dtype=np.int8)
matrix = np.zeros([5, 14], dtype=np.int8)
counter = Counter(list_cards)
joker_cnt = 0
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
elif card == 20:
if num_times == 2:
joker_cnt |= 0b11
else:
joker_cnt |= 0b01
elif card == 30:
if num_times == 2:
joker_cnt |= 0b1100
else:
joker_cnt |= 0b0100
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
return matrix.flatten('F')[:-1]
else:
if len(list_cards) == 0:
return np.zeros(108, dtype=np.int8)
matrix = np.zeros([8, 14], dtype=np.int8)
counter = Counter(list_cards)
joker_cnt = 0
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
elif card == 20:
if num_times == 2:
joker_cnt |= 0b11
else:
joker_cnt |= 0b01
elif card == 30:
if num_times == 2:
joker_cnt |= 0b1100
else:
joker_cnt |= 0b0100
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
return matrix.flatten('F')[:-4]
2021-09-07 17:19:25 +08:00
# def _action_seq_list2array(action_seq_list):
# """
# A utility function to encode the historical moves.
# We encode the historical 15 actions. If there is
# no 15 actions, we pad the features with 0. Since
# three moves is a round in DouDizhu, we concatenate
# the representations for each consecutive three moves.
# Finally, we obtain a 5x162 matrix, which will be fed
# into LSTM for encoding.
# """
2021-12-05 12:03:30 +08:00
# action_seq_array = np.zeros((len(action_seq_list), 108))
2021-09-07 17:19:25 +08:00
# for row, list_cards in enumerate(action_seq_list):
# action_seq_array[row, :] = _cards2array(list_cards)
# # action_seq_array = action_seq_array.reshape(5, 162)
# return action_seq_array
2021-12-22 21:19:10 +08:00
def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = False):
2021-09-07 17:19:25 +08:00
"""
A utility function to encode the historical moves.
2021-12-05 12:03:30 +08:00
We encode the historical 20 actions. If there is
no 20 actions, we pad the features with 0. Since
2021-09-07 17:19:25 +08:00
three moves is a round in DouDizhu, we concatenate
the representations for each consecutive three moves.
2021-12-05 12:03:30 +08:00
Finally, we obtain a 5x432 matrix, which will be fed
2021-09-07 17:19:25 +08:00
into LSTM for encoding.
"""
if new_model:
2021-12-05 12:03:30 +08:00
# position_map = {"landlord": 0, "landlord_up": 1, "landlord_front": 2, "landlord_down": 3}
2021-12-22 21:19:10 +08:00
if compressed_form:
action_seq_array = np.full((len(action_seq_list), 69), -1) # Default Value -1 for not using area
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :69] = _cards2array(list_cards[1], compressed_form)
else:
action_seq_array = np.full((len(action_seq_list), 108), -1) # Default Value -1 for not using area
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :108] = _cards2array(list_cards[1], compressed_form)
2021-09-07 17:19:25 +08:00
else:
2021-12-22 21:19:10 +08:00
if compressed_form:
action_seq_array = np.zeros((len(action_seq_list), 69))
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
action_seq_array = action_seq_array.reshape(5, 276)
else:
action_seq_array = np.zeros((len(action_seq_list), 108))
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
action_seq_array = action_seq_array.reshape(5, 432)
2021-09-07 17:19:25 +08:00
return action_seq_array
# action_seq_array = np.zeros((len(action_seq_list), 54))
# for row, list_cards in enumerate(action_seq_list):
# if list_cards != []:
# action_seq_array[row, :] = _cards2array(list_cards[1])
# return action_seq_array
2021-12-05 12:03:30 +08:00
def _process_action_seq(sequence, length=20, new_model=True):
2021-09-07 17:19:25 +08:00
"""
A utility function encoding historical moves. We
2021-12-05 12:03:30 +08:00
encode 20 moves. If there is no 20 moves, we pad
2021-09-07 17:19:25 +08:00
with zeros.
"""
sequence = sequence[-length:].copy()
if new_model:
sequence = sequence[::-1]
if len(sequence) < length:
empty_sequence = [[] for _ in range(length - len(sequence))]
empty_sequence.extend(sequence)
sequence = empty_sequence
return sequence
2021-12-20 10:02:55 +08:00
def _get_one_hot_bomb(bomb_num, use_legacy = False):
2021-09-07 17:19:25 +08:00
"""
A utility function to encode the number of bombs
into one-hot representation.
"""
2021-12-20 10:02:55 +08:00
if use_legacy:
one_hot = np.zeros(29)
one_hot[bomb_num[0] + bomb_num[1]] = 1
else:
one_hot = np.zeros(56) # 14 + 15 + 27
one_hot[bomb_num[0]] = 1
one_hot[14 + bomb_num[1]] = 1
one_hot[29 + bomb_num[2]] = 1
2021-09-07 17:19:25 +08:00
return one_hot
2021-12-22 21:19:10 +08:00
def _get_obs_landlord(infoset, use_legacy = False, compressed_form = False):
2021-09-07 17:19:25 +08:00
"""
Obttain the landlord features. See Table 4 in
https://arxiv.org/pdf/2106.06135.pdf
"""
num_legal_actions = len(infoset.legal_actions)
2021-12-22 21:19:10 +08:00
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
2021-09-07 17:19:25 +08:00
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-22 21:19:10 +08:00
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
2021-09-07 17:19:25 +08:00
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-22 21:19:10 +08:00
last_action = _cards2array(infoset.last_move, compressed_form)
2021-09-07 17:19:25 +08:00
last_action_batch = np.repeat(last_action[np.newaxis, :],
num_legal_actions, axis=0)
my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions):
2021-12-22 21:19:10 +08:00
my_action_batch[j, :] = _cards2array(action, compressed_form)
2021-09-07 17:19:25 +08:00
landlord_up_num_cards_left = _get_one_hot_array(
2021-12-23 09:55:49 +08:00
infoset.num_cards_left_dict['landlord_up'], 25, 8 if compressed_form else 0)
2021-09-07 17:19:25 +08:00
landlord_up_num_cards_left_batch = np.repeat(
landlord_up_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-05 12:03:30 +08:00
landlord_front_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_front'], 25, 8 if compressed_form else 0)
2021-12-05 12:03:30 +08:00
landlord_front_num_cards_left_batch = np.repeat(
landlord_front_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
2021-09-07 17:19:25 +08:00
landlord_down_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_down'], 25, 8 if compressed_form else 0)
2021-09-07 17:19:25 +08:00
landlord_down_num_cards_left_batch = np.repeat(
landlord_down_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
landlord_up_played_cards = _cards2array(
2021-12-23 09:34:03 +08:00
infoset.played_cards['landlord_up'], compressed_form)
2021-09-07 17:19:25 +08:00
landlord_up_played_cards_batch = np.repeat(
landlord_up_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-05 12:03:30 +08:00
landlord_front_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_front'], compressed_form)
2021-12-05 12:03:30 +08:00
landlord_front_played_cards_batch = np.repeat(
landlord_front_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
2021-09-07 17:19:25 +08:00
landlord_down_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_down'], compressed_form)
2021-09-07 17:19:25 +08:00
landlord_down_played_cards_batch = np.repeat(
landlord_down_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
2021-12-20 10:02:55 +08:00
infoset.bomb_num, use_legacy)
2021-09-07 17:19:25 +08:00
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
x_batch = np.hstack((my_handcards_batch,
other_handcards_batch,
last_action_batch,
landlord_up_played_cards_batch,
2021-12-05 12:03:30 +08:00
landlord_front_played_cards_batch,
2021-09-07 17:19:25 +08:00
landlord_down_played_cards_batch,
landlord_up_num_cards_left_batch,
2021-12-05 12:03:30 +08:00
landlord_front_num_cards_left_batch,
2021-09-07 17:19:25 +08:00
landlord_down_num_cards_left_batch,
bomb_num_batch,
my_action_batch))
x_no_action = np.hstack((my_handcards,
other_handcards,
last_action,
landlord_up_played_cards,
2021-12-05 12:03:30 +08:00
landlord_front_played_cards,
2021-09-07 17:19:25 +08:00
landlord_down_played_cards,
landlord_up_num_cards_left,
2021-12-05 12:03:30 +08:00
landlord_front_num_cards_left,
2021-09-07 17:19:25 +08:00
landlord_down_num_cards_left,
bomb_num))
z = _action_seq_list2array(_process_action_seq(
2021-12-22 21:19:10 +08:00
infoset.card_play_action_seq, 20, False), False, compressed_form)
2021-09-07 17:19:25 +08:00
z_batch = np.repeat(
z[np.newaxis, :, :],
num_legal_actions, axis=0)
obs = {
2021-12-05 12:03:30 +08:00
'position': 'landlord',
'x_batch': x_batch.astype(np.float32),
'z_batch': z_batch.astype(np.float32),
'legal_actions': infoset.legal_actions,
'x_no_action': x_no_action.astype(np.int8),
'z': z.astype(np.int8),
}
2021-09-07 17:19:25 +08:00
return obs
2021-12-22 21:19:10 +08:00
def _get_obs_landlord_up(infoset, use_legacy = False, compressed_form = False):
2021-09-07 17:19:25 +08:00
"""
Obttain the landlord_up features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf
"""
num_legal_actions = len(infoset.legal_actions)
2021-12-22 21:19:10 +08:00
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
2021-09-07 17:19:25 +08:00
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-22 21:19:10 +08:00
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
2021-09-07 17:19:25 +08:00
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-22 21:19:10 +08:00
last_action = _cards2array(infoset.last_move, compressed_form)
2021-09-07 17:19:25 +08:00
last_action_batch = np.repeat(last_action[np.newaxis, :],
num_legal_actions, axis=0)
my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions):
2021-12-22 21:19:10 +08:00
my_action_batch[j, :] = _cards2array(action, compressed_form)
2021-09-07 17:19:25 +08:00
last_landlord_action = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.last_move_dict['landlord'], compressed_form)
2021-09-07 17:19:25 +08:00
last_landlord_action_batch = np.repeat(
last_landlord_action[np.newaxis, :],
num_legal_actions, axis=0)
landlord_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord'], 33, 15 if compressed_form else 0)
2021-09-07 17:19:25 +08:00
landlord_num_cards_left_batch = np.repeat(
landlord_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
landlord_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord'], compressed_form)
2021-09-07 17:19:25 +08:00
landlord_played_cards_batch = np.repeat(
landlord_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
last_teammate_action = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.last_move_dict['landlord_down'], compressed_form)
2021-09-07 17:19:25 +08:00
last_teammate_action_batch = np.repeat(
last_teammate_action[np.newaxis, :],
num_legal_actions, axis=0)
teammate_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_down'], 25, 8 if compressed_form else 0)
2021-09-07 17:19:25 +08:00
teammate_num_cards_left_batch = np.repeat(
teammate_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
teammate_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_down'], compressed_form)
2021-09-07 17:19:25 +08:00
teammate_played_cards_batch = np.repeat(
teammate_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-05 12:03:30 +08:00
last_teammate_front_action = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.last_move_dict['landlord_front'], compressed_form)
2021-12-05 12:03:30 +08:00
last_teammate_front_action_batch = np.repeat(
last_teammate_front_action[np.newaxis, :],
num_legal_actions, axis=0)
teammate_front_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_front'], 25, 8 if compressed_form else 0)
2021-12-05 12:03:30 +08:00
teammate_front_num_cards_left_batch = np.repeat(
teammate_front_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
teammate_front_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_front'], compressed_form)
2021-12-05 12:03:30 +08:00
teammate_front_played_cards_batch = np.repeat(
teammate_front_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
2021-09-07 17:19:25 +08:00
bomb_num = _get_one_hot_bomb(
2021-12-20 10:02:55 +08:00
infoset.bomb_num, use_legacy)
2021-09-07 17:19:25 +08:00
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
x_batch = np.hstack((my_handcards_batch,
other_handcards_batch,
landlord_played_cards_batch,
teammate_played_cards_batch,
2021-12-05 12:03:30 +08:00
teammate_front_played_cards_batch,
2021-09-07 17:19:25 +08:00
last_action_batch,
last_landlord_action_batch,
last_teammate_action_batch,
2021-12-05 12:03:30 +08:00
last_teammate_front_action_batch,
2021-09-07 17:19:25 +08:00
landlord_num_cards_left_batch,
teammate_num_cards_left_batch,
2021-12-05 12:03:30 +08:00
teammate_front_num_cards_left_batch,
2021-09-07 17:19:25 +08:00
bomb_num_batch,
my_action_batch))
x_no_action = np.hstack((my_handcards,
other_handcards,
landlord_played_cards,
teammate_played_cards,
2021-12-05 12:03:30 +08:00
teammate_front_played_cards,
2021-09-07 17:19:25 +08:00
last_action,
last_landlord_action,
last_teammate_action,
2021-12-05 12:03:30 +08:00
last_teammate_front_action,
2021-09-07 17:19:25 +08:00
landlord_num_cards_left,
teammate_num_cards_left,
2021-12-05 12:03:30 +08:00
teammate_front_num_cards_left,
2021-09-07 17:19:25 +08:00
bomb_num))
z = _action_seq_list2array(_process_action_seq(
2021-12-22 21:19:10 +08:00
infoset.card_play_action_seq, 20, False), False, compressed_form)
2021-09-07 17:19:25 +08:00
z_batch = np.repeat(
z[np.newaxis, :, :],
num_legal_actions, axis=0)
obs = {
2021-12-05 12:03:30 +08:00
'position': 'landlord_up',
'x_batch': x_batch.astype(np.float32),
'z_batch': z_batch.astype(np.float32),
'legal_actions': infoset.legal_actions,
'x_no_action': x_no_action.astype(np.int8),
'z': z.astype(np.int8),
}
2021-09-07 17:19:25 +08:00
return obs
2021-12-22 21:19:10 +08:00
def _get_obs_landlord_front(infoset, use_legacy = False, compressed_form = False):
2021-09-07 17:19:25 +08:00
"""
2021-12-05 12:03:30 +08:00
Obttain the landlord_front features. See Table 5 in
2021-09-07 17:19:25 +08:00
https://arxiv.org/pdf/2106.06135.pdf
"""
num_legal_actions = len(infoset.legal_actions)
2021-12-22 21:19:10 +08:00
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
2021-09-07 17:19:25 +08:00
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-22 21:19:10 +08:00
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
2021-09-07 17:19:25 +08:00
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-22 21:19:10 +08:00
last_action = _cards2array(infoset.last_move, compressed_form)
2021-09-07 17:19:25 +08:00
last_action_batch = np.repeat(last_action[np.newaxis, :],
num_legal_actions, axis=0)
my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions):
2021-12-22 21:19:10 +08:00
my_action_batch[j, :] = _cards2array(action, compressed_form)
2021-09-07 17:19:25 +08:00
last_landlord_action = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.last_move_dict['landlord'], compressed_form)
2021-09-07 17:19:25 +08:00
last_landlord_action_batch = np.repeat(
last_landlord_action[np.newaxis, :],
num_legal_actions, axis=0)
landlord_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord'], 33, 15 if compressed_form else 0)
2021-09-07 17:19:25 +08:00
landlord_num_cards_left_batch = np.repeat(
landlord_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
landlord_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord'], compressed_form)
2021-09-07 17:19:25 +08:00
landlord_played_cards_batch = np.repeat(
landlord_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
last_teammate_action = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.last_move_dict['landlord_down'], compressed_form)
2021-09-07 17:19:25 +08:00
last_teammate_action_batch = np.repeat(
last_teammate_action[np.newaxis, :],
num_legal_actions, axis=0)
teammate_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_down'], 25, 8 if compressed_form else 0)
2021-09-07 17:19:25 +08:00
teammate_num_cards_left_batch = np.repeat(
teammate_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
teammate_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_down'], compressed_form)
2021-09-07 17:19:25 +08:00
teammate_played_cards_batch = np.repeat(
teammate_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-05 12:03:30 +08:00
last_teammate_front_action = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.last_move_dict['landlord_front'], compressed_form)
2021-12-05 12:03:30 +08:00
last_teammate_front_action_batch = np.repeat(
last_teammate_front_action[np.newaxis, :],
num_legal_actions, axis=0)
teammate_front_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_front'], 25, 8 if compressed_form else 0)
2021-12-05 12:03:30 +08:00
teammate_front_num_cards_left_batch = np.repeat(
teammate_front_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
teammate_front_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_front'], compressed_form)
2021-12-05 12:03:30 +08:00
teammate_front_played_cards_batch = np.repeat(
teammate_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
bomb_num = _get_one_hot_bomb(
2021-12-20 10:02:55 +08:00
infoset.bomb_num, use_legacy)
2021-12-05 12:03:30 +08:00
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
x_batch = np.hstack((my_handcards_batch,
other_handcards_batch,
landlord_played_cards_batch,
teammate_played_cards_batch,
teammate_front_played_cards_batch,
last_action_batch,
last_landlord_action_batch,
last_teammate_action_batch,
last_teammate_front_action_batch,
landlord_num_cards_left_batch,
teammate_num_cards_left_batch,
teammate_front_num_cards_left_batch,
bomb_num_batch,
my_action_batch))
x_no_action = np.hstack((my_handcards,
other_handcards,
landlord_played_cards,
teammate_played_cards,
teammate_front_played_cards,
last_action,
last_landlord_action,
last_teammate_action,
last_teammate_front_action,
landlord_num_cards_left,
teammate_num_cards_left,
teammate_front_num_cards_left,
bomb_num))
z = _action_seq_list2array(_process_action_seq(
2021-12-22 21:19:10 +08:00
infoset.card_play_action_seq, 20, False), False, compressed_form)
2021-12-05 12:03:30 +08:00
z_batch = np.repeat(
z[np.newaxis, :, :],
num_legal_actions, axis=0)
obs = {
'position': 'landlord_front',
'x_batch': x_batch.astype(np.float32),
'z_batch': z_batch.astype(np.float32),
'legal_actions': infoset.legal_actions,
'x_no_action': x_no_action.astype(np.int8),
'z': z.astype(np.int8),
}
return obs
2021-12-22 21:19:10 +08:00
def _get_obs_landlord_down(infoset, use_legacy = False, compressed_form = False):
2021-12-05 12:03:30 +08:00
"""
Obttain the landlord_down features. See Table 5 in
https://arxiv.org/pdf/2106.06135.pdf
"""
num_legal_actions = len(infoset.legal_actions)
2021-12-22 21:19:10 +08:00
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
2021-12-05 12:03:30 +08:00
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-22 21:19:10 +08:00
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
2021-12-05 12:03:30 +08:00
other_handcards_batch = np.repeat(other_handcards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-22 21:19:10 +08:00
last_action = _cards2array(infoset.last_move, compressed_form)
2021-12-05 12:03:30 +08:00
last_action_batch = np.repeat(last_action[np.newaxis, :],
num_legal_actions, axis=0)
my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions):
2021-12-22 21:19:10 +08:00
my_action_batch[j, :] = _cards2array(action, compressed_form)
2021-12-05 12:03:30 +08:00
last_landlord_action = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.last_move_dict['landlord'], compressed_form)
2021-12-05 12:03:30 +08:00
last_landlord_action_batch = np.repeat(
last_landlord_action[np.newaxis, :],
num_legal_actions, axis=0)
landlord_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord'], 33, 15 if compressed_form else 0)
2021-12-05 12:03:30 +08:00
landlord_num_cards_left_batch = np.repeat(
landlord_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
2021-09-07 17:19:25 +08:00
landlord_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord'], compressed_form)
2021-09-07 17:19:25 +08:00
landlord_played_cards_batch = np.repeat(
landlord_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-05 12:03:30 +08:00
last_teammate_action = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.last_move_dict['landlord_up'], compressed_form)
2021-12-05 12:03:30 +08:00
last_teammate_action_batch = np.repeat(
last_teammate_action[np.newaxis, :],
num_legal_actions, axis=0)
teammate_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_up'], 25, 8 if compressed_form else 0)
2021-12-05 12:03:30 +08:00
teammate_num_cards_left_batch = np.repeat(
teammate_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
teammate_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_up'], compressed_form)
2021-12-05 12:03:30 +08:00
teammate_played_cards_batch = np.repeat(
teammate_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
last_teammate_front_action = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.last_move_dict['landlord_front'], compressed_form)
2021-12-05 12:03:30 +08:00
last_teammate_front_action_batch = np.repeat(
last_teammate_front_action[np.newaxis, :],
num_legal_actions, axis=0)
teammate_front_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_front'], 25, 8 if compressed_form else 0)
2021-12-05 12:03:30 +08:00
teammate_front_num_cards_left_batch = np.repeat(
teammate_front_num_cards_left[np.newaxis, :],
num_legal_actions, axis=0)
teammate_front_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_front'], compressed_form)
2021-12-05 12:03:30 +08:00
teammate_front_played_cards_batch = np.repeat(
teammate_front_played_cards[np.newaxis, :],
num_legal_actions, axis=0)
2021-09-07 17:19:25 +08:00
bomb_num = _get_one_hot_bomb(
2021-12-20 10:02:55 +08:00
infoset.bomb_num, use_legacy)
2021-09-07 17:19:25 +08:00
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
x_batch = np.hstack((my_handcards_batch,
other_handcards_batch,
landlord_played_cards_batch,
teammate_played_cards_batch,
2021-12-05 12:03:30 +08:00
teammate_front_played_cards_batch,
2021-09-07 17:19:25 +08:00
last_action_batch,
last_landlord_action_batch,
last_teammate_action_batch,
2021-12-05 12:03:30 +08:00
last_teammate_front_action_batch,
2021-09-07 17:19:25 +08:00
landlord_num_cards_left_batch,
teammate_num_cards_left_batch,
2021-12-05 12:03:30 +08:00
teammate_front_num_cards_left_batch,
2021-09-07 17:19:25 +08:00
bomb_num_batch,
my_action_batch))
x_no_action = np.hstack((my_handcards,
other_handcards,
landlord_played_cards,
teammate_played_cards,
2021-12-05 12:03:30 +08:00
teammate_front_played_cards,
2021-09-07 17:19:25 +08:00
last_action,
last_landlord_action,
last_teammate_action,
2021-12-05 12:03:30 +08:00
last_teammate_front_action,
2021-09-07 17:19:25 +08:00
landlord_num_cards_left,
teammate_num_cards_left,
2021-12-05 12:03:30 +08:00
teammate_front_num_cards_left,
2021-09-07 17:19:25 +08:00
bomb_num))
z = _action_seq_list2array(_process_action_seq(
2021-12-22 21:19:10 +08:00
infoset.card_play_action_seq, 20, False), False, compressed_form)
2021-09-07 17:19:25 +08:00
z_batch = np.repeat(
z[np.newaxis, :, :],
num_legal_actions, axis=0)
obs = {
2021-12-05 12:03:30 +08:00
'position': 'landlord_down',
'x_batch': x_batch.astype(np.float32),
'z_batch': z_batch.astype(np.float32),
'legal_actions': infoset.legal_actions,
'x_no_action': x_no_action.astype(np.int8),
'z': z.astype(np.int8),
}
2021-09-07 17:19:25 +08:00
return obs
2021-12-22 21:19:10 +08:00
def _get_obs_general(infoset, position, compressed_form = False):
2021-09-07 17:19:25 +08:00
num_legal_actions = len(infoset.legal_actions)
2021-12-22 21:19:10 +08:00
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
2021-09-07 17:19:25 +08:00
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0)
2021-12-22 21:19:10 +08:00
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
2021-09-07 17:19:25 +08:00
my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions):
2021-12-22 21:19:10 +08:00
my_action_batch[j, :] = _cards2array(action, compressed_form)
2021-09-07 17:19:25 +08:00
landlord_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord'], 33, 15 if compressed_form else 0)
2021-09-07 17:19:25 +08:00
landlord_up_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_up'], 25, 8 if compressed_form else 0)
2021-09-07 17:19:25 +08:00
2021-12-05 12:03:30 +08:00
landlord_front_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_front'], 25, 8 if compressed_form else 0)
2021-12-05 12:03:30 +08:00
2021-09-07 17:19:25 +08:00
landlord_down_num_cards_left = _get_one_hot_array(
2021-12-23 09:34:03 +08:00
infoset.num_cards_left_dict['landlord_down'], 25, 8 if compressed_form else 0)
2021-09-07 17:19:25 +08:00
landlord_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord'], compressed_form)
2021-09-07 17:19:25 +08:00
landlord_up_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_up'], compressed_form)
2021-09-07 17:19:25 +08:00
2021-12-05 12:03:30 +08:00
landlord_front_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_front'], compressed_form)
2021-12-05 12:03:30 +08:00
2021-09-07 17:19:25 +08:00
landlord_down_played_cards = _cards2array(
2021-12-22 21:19:10 +08:00
infoset.played_cards['landlord_down'], compressed_form)
2021-09-07 17:19:25 +08:00
bomb_num = _get_one_hot_bomb(
2021-12-19 17:19:32 +08:00
infoset.bomb_num)
2021-09-07 17:19:25 +08:00
bomb_num_batch = np.repeat(
bomb_num[np.newaxis, :],
num_legal_actions, axis=0)
num_cards_left = np.hstack((
2021-12-22 21:19:10 +08:00
landlord_num_cards_left, # 33/18
landlord_up_num_cards_left, # 25/17
landlord_front_num_cards_left, # 25/17
landlord_down_num_cards_left)) # 25/17
2021-09-07 17:19:25 +08:00
2021-12-19 17:19:32 +08:00
x_batch = np.hstack((
bomb_num_batch, # 56
)) # 4
x_no_action = np.hstack((
bomb_num, # 56
))
2021-09-07 17:19:25 +08:00
z =np.vstack((
2021-12-22 21:19:10 +08:00
num_cards_left, # 108 / 18+17*3=69
my_handcards, # 108/69
other_handcards, # 108/69
landlord_played_cards, # 108/69
landlord_up_played_cards, # 108/69
landlord_front_played_cards, # 108/69
landlord_down_played_cards, # 108/69
_action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 32), True, compressed_form)
2021-09-07 17:19:25 +08:00
))
_z_batch = np.repeat(
z[np.newaxis, :, :],
num_legal_actions, axis=0)
my_action_batch = my_action_batch[:,np.newaxis,:]
2021-12-24 10:12:55 +08:00
z_batch = np.concatenate((my_action_batch, _z_batch), axis=1)
2021-09-07 17:19:25 +08:00
obs = {
'position': position,
'x_batch': x_batch.astype(np.float32),
'z_batch': z_batch.astype(np.float32),
'legal_actions': infoset.legal_actions,
'x_no_action': x_no_action.astype(np.int8),
'z': z.astype(np.int8),
}
return obs