unified_model

This commit is contained in:
zhiyang7 2022-01-04 18:15:35 +08:00
parent 2b8586303b
commit 08e05dbc83
7 changed files with 367 additions and 28 deletions

View File

@ -31,6 +31,8 @@ parser.add_argument('--load_model', action='store_true',
help='Load an existing model') help='Load an existing model')
parser.add_argument('--old_model', action='store_true', parser.add_argument('--old_model', action='store_true',
help='Use vanilla model') help='Use vanilla model')
parser.add_argument('--unified_model', action='store_true',
help='Use unified model')
parser.add_argument('--lite_model', action='store_true', parser.add_argument('--lite_model', action='store_true',
help='Use lite card model') help='Use lite card model')
parser.add_argument('--lagacy_model', action='store_true', parser.add_argument('--lagacy_model', action='store_true',

View File

@ -13,8 +13,8 @@ from torch import nn
import douzero.dmc.models import douzero.dmc.models
import douzero.env.env import douzero.env.env
from .file_writer import FileWriter from .file_writer import FileWriter
from .models import Model, OldModel from .models import Model, OldModel, UnifiedModel
from .utils import get_batch, log, create_env, create_optimizers, act, infer_logic from .utils import get_batch, log, create_optimizers, act, infer_logic
import psutil import psutil
import shutil import shutil
import requests import requests
@ -87,6 +87,8 @@ def train(flags):
# Initialize actor models # Initialize actor models
if flags.old_model: if flags.old_model:
actor_model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model) actor_model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model)
elif flags.unified_model:
actor_model = UnifiedModel(device="cpu", flags = flags, lite_model = flags.lite_model)
else: else:
actor_model = Model(device="cpu", flags = flags, lite_model = flags.lite_model) actor_model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
actor_model.eval() actor_model.eval()
@ -100,6 +102,8 @@ def train(flags):
# Learner model for training # Learner model for training
if flags.old_model: if flags.old_model:
learner_model = OldModel(device=flags.training_device, lite_model = flags.lite_model) learner_model = OldModel(device=flags.training_device, lite_model = flags.lite_model)
elif flags.unified_model:
learner_model = UnifiedModel(device=flags.training_device, lite_model = flags.lite_model)
else: else:
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model) learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
@ -255,6 +259,8 @@ def train(flags):
type = '' type = ''
if flags.old_model: if flags.old_model:
type += 'vanilla' type += 'vanilla'
elif flags.unified_model:
type += 'unified'
else: else:
type += 'resnet' type += 'resnet'
requests.post(flags.upload_url, data={ requests.post(flags.upload_url, data={

View File

@ -400,6 +400,74 @@ class GeneralModelLite(nn.Module):
return dict(values=out) return dict(values=out)
class UnifiedModelLite(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 30
#input 1*69*15
self.conv1 = nn.Conv1d(15, 30, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*35*30
self.bn1 = nn.BatchNorm1d(30)
self.layer1 = self._make_layer(BasicBlock, 30, 2, stride=2)#1*18*30
self.layer2 = self._make_layer(BasicBlock, 60, 2, stride=2)#1*9*60
self.layer3 = self._make_layer(BasicBlock, 120, 2, stride=2)#1*5*120
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.lstm = nn.LSTM(276, 128, batch_first=True)
self.linear1 = nn.Linear(120 * BasicBlock.expansion * 5 + 128, 2048)
self.linear2 = nn.Linear(2048, 1024)
self.linear3 = nn.Linear(1024, 512)
self.linear4 = nn.Linear(512, 256)
self.linear5 = nn.Linear(256, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def get_onnx_params(self, device=None):
return {
'args': (
torch.randn(1, 40, 69, requires_grad=True, device=device),
torch.randn(1, 56, requires_grad=True, device=device),
),
'input_names': ['z_batch','x_batch'],
'output_names': ['values'],
'dynamic_axes': {
'z_batch': {
0: "batch_size"
},
'x_batch': {
0: "batch_size"
},
'values': {
0: "batch_size"
}
}
}
def forward(self, z, x):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = out.flatten(1,2)
lstm_out, (h_n, _) = self.lstm(x)
lstm_out = lstm_out[:,-1,:]
out = torch.hstack([lstm_out, out])
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
out = F.leaky_relu_(self.linear5(out))
return dict(values=out)
class GeneralModel(nn.Module): class GeneralModel(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -675,3 +743,61 @@ class Model:
def get_models(self): def get_models(self):
return self.models return self.models
class UnifiedModel:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0, flags=None, lite_model = False):
self.onnx_models = {}
self.model = None
self.models = {}
self.flags = flags
if not device == "cpu":
device = 'cuda:' + str(device)
self.device = torch.device(device)
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if flags is not None and flags.enable_onnx:
self.model = None
else:
if lite_model:
self.model = UnifiedModelLite().to(self.device)
for position in positions:
self.models[position] = self.model
else:
self.model = GeneralModel().to(self.device)
self.onnx_model = None
def set_onnx_model(self, device='cpu'):
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid))
if device == 'cpu':
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
else:
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
def get_onnx_params(self, position):
self.model.get_onnx_params(self.device)
def forward(self, position, z, x, device='cpu', return_value=False, flags=None):
return forward_logic(self, position, z, x, device, return_value, flags)
def share_memory(self):
if self.model is not None:
self.model.share_memory()
def eval(self):
if self.model is not None:
self.model.eval()
def parameters(self, position):
return self.model.parameters()
def get_model(self, position):
return self.model
def get_models(self):
return {
'uni' : self.model
}

View File

@ -73,7 +73,7 @@ log.setLevel(logging.INFO)
Buffers = typing.Dict[str, typing.List[torch.Tensor]] Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags): def create_env(flags):
return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model) return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model, flags.unified_model)
def get_batch(b_queues, position, flags, lock): def get_batch(b_queues, position, flags, lock):
""" """
@ -116,8 +116,11 @@ def create_optimizers(flags, learner_model):
def infer_logic(i, device, infer_queues, model, flags, onnx_frame): def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if not flags.enable_onnx: if not flags.enable_onnx:
for pos in positions: if flags.unified_model:
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device)))) model.model.to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
else:
for pos in positions:
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
last_onnx_frame = -1 last_onnx_frame = -1
log.info('Infer %i started.', i) log.info('Infer %i started.', i)

209
douzero/env/env.py vendored
View File

@ -49,7 +49,47 @@ NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
12: np.array([0, 0, 1, 1, 0]), 12: np.array([0, 0, 1, 1, 0]),
13: np.array([1, 0, 1, 1, 0]), 13: np.array([1, 0, 1, 1, 0]),
15: np.array([1, 1, 1, 1, 0])} 15: np.array([1, 1, 1, 1, 0])}
PositionInfoArray = {
'landlord': np.array([1, 0, 0, 0]),
'landlord_down': np.array([0, 1, 0, 0]),
'landlord_front': np.array([0, 0, 1, 0]),
'landlord_up': np.array([0, 0, 0, 1]),
}
FaceUpLevelArray = {
0x00: np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]),
0x01: np.array([1, 0, 0, 0, 0, 0, 0, 0, 0]),
0x02: np.array([0, 1, 0, 0, 0, 0, 0, 0, 0]),
0x03: np.array([1, 1, 0, 0, 0, 0, 0, 0, 0]),
0x04: np.array([0, 0, 1, 0, 0, 0, 0, 0, 0]),
0x05: np.array([1, 0, 1, 0, 0, 0, 0, 0, 0]),
0x06: np.array([0, 1, 1, 0, 0, 0, 0, 0, 0]),
0x07: np.array([1, 1, 1, 0, 0, 0, 0, 0, 0]),
0x08: np.array([0, 0, 0, 1, 0, 0, 0, 0, 0]),
0x09: np.array([1, 0, 0, 1, 0, 0, 0, 0, 0]),
0x0A: np.array([0, 1, 0, 1, 0, 0, 0, 0, 0]),
0x0B: np.array([1, 1, 0, 1, 0, 0, 0, 0, 0]),
0x0C: np.array([0, 0, 1, 1, 0, 0, 0, 0, 0]),
0x0D: np.array([1, 0, 1, 1, 0, 0, 0, 0, 0]),
0x0E: np.array([0, 1, 1, 1, 0, 0, 0, 0, 0]),
0x0F: np.array([1, 1, 1, 1, 0, 0, 0, 0, 0]),
0x10: np.array([0, 0, 0, 0, 1, 0, 0, 0, 0]),
0x11: np.array([1, 0, 0, 0, 1, 0, 0, 0, 0]),
0x12: np.array([0, 1, 0, 0, 1, 0, 0, 0, 0]),
0x13: np.array([1, 1, 0, 0, 1, 0, 0, 0, 0]),
0x14: np.array([0, 0, 1, 0, 1, 0, 0, 0, 0]),
0x15: np.array([1, 0, 1, 0, 1, 0, 0, 0, 0]),
0x16: np.array([0, 1, 1, 0, 1, 0, 0, 0, 0]),
0x17: np.array([1, 1, 1, 0, 1, 0, 0, 0, 0]),
0x18: np.array([0, 0, 0, 1, 1, 0, 0, 0, 0]),
0x19: np.array([1, 0, 0, 1, 1, 0, 0, 0, 0]),
0x1A: np.array([0, 1, 0, 1, 1, 0, 0, 0, 0]),
0x1B: np.array([1, 1, 0, 1, 1, 0, 0, 0, 0]),
0x1C: np.array([0, 0, 1, 1, 1, 0, 0, 0, 0]),
0x1D: np.array([1, 0, 1, 1, 1, 0, 0, 0, 0]),
0x1E: np.array([0, 1, 1, 1, 1, 0, 0, 0, 0]),
0x1F: np.array([1, 1, 1, 1, 1, 0, 0, 0, 0]),
}
deck = [] deck = []
for i in range(3, 15): for i in range(3, 15):
@ -63,7 +103,7 @@ class Env:
Doudizhu multi-agent wrapper Doudizhu multi-agent wrapper
""" """
def __init__(self, objective, old_model, legacy_model=False, lite_model = False): def __init__(self, objective, old_model, legacy_model=False, lite_model = False, unified_model = False):
""" """
Objective is wp/adp/logadp. It indicates whether considers Objective is wp/adp/logadp. It indicates whether considers
bomb in reward calculation. Here, we use dummy agents. bomb in reward calculation. Here, we use dummy agents.
@ -77,7 +117,8 @@ class Env:
""" """
self.objective = objective self.objective = objective
self.use_legacy = legacy_model self.use_legacy = legacy_model
self.use_general = not old_model self.use_unified = unified_model
self.use_general = not old_model and not unified_model
self.lite_model = lite_model self.lite_model = lite_model
# Initialize players # Initialize players
@ -107,6 +148,8 @@ class Env:
'landlord_up': _deck[33:58], 'landlord_up': _deck[33:58],
'landlord_front': _deck[58:83], 'landlord_front': _deck[58:83],
'landlord_down': _deck[83:108], 'landlord_down': _deck[83:108],
'three_landlord_cards': _deck[25:33],
'three_landlord_cards_all': _deck[25:33],
} }
for key in card_play_data: for key in card_play_data:
card_play_data[key].sort() card_play_data[key].sort()
@ -124,7 +167,7 @@ class Env:
self._env.info_sets[pos].player_id = pid self._env.info_sets[pos].player_id = pid
self.infoset = self._game_infoset self.infoset = self._game_infoset
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model) return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model, self.use_unified)
def step(self, action): def step(self, action):
""" """
@ -250,7 +293,7 @@ class DummyAgent(object):
self.action = action self.action = action
def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False): def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False, use_unified=False):
""" """
This function obtains observations with imperfect information This function obtains observations with imperfect information
from the infoset. It has three branches since we encode from the infoset. It has three branches since we encode
@ -278,6 +321,8 @@ def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False):
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]: if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
raise ValueError('') raise ValueError('')
return _get_obs_general(infoset, infoset.player_position, lite_model) return _get_obs_general(infoset, infoset.player_position, lite_model)
elif use_unified:
return _get_obs_unified(infoset, infoset.player_position, lite_model)
else: else:
if infoset.player_position == 'landlord': if infoset.player_position == 'landlord':
return _get_obs_landlord(infoset, use_legacy, lite_model) return _get_obs_landlord(infoset, use_legacy, lite_model)
@ -312,6 +357,12 @@ def _get_one_hot_array(num_left_cards, max_num_cards, compress_size = 0):
return one_hot return one_hot
def _cards2noise(list_cards, compressed_form = False):
if compressed_form:
return np.random.randint(0, 2, 69, dtype=np.int8)
else:
return np.random.randint(0, 2, 108, dtype=np.int8)
def _cards2array(list_cards, compressed_form = False): def _cards2array(list_cards, compressed_form = False):
""" """
A utility function that transforms the actions, i.e., A utility function that transforms the actions, i.e.,
@ -381,7 +432,7 @@ def _cards2array(list_cards, compressed_form = False):
# # action_seq_array = action_seq_array.reshape(5, 162) # # action_seq_array = action_seq_array.reshape(5, 162)
# return action_seq_array # return action_seq_array
def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = False): def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = False, use_unified = False):
""" """
A utility function to encode the historical moves. A utility function to encode the historical moves.
We encode the historical 20 actions. If there is We encode the historical 20 actions. If there is
@ -404,6 +455,19 @@ def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = Fa
for row, list_cards in enumerate(action_seq_list): for row, list_cards in enumerate(action_seq_list):
if list_cards != []: if list_cards != []:
action_seq_array[row, :108] = _cards2array(list_cards[1], compressed_form) action_seq_array[row, :108] = _cards2array(list_cards[1], compressed_form)
elif use_unified:
if compressed_form:
action_seq_array = np.zeros((len(action_seq_list), 69))
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
action_seq_array = action_seq_array.reshape(24, 276)
else:
action_seq_array = np.zeros((len(action_seq_list), 108))
for row, list_cards in enumerate(action_seq_list):
if list_cards != []:
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
action_seq_array = action_seq_array.reshape(24, 432)
else: else:
if compressed_form: if compressed_form:
action_seq_array = np.zeros((len(action_seq_list), 69)) action_seq_array = np.zeros((len(action_seq_list), 69))
@ -442,7 +506,7 @@ def _process_action_seq(sequence, length=20, new_model=True):
return sequence return sequence
def _get_one_hot_bomb(bomb_num, use_legacy = False): def _get_one_hot_bomb(bomb_num, use_legacy = False, compressed_form = False):
""" """
A utility function to encode the number of bombs A utility function to encode the number of bombs
into one-hot representation. into one-hot representation.
@ -451,7 +515,7 @@ def _get_one_hot_bomb(bomb_num, use_legacy = False):
one_hot = np.zeros(29) one_hot = np.zeros(29)
one_hot[bomb_num[0] + bomb_num[1]] = 1 one_hot[bomb_num[0] + bomb_num[1]] = 1
else: else:
one_hot = np.zeros(56) # 14 + 15 + 27 one_hot = np.zeros(56 if compressed_form else 95) # 14 + 15 + 27
one_hot[bomb_num[0]] = 1 one_hot[bomb_num[0]] = 1
one_hot[14 + bomb_num[1]] = 1 one_hot[14 + bomb_num[1]] = 1
one_hot[29 + bomb_num[2]] = 1 one_hot[29 + bomb_num[2]] = 1
@ -1000,3 +1064,134 @@ def _get_obs_general(infoset, position, compressed_form = False):
'z': z.astype(np.int8), 'z': z.astype(np.int8),
} }
return obs return obs
'''
face_up_level 0x01: three_landlord_cards, 0x02: landlord, 0x04: landlord_up, 0x08: landlord_front, 0x10: landlord_down
'''
def _get_obs_unified(infoset, position, compressed_form = True, face_up_level = 0):
num_legal_actions = len(infoset.legal_actions)
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
num_legal_actions, axis=0)
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
my_action_batch = np.zeros(my_handcards_batch.shape)
for j, action in enumerate(infoset.legal_actions):
my_action_batch[j, :] = _cards2array(action, compressed_form)
landlord_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord'], 33, 15 if compressed_form else 0)
landlord_up_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_up'], 25, 8 if compressed_form else 0)
landlord_front_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_front'], 25, 8 if compressed_form else 0)
landlord_down_num_cards_left = _get_one_hot_array(
infoset.num_cards_left_dict['landlord_down'], 25, 8 if compressed_form else 0)
landlord_played_cards = _cards2array(
infoset.played_cards['landlord'], compressed_form)
landlord_up_played_cards = _cards2array(
infoset.played_cards['landlord_up'], compressed_form)
landlord_front_played_cards = _cards2array(
infoset.played_cards['landlord_front'], compressed_form)
landlord_down_played_cards = _cards2array(
infoset.played_cards['landlord_down'], compressed_form)
if (face_up_level & 0x01) > 0:
three_landlord_cards = _cards2array(
infoset.three_landlord_cards, compressed_form)
three_landlord_cards_all = _cards2array(
infoset.three_landlord_cards_all, compressed_form)
else:
three_landlord_cards = _cards2noise(
infoset.three_landlord_cards, compressed_form)
three_landlord_cards_all = _cards2noise(
infoset.three_landlord_cards_all, compressed_form)
if (face_up_level & 0x02) > 0:
landlord_cards = _cards2array(
infoset.all_handcards['landlord'], compressed_form)
else:
landlord_cards = _cards2noise(
infoset.all_handcards['landlord'], compressed_form)
if (face_up_level & 0x04) > 0:
landlord_up_cards = _cards2array(
infoset.all_handcards['landlord_up'], compressed_form)
else:
landlord_up_cards = _cards2noise(
infoset.all_handcards['landlord_up'], compressed_form)
if (face_up_level & 0x08) > 0:
landlord_front_cards = _cards2array(
infoset.all_handcards['landlord_front'], compressed_form)
else:
landlord_front_cards = _cards2noise(
infoset.all_handcards['landlord_front'], compressed_form)
if (face_up_level & 0x10) > 0:
landlord_down_cards = _cards2array(
infoset.all_handcards['landlord_down'], compressed_form)
else:
landlord_down_cards = _cards2noise(
infoset.all_handcards['landlord_down'], compressed_form)
bomb_num = _get_one_hot_bomb(
infoset.bomb_num, compressed_form=compressed_form) # 56/95
base_info = np.hstack((
PositionInfoArray[position], # 4
FaceUpLevelArray[face_up_level], # 9
bomb_num, #56
))
num_cards_left = np.hstack((
landlord_num_cards_left, # 33/18
landlord_up_num_cards_left, # 25/17
landlord_front_num_cards_left, # 25/17
landlord_down_num_cards_left)) # 25/17
x_no_action = _action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 96), False, compressed_form, True) # 24*276 / 24*432
x_batch = np.repeat(
x_no_action[np.newaxis, :],
num_legal_actions, axis=0)
z =np.vstack((
base_info, # 69
num_cards_left, # 108 / 18+17*3=69
my_handcards, # 108/69
other_handcards, # 108/69
landlord_played_cards, # 108/69
landlord_up_played_cards, # 108/69
landlord_front_played_cards, # 108/69
landlord_down_played_cards, # 108/69
landlord_cards, # 108/69
landlord_up_cards, # 108/69
landlord_front_cards, # 108/69
landlord_down_cards, # 108/69
three_landlord_cards, # 108/69
three_landlord_cards_all, # 108/69
))
_z_batch = np.repeat(
z[np.newaxis, :, :],
num_legal_actions, axis=0)
my_action_batch = my_action_batch[:,np.newaxis,:]
z_batch = np.concatenate((my_action_batch, _z_batch), axis=1)
obs = {
'position': position,
'x_batch': x_batch.astype(np.float32),
'z_batch': z_batch.astype(np.float32),
'legal_actions': infoset.legal_actions,
'x_no_action': x_no_action.astype(np.int8),
'z': z.astype(np.int8),
}
return obs

36
douzero/env/game.py vendored
View File

@ -127,7 +127,8 @@ class GameEnv(object):
self.card_play_action_seq = [] self.card_play_action_seq = []
# self.three_landlord_cards = None self.three_landlord_cards = None
self.three_landlord_cards_all = None
self.game_over = False self.game_over = False
self.acting_player_position = None self.acting_player_position = None
@ -185,7 +186,8 @@ class GameEnv(object):
card_play_data['landlord_front'] card_play_data['landlord_front']
self.info_sets['landlord_down'].player_hand_cards = \ self.info_sets['landlord_down'].player_hand_cards = \
card_play_data['landlord_down'] card_play_data['landlord_down']
# self.three_landlord_cards = card_play_data['three_landlord_cards'] self.three_landlord_cards = card_play_data['three_landlord_cards']
self.three_landlord_cards_all = card_play_data['three_landlord_cards']
self.get_acting_player_position() self.get_acting_player_position()
self.game_infoset = self.get_infoset() self.game_infoset = self.get_infoset()
@ -253,15 +255,15 @@ class GameEnv(object):
self.played_cards[self.acting_player_position] += action self.played_cards[self.acting_player_position] += action
# if self.acting_player_position == 'landlord' and \ if self.acting_player_position == 'landlord' and \
# len(action) > 0 and \ len(action) > 0 and \
# len(self.three_landlord_cards) > 0: len(self.three_landlord_cards) > 0:
# for card in action: for card in action:
# if len(self.three_landlord_cards) > 0: if len(self.three_landlord_cards) > 0:
# if card in self.three_landlord_cards: if card in self.three_landlord_cards:
# self.three_landlord_cards.remove(card) self.three_landlord_cards.remove(card)
# else: else:
# break break
self.game_done() self.game_done()
if not self.game_over: if not self.game_over:
@ -333,7 +335,8 @@ class GameEnv(object):
def reset(self): def reset(self):
self.card_play_action_seq = [] self.card_play_action_seq = []
# self.three_landlord_cards = None self.three_landlord_cards = None
self.three_landlord_cards_all = None
self.game_over = False self.game_over = False
self.acting_player_position = None self.acting_player_position = None
@ -397,8 +400,10 @@ class GameEnv(object):
self.info_sets[self.acting_player_position].played_cards = \ self.info_sets[self.acting_player_position].played_cards = \
self.played_cards self.played_cards
# self.info_sets[self.acting_player_position].three_landlord_cards = \ self.info_sets[self.acting_player_position].three_landlord_cards = \
# self.three_landlord_cards self.three_landlord_cards
self.info_sets[self.acting_player_position].three_landlord_cards_all = \
self.three_landlord_cards_all
self.info_sets[self.acting_player_position].card_play_action_seq = \ self.info_sets[self.acting_player_position].card_play_action_seq = \
self.card_play_action_seq self.card_play_action_seq
@ -424,7 +429,8 @@ class InfoSet(object):
# The number of cards left for each player. It is a dict with str-->int # The number of cards left for each player. It is a dict with str-->int
self.num_cards_left_dict = None self.num_cards_left_dict = None
# The three landload cards. A list. # The three landload cards. A list.
# self.three_landlord_cards = None self.three_landlord_cards = None
self.three_landlord_cards_all = None
# The historical moves. It is a list of list # The historical moves. It is a list of list
self.card_play_action_seq = None self.card_play_action_seq = None
# The union of the hand cards of the other two players for the current player # The union of the hand cards of the other two players for the current player

View File

@ -21,7 +21,8 @@ def generate():
'landlord_up': _deck[33:58], 'landlord_up': _deck[33:58],
'landlord_front': _deck[58:83], 'landlord_front': _deck[58:83],
'landlord_down': _deck[83:108], 'landlord_down': _deck[83:108],
# 'three_landlord_cards': _deck[25:33], 'three_landlord_cards': _deck[25:33],
'three_landlord_cards_all': _deck[25:33],
} }
for key in card_play_data: for key in card_play_data:
card_play_data[key].sort() card_play_data[key].sort()