Merge branch 'unified' of ssh://git@git.zaneyork.cn:2222/douzero/Douzero_Resnet.git into main
This commit is contained in:
commit
f0883bbf31
|
@ -31,6 +31,8 @@ parser.add_argument('--load_model', action='store_true',
|
||||||
help='Load an existing model')
|
help='Load an existing model')
|
||||||
parser.add_argument('--old_model', action='store_true',
|
parser.add_argument('--old_model', action='store_true',
|
||||||
help='Use vanilla model')
|
help='Use vanilla model')
|
||||||
|
parser.add_argument('--unified_model', action='store_true',
|
||||||
|
help='Use unified model')
|
||||||
parser.add_argument('--lite_model', action='store_true',
|
parser.add_argument('--lite_model', action='store_true',
|
||||||
help='Use lite card model')
|
help='Use lite card model')
|
||||||
parser.add_argument('--lagacy_model', action='store_true',
|
parser.add_argument('--lagacy_model', action='store_true',
|
||||||
|
|
|
@ -13,8 +13,8 @@ from torch import nn
|
||||||
import douzero.dmc.models
|
import douzero.dmc.models
|
||||||
import douzero.env.env
|
import douzero.env.env
|
||||||
from .file_writer import FileWriter
|
from .file_writer import FileWriter
|
||||||
from .models import Model, OldModel
|
from .models import Model, OldModel, UnifiedModel
|
||||||
from .utils import get_batch, log, create_env, create_optimizers, act, infer_logic
|
from .utils import get_batch, log, create_optimizers, act, infer_logic
|
||||||
import psutil
|
import psutil
|
||||||
import shutil
|
import shutil
|
||||||
import requests
|
import requests
|
||||||
|
@ -87,6 +87,8 @@ def train(flags):
|
||||||
# Initialize actor models
|
# Initialize actor models
|
||||||
if flags.old_model:
|
if flags.old_model:
|
||||||
actor_model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model)
|
actor_model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model)
|
||||||
|
elif flags.unified_model:
|
||||||
|
actor_model = UnifiedModel(device="cpu", flags = flags, lite_model = flags.lite_model)
|
||||||
else:
|
else:
|
||||||
actor_model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
|
actor_model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
|
||||||
actor_model.eval()
|
actor_model.eval()
|
||||||
|
@ -100,6 +102,8 @@ def train(flags):
|
||||||
# Learner model for training
|
# Learner model for training
|
||||||
if flags.old_model:
|
if flags.old_model:
|
||||||
learner_model = OldModel(device=flags.training_device, lite_model = flags.lite_model)
|
learner_model = OldModel(device=flags.training_device, lite_model = flags.lite_model)
|
||||||
|
elif flags.unified_model:
|
||||||
|
learner_model = UnifiedModel(device=flags.training_device, lite_model = flags.lite_model)
|
||||||
else:
|
else:
|
||||||
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
|
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
|
||||||
|
|
||||||
|
@ -116,13 +120,18 @@ def train(flags):
|
||||||
]
|
]
|
||||||
frames, stats = 0, {k: 0 for k in stat_keys}
|
frames, stats = 0, {k: 0 for k in stat_keys}
|
||||||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
if flags.unified_model:
|
||||||
|
lock = threading.Lock()
|
||||||
|
position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock, 'uni': lock}
|
||||||
|
else:
|
||||||
|
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||||
|
|
||||||
def sync_onnx_model(frames):
|
def sync_onnx_model(frames):
|
||||||
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
||||||
if not os.path.exists(p_path):
|
if not os.path.exists(p_path):
|
||||||
os.makedirs(p_path)
|
os.makedirs(p_path)
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
positions = ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
|
for position in positions:
|
||||||
if flags.enable_onnx:
|
if flags.enable_onnx:
|
||||||
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
||||||
onnx_params = learner_model.get_model(position)\
|
onnx_params = learner_model.get_model(position)\
|
||||||
|
@ -149,11 +158,17 @@ def train(flags):
|
||||||
checkpoint_states = torch.load(
|
checkpoint_states = torch.load(
|
||||||
checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu")
|
checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu")
|
||||||
)
|
)
|
||||||
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down']
|
if flags.unified_model:
|
||||||
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
learner_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni'])
|
||||||
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
optimizers['uni'].load_state_dict(checkpoint_states["optimizer_state_dict"]['uni'])
|
||||||
if not flags.enable_onnx:
|
if not flags.enable_onnx:
|
||||||
actor_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
actor_model.get_model('uni').load_state_dict(checkpoint_states["model_state_dict"]['uni'])
|
||||||
|
else:
|
||||||
|
for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down']
|
||||||
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||||
|
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
||||||
|
if not flags.enable_onnx:
|
||||||
|
actor_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||||
stats = checkpoint_states["stats"]
|
stats = checkpoint_states["stats"]
|
||||||
|
|
||||||
frames = checkpoint_states["frames"]
|
frames = checkpoint_states["frames"]
|
||||||
|
@ -206,7 +221,7 @@ def train(flags):
|
||||||
while frames < flags.total_frames:
|
while frames < flags.total_frames:
|
||||||
batch = get_batch(batch_queues, position, flags, local_lock)
|
batch = get_batch(batch_queues, position, flags, local_lock)
|
||||||
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
|
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
|
||||||
optimizers[position], flags, position_lock)
|
optimizers['uni'], flags, position_lock)
|
||||||
with lock:
|
with lock:
|
||||||
for k in _stats:
|
for k in _stats:
|
||||||
stats[k] = _stats[k]
|
stats[k] = _stats[k]
|
||||||
|
@ -244,7 +259,7 @@ def train(flags):
|
||||||
}, checkpointpath + '.new')
|
}, checkpointpath + '.new')
|
||||||
|
|
||||||
# Save the weights for evaluation purpose
|
# Save the weights for evaluation purpose
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
for position in ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||||
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
model_weights_dir = os.path.expandvars(os.path.expanduser(
|
||||||
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt')))
|
||||||
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
|
||||||
|
@ -255,13 +270,18 @@ def train(flags):
|
||||||
type = ''
|
type = ''
|
||||||
if flags.old_model:
|
if flags.old_model:
|
||||||
type += 'vanilla'
|
type += 'vanilla'
|
||||||
|
elif flags.unified_model:
|
||||||
|
type += 'unified'
|
||||||
else:
|
else:
|
||||||
type += 'resnet'
|
type += 'resnet'
|
||||||
requests.post(flags.upload_url, data={
|
try:
|
||||||
'type': type,
|
requests.post(flags.upload_url, data={
|
||||||
'position': position,
|
'type': type,
|
||||||
'frame': frames
|
'position': position,
|
||||||
}, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))})
|
'frame': frames
|
||||||
|
}, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))})
|
||||||
|
except:
|
||||||
|
print("模型上传失败")
|
||||||
os.remove(model_weights_dir)
|
os.remove(model_weights_dir)
|
||||||
shutil.move(checkpointpath + '.new', checkpointpath)
|
shutil.move(checkpointpath + '.new', checkpointpath)
|
||||||
|
|
||||||
|
|
|
@ -400,6 +400,74 @@ class GeneralModelLite(nn.Module):
|
||||||
return dict(values=out)
|
return dict(values=out)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedModelLite(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.in_planes = 30
|
||||||
|
#input 1*69*15
|
||||||
|
self.conv1 = nn.Conv1d(15, 30, kernel_size=(3,),
|
||||||
|
stride=(2,), padding=1, bias=False) #1*35*30
|
||||||
|
|
||||||
|
self.bn1 = nn.BatchNorm1d(30)
|
||||||
|
|
||||||
|
self.layer1 = self._make_layer(BasicBlock, 30, 2, stride=2)#1*18*30
|
||||||
|
self.layer2 = self._make_layer(BasicBlock, 60, 2, stride=2)#1*9*60
|
||||||
|
self.layer3 = self._make_layer(BasicBlock, 120, 2, stride=2)#1*5*120
|
||||||
|
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||||
|
self.lstm = nn.LSTM(276, 128, batch_first=True)
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(120 * BasicBlock.expansion * 5 + 128, 2048)
|
||||||
|
self.linear2 = nn.Linear(2048, 1024)
|
||||||
|
self.linear3 = nn.Linear(1024, 512)
|
||||||
|
self.linear4 = nn.Linear(512, 256)
|
||||||
|
self.linear5 = nn.Linear(256, 1)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
layers = []
|
||||||
|
for stride in strides:
|
||||||
|
layers.append(block(self.in_planes, planes, stride))
|
||||||
|
self.in_planes = planes * block.expansion
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def get_onnx_params(self, device=None):
|
||||||
|
return {
|
||||||
|
'args': (
|
||||||
|
torch.randn(1, 15, 69, requires_grad=True, device=device),
|
||||||
|
torch.randn(1, 24, 276, requires_grad=True, device=device),
|
||||||
|
),
|
||||||
|
'input_names': ['z_batch','x_batch'],
|
||||||
|
'output_names': ['values'],
|
||||||
|
'dynamic_axes': {
|
||||||
|
'z_batch': {
|
||||||
|
0: "batch_size"
|
||||||
|
},
|
||||||
|
'x_batch': {
|
||||||
|
0: "batch_size"
|
||||||
|
},
|
||||||
|
'values': {
|
||||||
|
0: "batch_size"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, z, x):
|
||||||
|
out = F.relu(self.bn1(self.conv1(z)))
|
||||||
|
out = self.layer1(out)
|
||||||
|
out = self.layer2(out)
|
||||||
|
out = self.layer3(out)
|
||||||
|
out = out.flatten(1,2)
|
||||||
|
lstm_out, (h_n, _) = self.lstm(x)
|
||||||
|
lstm_out = lstm_out[:,-1,:]
|
||||||
|
out = torch.cat([lstm_out,out], dim=1)
|
||||||
|
out = F.leaky_relu_(self.linear1(out))
|
||||||
|
out = F.leaky_relu_(self.linear2(out))
|
||||||
|
out = F.leaky_relu_(self.linear3(out))
|
||||||
|
out = F.leaky_relu_(self.linear4(out))
|
||||||
|
out = F.leaky_relu_(self.linear5(out))
|
||||||
|
return dict(values=out)
|
||||||
|
|
||||||
class GeneralModel(nn.Module):
|
class GeneralModel(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -492,6 +560,11 @@ model_dict_new_lite['landlord'] = GeneralModelLite
|
||||||
model_dict_new_lite['landlord_up'] = GeneralModelLite
|
model_dict_new_lite['landlord_up'] = GeneralModelLite
|
||||||
model_dict_new_lite['landlord_front'] = GeneralModelLite
|
model_dict_new_lite['landlord_front'] = GeneralModelLite
|
||||||
model_dict_new_lite['landlord_down'] = GeneralModelLite
|
model_dict_new_lite['landlord_down'] = GeneralModelLite
|
||||||
|
model_dict_uni_lite = {}
|
||||||
|
model_dict_uni_lite['landlord'] = UnifiedModelLite
|
||||||
|
model_dict_uni_lite['landlord_up'] = UnifiedModelLite
|
||||||
|
model_dict_uni_lite['landlord_front'] = UnifiedModelLite
|
||||||
|
model_dict_uni_lite['landlord_down'] = UnifiedModelLite
|
||||||
|
|
||||||
def forward_logic(self_model, position, z, x, device='cpu', return_value=False, flags=None):
|
def forward_logic(self_model, position, z, x, device='cpu', return_value=False, flags=None):
|
||||||
legal_count = len(z)
|
legal_count = len(z)
|
||||||
|
@ -675,3 +748,64 @@ class Model:
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
return self.models
|
return self.models
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedModel:
|
||||||
|
"""
|
||||||
|
The wrapper for the three models. We also wrap several
|
||||||
|
interfaces such as share_memory, eval, etc.
|
||||||
|
"""
|
||||||
|
def __init__(self, device=0, flags=None, lite_model = False):
|
||||||
|
self.onnx_models = {}
|
||||||
|
self.model = None
|
||||||
|
self.models = {}
|
||||||
|
self.flags = flags
|
||||||
|
if not device == "cpu":
|
||||||
|
device = 'cuda:' + str(device)
|
||||||
|
self.device = torch.device(device)
|
||||||
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
|
if flags is not None and flags.enable_onnx:
|
||||||
|
self.model = None
|
||||||
|
else:
|
||||||
|
if lite_model:
|
||||||
|
self.model = UnifiedModelLite().to(self.device)
|
||||||
|
for position in positions:
|
||||||
|
self.models[position] = self.model
|
||||||
|
else:
|
||||||
|
self.model = GeneralModel().to(self.device)
|
||||||
|
self.onnx_model = None
|
||||||
|
|
||||||
|
def set_onnx_model(self, device='cpu'):
|
||||||
|
model_path = os.path.abspath('%s/%s/model_uni.onnx' % (self.flags.onnx_model_path, self.flags.xpid))
|
||||||
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
|
if device == 'cpu':
|
||||||
|
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
||||||
|
else:
|
||||||
|
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
|
||||||
|
for position in positions:
|
||||||
|
self.onnx_models[position] = self.onnx_model
|
||||||
|
|
||||||
|
def get_onnx_params(self, position):
|
||||||
|
self.model.get_onnx_params(self.device)
|
||||||
|
|
||||||
|
def forward(self, position, z, x, device='cpu', return_value=False, flags=None):
|
||||||
|
return forward_logic(self, position, z, x, device, return_value, flags)
|
||||||
|
|
||||||
|
def share_memory(self):
|
||||||
|
if self.model is not None:
|
||||||
|
self.model.share_memory()
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
if self.model is not None:
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def parameters(self, position):
|
||||||
|
return self.model.parameters()
|
||||||
|
|
||||||
|
def get_model(self, position):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
return {
|
||||||
|
'uni' : self.model
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ log.setLevel(logging.INFO)
|
||||||
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
||||||
|
|
||||||
def create_env(flags):
|
def create_env(flags):
|
||||||
return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model)
|
return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model, flags.unified_model)
|
||||||
|
|
||||||
def get_batch(b_queues, position, flags, lock):
|
def get_batch(b_queues, position, flags, lock):
|
||||||
"""
|
"""
|
||||||
|
@ -104,20 +104,31 @@ def create_optimizers(flags, learner_model):
|
||||||
"""
|
"""
|
||||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
optimizers = {}
|
optimizers = {}
|
||||||
for position in positions:
|
if flags.unified_model:
|
||||||
|
position = 'uni'
|
||||||
optimizer = RAdam(
|
optimizer = RAdam(
|
||||||
learner_model.parameters(position),
|
learner_model.parameters(position),
|
||||||
lr=flags.learning_rate,
|
lr=flags.learning_rate,
|
||||||
eps=flags.epsilon)
|
eps=flags.epsilon)
|
||||||
optimizers[position] = optimizer
|
optimizers[position] = optimizer
|
||||||
|
else:
|
||||||
|
for position in positions:
|
||||||
|
optimizer = RAdam(
|
||||||
|
learner_model.parameters(position),
|
||||||
|
lr=flags.learning_rate,
|
||||||
|
eps=flags.epsilon)
|
||||||
|
optimizers[position] = optimizer
|
||||||
return optimizers
|
return optimizers
|
||||||
|
|
||||||
|
|
||||||
def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
|
def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
|
||||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
if not flags.enable_onnx:
|
if not flags.enable_onnx:
|
||||||
for pos in positions:
|
if flags.unified_model:
|
||||||
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
|
model.model.to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
|
||||||
|
else:
|
||||||
|
for pos in positions:
|
||||||
|
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
|
||||||
last_onnx_frame = -1
|
last_onnx_frame = -1
|
||||||
log.info('Infer %i started.', i)
|
log.info('Infer %i started.', i)
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,47 @@ NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
||||||
12: np.array([0, 0, 1, 1, 0]),
|
12: np.array([0, 0, 1, 1, 0]),
|
||||||
13: np.array([1, 0, 1, 1, 0]),
|
13: np.array([1, 0, 1, 1, 0]),
|
||||||
15: np.array([1, 1, 1, 1, 0])}
|
15: np.array([1, 1, 1, 1, 0])}
|
||||||
|
PositionInfoArray = {
|
||||||
|
'landlord': np.array([1, 0, 0, 0]),
|
||||||
|
'landlord_down': np.array([0, 1, 0, 0]),
|
||||||
|
'landlord_front': np.array([0, 0, 1, 0]),
|
||||||
|
'landlord_up': np.array([0, 0, 0, 1]),
|
||||||
|
}
|
||||||
|
|
||||||
|
FaceUpLevelArray = {
|
||||||
|
0x00: np.array([0, 0, 0, 0, 0, 0, 0, 0, 0]),
|
||||||
|
0x01: np.array([1, 0, 0, 0, 0, 0, 0, 0, 0]),
|
||||||
|
0x02: np.array([0, 1, 0, 0, 0, 0, 0, 0, 0]),
|
||||||
|
0x03: np.array([1, 1, 0, 0, 0, 0, 0, 0, 0]),
|
||||||
|
0x04: np.array([0, 0, 1, 0, 0, 0, 0, 0, 0]),
|
||||||
|
0x05: np.array([1, 0, 1, 0, 0, 0, 0, 0, 0]),
|
||||||
|
0x06: np.array([0, 1, 1, 0, 0, 0, 0, 0, 0]),
|
||||||
|
0x07: np.array([1, 1, 1, 0, 0, 0, 0, 0, 0]),
|
||||||
|
0x08: np.array([0, 0, 0, 1, 0, 0, 0, 0, 0]),
|
||||||
|
0x09: np.array([1, 0, 0, 1, 0, 0, 0, 0, 0]),
|
||||||
|
0x0A: np.array([0, 1, 0, 1, 0, 0, 0, 0, 0]),
|
||||||
|
0x0B: np.array([1, 1, 0, 1, 0, 0, 0, 0, 0]),
|
||||||
|
0x0C: np.array([0, 0, 1, 1, 0, 0, 0, 0, 0]),
|
||||||
|
0x0D: np.array([1, 0, 1, 1, 0, 0, 0, 0, 0]),
|
||||||
|
0x0E: np.array([0, 1, 1, 1, 0, 0, 0, 0, 0]),
|
||||||
|
0x0F: np.array([1, 1, 1, 1, 0, 0, 0, 0, 0]),
|
||||||
|
0x10: np.array([0, 0, 0, 0, 1, 0, 0, 0, 0]),
|
||||||
|
0x11: np.array([1, 0, 0, 0, 1, 0, 0, 0, 0]),
|
||||||
|
0x12: np.array([0, 1, 0, 0, 1, 0, 0, 0, 0]),
|
||||||
|
0x13: np.array([1, 1, 0, 0, 1, 0, 0, 0, 0]),
|
||||||
|
0x14: np.array([0, 0, 1, 0, 1, 0, 0, 0, 0]),
|
||||||
|
0x15: np.array([1, 0, 1, 0, 1, 0, 0, 0, 0]),
|
||||||
|
0x16: np.array([0, 1, 1, 0, 1, 0, 0, 0, 0]),
|
||||||
|
0x17: np.array([1, 1, 1, 0, 1, 0, 0, 0, 0]),
|
||||||
|
0x18: np.array([0, 0, 0, 1, 1, 0, 0, 0, 0]),
|
||||||
|
0x19: np.array([1, 0, 0, 1, 1, 0, 0, 0, 0]),
|
||||||
|
0x1A: np.array([0, 1, 0, 1, 1, 0, 0, 0, 0]),
|
||||||
|
0x1B: np.array([1, 1, 0, 1, 1, 0, 0, 0, 0]),
|
||||||
|
0x1C: np.array([0, 0, 1, 1, 1, 0, 0, 0, 0]),
|
||||||
|
0x1D: np.array([1, 0, 1, 1, 1, 0, 0, 0, 0]),
|
||||||
|
0x1E: np.array([0, 1, 1, 1, 1, 0, 0, 0, 0]),
|
||||||
|
0x1F: np.array([1, 1, 1, 1, 1, 0, 0, 0, 0]),
|
||||||
|
}
|
||||||
|
|
||||||
deck = []
|
deck = []
|
||||||
for i in range(3, 15):
|
for i in range(3, 15):
|
||||||
|
@ -63,7 +103,7 @@ class Env:
|
||||||
Doudizhu multi-agent wrapper
|
Doudizhu multi-agent wrapper
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, objective, old_model, legacy_model=False, lite_model = False):
|
def __init__(self, objective, old_model, legacy_model=False, lite_model = False, unified_model = False):
|
||||||
"""
|
"""
|
||||||
Objective is wp/adp/logadp. It indicates whether considers
|
Objective is wp/adp/logadp. It indicates whether considers
|
||||||
bomb in reward calculation. Here, we use dummy agents.
|
bomb in reward calculation. Here, we use dummy agents.
|
||||||
|
@ -77,7 +117,8 @@ class Env:
|
||||||
"""
|
"""
|
||||||
self.objective = objective
|
self.objective = objective
|
||||||
self.use_legacy = legacy_model
|
self.use_legacy = legacy_model
|
||||||
self.use_general = not old_model
|
self.use_unified = unified_model
|
||||||
|
self.use_general = not old_model and not unified_model
|
||||||
self.lite_model = lite_model
|
self.lite_model = lite_model
|
||||||
|
|
||||||
# Initialize players
|
# Initialize players
|
||||||
|
@ -107,6 +148,8 @@ class Env:
|
||||||
'landlord_up': _deck[33:58],
|
'landlord_up': _deck[33:58],
|
||||||
'landlord_front': _deck[58:83],
|
'landlord_front': _deck[58:83],
|
||||||
'landlord_down': _deck[83:108],
|
'landlord_down': _deck[83:108],
|
||||||
|
'three_landlord_cards': _deck[25:33],
|
||||||
|
'three_landlord_cards_all': _deck[25:33],
|
||||||
}
|
}
|
||||||
for key in card_play_data:
|
for key in card_play_data:
|
||||||
card_play_data[key].sort()
|
card_play_data[key].sort()
|
||||||
|
@ -124,7 +167,7 @@ class Env:
|
||||||
self._env.info_sets[pos].player_id = pid
|
self._env.info_sets[pos].player_id = pid
|
||||||
self.infoset = self._game_infoset
|
self.infoset = self._game_infoset
|
||||||
|
|
||||||
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
|
return get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model, self.use_unified)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""
|
"""
|
||||||
|
@ -152,7 +195,7 @@ class Env:
|
||||||
}
|
}
|
||||||
obs = None
|
obs = None
|
||||||
else:
|
else:
|
||||||
obs = get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model)
|
obs = get_obs(self.infoset, self.use_general, self.use_legacy, self.lite_model, self.use_unified)
|
||||||
return obs, reward, done, {}
|
return obs, reward, done, {}
|
||||||
|
|
||||||
def _get_reward(self, pos):
|
def _get_reward(self, pos):
|
||||||
|
@ -250,7 +293,7 @@ class DummyAgent(object):
|
||||||
self.action = action
|
self.action = action
|
||||||
|
|
||||||
|
|
||||||
def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False):
|
def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False, use_unified=False):
|
||||||
"""
|
"""
|
||||||
This function obtains observations with imperfect information
|
This function obtains observations with imperfect information
|
||||||
from the infoset. It has three branches since we encode
|
from the infoset. It has three branches since we encode
|
||||||
|
@ -278,6 +321,8 @@ def get_obs(infoset, use_general=True, use_legacy = False, lite_model = False):
|
||||||
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
if infoset.player_position not in ["landlord", "landlord_up", "landlord_front", "landlord_down"]:
|
||||||
raise ValueError('')
|
raise ValueError('')
|
||||||
return _get_obs_general(infoset, infoset.player_position, lite_model)
|
return _get_obs_general(infoset, infoset.player_position, lite_model)
|
||||||
|
elif use_unified:
|
||||||
|
return _get_obs_unified(infoset, infoset.player_position, lite_model)
|
||||||
else:
|
else:
|
||||||
if infoset.player_position == 'landlord':
|
if infoset.player_position == 'landlord':
|
||||||
return _get_obs_landlord(infoset, use_legacy, lite_model)
|
return _get_obs_landlord(infoset, use_legacy, lite_model)
|
||||||
|
@ -312,6 +357,12 @@ def _get_one_hot_array(num_left_cards, max_num_cards, compress_size = 0):
|
||||||
return one_hot
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
|
def _cards2noise(list_cards, compressed_form = False):
|
||||||
|
if compressed_form:
|
||||||
|
return np.random.randint(0, 2, 69, dtype=np.int8)
|
||||||
|
else:
|
||||||
|
return np.random.randint(0, 2, 108, dtype=np.int8)
|
||||||
|
|
||||||
def _cards2array(list_cards, compressed_form = False):
|
def _cards2array(list_cards, compressed_form = False):
|
||||||
"""
|
"""
|
||||||
A utility function that transforms the actions, i.e.,
|
A utility function that transforms the actions, i.e.,
|
||||||
|
@ -381,7 +432,7 @@ def _cards2array(list_cards, compressed_form = False):
|
||||||
# # action_seq_array = action_seq_array.reshape(5, 162)
|
# # action_seq_array = action_seq_array.reshape(5, 162)
|
||||||
# return action_seq_array
|
# return action_seq_array
|
||||||
|
|
||||||
def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = False):
|
def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = False, use_unified = False):
|
||||||
"""
|
"""
|
||||||
A utility function to encode the historical moves.
|
A utility function to encode the historical moves.
|
||||||
We encode the historical 20 actions. If there is
|
We encode the historical 20 actions. If there is
|
||||||
|
@ -404,6 +455,19 @@ def _action_seq_list2array(action_seq_list, new_model=True, compressed_form = Fa
|
||||||
for row, list_cards in enumerate(action_seq_list):
|
for row, list_cards in enumerate(action_seq_list):
|
||||||
if list_cards != []:
|
if list_cards != []:
|
||||||
action_seq_array[row, :108] = _cards2array(list_cards[1], compressed_form)
|
action_seq_array[row, :108] = _cards2array(list_cards[1], compressed_form)
|
||||||
|
elif use_unified:
|
||||||
|
if compressed_form:
|
||||||
|
action_seq_array = np.zeros((len(action_seq_list), 69))
|
||||||
|
for row, list_cards in enumerate(action_seq_list):
|
||||||
|
if list_cards != []:
|
||||||
|
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
|
||||||
|
action_seq_array = action_seq_array.reshape(24, 276)
|
||||||
|
else:
|
||||||
|
action_seq_array = np.zeros((len(action_seq_list), 108))
|
||||||
|
for row, list_cards in enumerate(action_seq_list):
|
||||||
|
if list_cards != []:
|
||||||
|
action_seq_array[row, :] = _cards2array(list_cards[1], compressed_form)
|
||||||
|
action_seq_array = action_seq_array.reshape(24, 432)
|
||||||
else:
|
else:
|
||||||
if compressed_form:
|
if compressed_form:
|
||||||
action_seq_array = np.zeros((len(action_seq_list), 69))
|
action_seq_array = np.zeros((len(action_seq_list), 69))
|
||||||
|
@ -442,7 +506,7 @@ def _process_action_seq(sequence, length=20, new_model=True):
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
def _get_one_hot_bomb(bomb_num, use_legacy = False):
|
def _get_one_hot_bomb(bomb_num, use_legacy = False, compressed_form = False):
|
||||||
"""
|
"""
|
||||||
A utility function to encode the number of bombs
|
A utility function to encode the number of bombs
|
||||||
into one-hot representation.
|
into one-hot representation.
|
||||||
|
@ -451,7 +515,7 @@ def _get_one_hot_bomb(bomb_num, use_legacy = False):
|
||||||
one_hot = np.zeros(29)
|
one_hot = np.zeros(29)
|
||||||
one_hot[bomb_num[0] + bomb_num[1]] = 1
|
one_hot[bomb_num[0] + bomb_num[1]] = 1
|
||||||
else:
|
else:
|
||||||
one_hot = np.zeros(56) # 14 + 15 + 27
|
one_hot = np.zeros(56 if compressed_form else 95) # 14 + 15 + 27
|
||||||
one_hot[bomb_num[0]] = 1
|
one_hot[bomb_num[0]] = 1
|
||||||
one_hot[14 + bomb_num[1]] = 1
|
one_hot[14 + bomb_num[1]] = 1
|
||||||
one_hot[29 + bomb_num[2]] = 1
|
one_hot[29 + bomb_num[2]] = 1
|
||||||
|
@ -1000,3 +1064,134 @@ def _get_obs_general(infoset, position, compressed_form = False):
|
||||||
'z': z.astype(np.int8),
|
'z': z.astype(np.int8),
|
||||||
}
|
}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
'''
|
||||||
|
face_up_level 0x01: three_landlord_cards, 0x02: landlord, 0x04: landlord_up, 0x08: landlord_front, 0x10: landlord_down
|
||||||
|
'''
|
||||||
|
def _get_obs_unified(infoset, position, compressed_form = True, face_up_level = 0):
|
||||||
|
num_legal_actions = len(infoset.legal_actions)
|
||||||
|
my_handcards = _cards2array(infoset.player_hand_cards, compressed_form)
|
||||||
|
my_handcards_batch = np.repeat(my_handcards[np.newaxis, :],
|
||||||
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
|
other_handcards = _cards2array(infoset.other_hand_cards, compressed_form)
|
||||||
|
|
||||||
|
my_action_batch = np.zeros(my_handcards_batch.shape)
|
||||||
|
for j, action in enumerate(infoset.legal_actions):
|
||||||
|
my_action_batch[j, :] = _cards2array(action, compressed_form)
|
||||||
|
|
||||||
|
landlord_num_cards_left = _get_one_hot_array(
|
||||||
|
infoset.num_cards_left_dict['landlord'], 33, 15 if compressed_form else 0)
|
||||||
|
|
||||||
|
landlord_up_num_cards_left = _get_one_hot_array(
|
||||||
|
infoset.num_cards_left_dict['landlord_up'], 25, 8 if compressed_form else 0)
|
||||||
|
|
||||||
|
landlord_front_num_cards_left = _get_one_hot_array(
|
||||||
|
infoset.num_cards_left_dict['landlord_front'], 25, 8 if compressed_form else 0)
|
||||||
|
|
||||||
|
landlord_down_num_cards_left = _get_one_hot_array(
|
||||||
|
infoset.num_cards_left_dict['landlord_down'], 25, 8 if compressed_form else 0)
|
||||||
|
|
||||||
|
landlord_played_cards = _cards2array(
|
||||||
|
infoset.played_cards['landlord'], compressed_form)
|
||||||
|
|
||||||
|
landlord_up_played_cards = _cards2array(
|
||||||
|
infoset.played_cards['landlord_up'], compressed_form)
|
||||||
|
|
||||||
|
landlord_front_played_cards = _cards2array(
|
||||||
|
infoset.played_cards['landlord_front'], compressed_form)
|
||||||
|
|
||||||
|
landlord_down_played_cards = _cards2array(
|
||||||
|
infoset.played_cards['landlord_down'], compressed_form)
|
||||||
|
|
||||||
|
if (face_up_level & 0x01) > 0:
|
||||||
|
three_landlord_cards = _cards2array(
|
||||||
|
infoset.three_landlord_cards, compressed_form)
|
||||||
|
|
||||||
|
three_landlord_cards_all = _cards2array(
|
||||||
|
infoset.three_landlord_cards_all, compressed_form)
|
||||||
|
else:
|
||||||
|
three_landlord_cards = _cards2noise(
|
||||||
|
infoset.three_landlord_cards, compressed_form)
|
||||||
|
|
||||||
|
three_landlord_cards_all = _cards2noise(
|
||||||
|
infoset.three_landlord_cards_all, compressed_form)
|
||||||
|
|
||||||
|
if (face_up_level & 0x02) > 0:
|
||||||
|
landlord_cards = _cards2array(
|
||||||
|
infoset.all_handcards['landlord'], compressed_form)
|
||||||
|
else:
|
||||||
|
landlord_cards = _cards2noise(
|
||||||
|
infoset.all_handcards['landlord'], compressed_form)
|
||||||
|
|
||||||
|
if (face_up_level & 0x04) > 0:
|
||||||
|
landlord_up_cards = _cards2array(
|
||||||
|
infoset.all_handcards['landlord_up'], compressed_form)
|
||||||
|
else:
|
||||||
|
landlord_up_cards = _cards2noise(
|
||||||
|
infoset.all_handcards['landlord_up'], compressed_form)
|
||||||
|
|
||||||
|
if (face_up_level & 0x08) > 0:
|
||||||
|
landlord_front_cards = _cards2array(
|
||||||
|
infoset.all_handcards['landlord_front'], compressed_form)
|
||||||
|
else:
|
||||||
|
landlord_front_cards = _cards2noise(
|
||||||
|
infoset.all_handcards['landlord_front'], compressed_form)
|
||||||
|
|
||||||
|
if (face_up_level & 0x10) > 0:
|
||||||
|
landlord_down_cards = _cards2array(
|
||||||
|
infoset.all_handcards['landlord_down'], compressed_form)
|
||||||
|
else:
|
||||||
|
landlord_down_cards = _cards2noise(
|
||||||
|
infoset.all_handcards['landlord_down'], compressed_form)
|
||||||
|
|
||||||
|
bomb_num = _get_one_hot_bomb(
|
||||||
|
infoset.bomb_num, compressed_form=compressed_form) # 56/95
|
||||||
|
base_info = np.hstack((
|
||||||
|
PositionInfoArray[position], # 4
|
||||||
|
FaceUpLevelArray[face_up_level], # 9
|
||||||
|
bomb_num, #56
|
||||||
|
))
|
||||||
|
num_cards_left = np.hstack((
|
||||||
|
landlord_num_cards_left, # 33/18
|
||||||
|
landlord_up_num_cards_left, # 25/17
|
||||||
|
landlord_front_num_cards_left, # 25/17
|
||||||
|
landlord_down_num_cards_left)) # 25/17
|
||||||
|
|
||||||
|
x_no_action = _action_seq_list2array(_process_action_seq(infoset.card_play_action_seq, 96), False, compressed_form, True) # 24*276 / 24*432
|
||||||
|
|
||||||
|
x_batch = np.repeat(
|
||||||
|
x_no_action[np.newaxis, :],
|
||||||
|
num_legal_actions, axis=0)
|
||||||
|
|
||||||
|
z =np.vstack((
|
||||||
|
base_info, # 69
|
||||||
|
num_cards_left, # 108 / 18+17*3=69
|
||||||
|
my_handcards, # 108/69
|
||||||
|
other_handcards, # 108/69
|
||||||
|
landlord_played_cards, # 108/69
|
||||||
|
landlord_up_played_cards, # 108/69
|
||||||
|
landlord_front_played_cards, # 108/69
|
||||||
|
landlord_down_played_cards, # 108/69
|
||||||
|
landlord_cards, # 108/69
|
||||||
|
landlord_up_cards, # 108/69
|
||||||
|
landlord_front_cards, # 108/69
|
||||||
|
landlord_down_cards, # 108/69
|
||||||
|
three_landlord_cards, # 108/69
|
||||||
|
three_landlord_cards_all, # 108/69
|
||||||
|
))
|
||||||
|
|
||||||
|
_z_batch = np.repeat(
|
||||||
|
z[np.newaxis, :, :],
|
||||||
|
num_legal_actions, axis=0)
|
||||||
|
my_action_batch = my_action_batch[:,np.newaxis,:]
|
||||||
|
z_batch = np.concatenate((my_action_batch, _z_batch), axis=1)
|
||||||
|
obs = {
|
||||||
|
'position': position,
|
||||||
|
'x_batch': x_batch.astype(np.float32),
|
||||||
|
'z_batch': z_batch.astype(np.float32),
|
||||||
|
'legal_actions': infoset.legal_actions,
|
||||||
|
'x_no_action': x_no_action.astype(np.int8),
|
||||||
|
'z': z.astype(np.int8),
|
||||||
|
}
|
||||||
|
return obs
|
||||||
|
|
|
@ -127,7 +127,8 @@ class GameEnv(object):
|
||||||
|
|
||||||
self.card_play_action_seq = []
|
self.card_play_action_seq = []
|
||||||
|
|
||||||
# self.three_landlord_cards = None
|
self.three_landlord_cards = None
|
||||||
|
self.three_landlord_cards_all = None
|
||||||
self.game_over = False
|
self.game_over = False
|
||||||
|
|
||||||
self.acting_player_position = None
|
self.acting_player_position = None
|
||||||
|
@ -185,7 +186,12 @@ class GameEnv(object):
|
||||||
card_play_data['landlord_front']
|
card_play_data['landlord_front']
|
||||||
self.info_sets['landlord_down'].player_hand_cards = \
|
self.info_sets['landlord_down'].player_hand_cards = \
|
||||||
card_play_data['landlord_down']
|
card_play_data['landlord_down']
|
||||||
# self.three_landlord_cards = card_play_data['three_landlord_cards']
|
if 'three_landlord_cards' not in card_play_data.keys():
|
||||||
|
self.three_landlord_cards = card_play_data['landlord'][25:33]
|
||||||
|
self.three_landlord_cards_all = card_play_data['landlord'][25:33]
|
||||||
|
else:
|
||||||
|
self.three_landlord_cards = card_play_data['three_landlord_cards'][:]
|
||||||
|
self.three_landlord_cards_all = card_play_data['three_landlord_cards'][:]
|
||||||
self.get_acting_player_position()
|
self.get_acting_player_position()
|
||||||
self.game_infoset = self.get_infoset()
|
self.game_infoset = self.get_infoset()
|
||||||
|
|
||||||
|
@ -253,15 +259,15 @@ class GameEnv(object):
|
||||||
|
|
||||||
self.played_cards[self.acting_player_position] += action
|
self.played_cards[self.acting_player_position] += action
|
||||||
|
|
||||||
# if self.acting_player_position == 'landlord' and \
|
if self.acting_player_position == 'landlord' and \
|
||||||
# len(action) > 0 and \
|
len(action) > 0 and \
|
||||||
# len(self.three_landlord_cards) > 0:
|
len(self.three_landlord_cards) > 0:
|
||||||
# for card in action:
|
for card in action:
|
||||||
# if len(self.three_landlord_cards) > 0:
|
if len(self.three_landlord_cards) > 0:
|
||||||
# if card in self.three_landlord_cards:
|
if card in self.three_landlord_cards:
|
||||||
# self.three_landlord_cards.remove(card)
|
self.three_landlord_cards.remove(card)
|
||||||
# else:
|
else:
|
||||||
# break
|
break
|
||||||
|
|
||||||
self.game_done()
|
self.game_done()
|
||||||
if not self.game_over:
|
if not self.game_over:
|
||||||
|
@ -333,7 +339,8 @@ class GameEnv(object):
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.card_play_action_seq = []
|
self.card_play_action_seq = []
|
||||||
|
|
||||||
# self.three_landlord_cards = None
|
self.three_landlord_cards = None
|
||||||
|
self.three_landlord_cards_all = None
|
||||||
self.game_over = False
|
self.game_over = False
|
||||||
|
|
||||||
self.acting_player_position = None
|
self.acting_player_position = None
|
||||||
|
@ -397,8 +404,10 @@ class GameEnv(object):
|
||||||
|
|
||||||
self.info_sets[self.acting_player_position].played_cards = \
|
self.info_sets[self.acting_player_position].played_cards = \
|
||||||
self.played_cards
|
self.played_cards
|
||||||
# self.info_sets[self.acting_player_position].three_landlord_cards = \
|
self.info_sets[self.acting_player_position].three_landlord_cards = \
|
||||||
# self.three_landlord_cards
|
self.three_landlord_cards
|
||||||
|
self.info_sets[self.acting_player_position].three_landlord_cards_all = \
|
||||||
|
self.three_landlord_cards_all
|
||||||
self.info_sets[self.acting_player_position].card_play_action_seq = \
|
self.info_sets[self.acting_player_position].card_play_action_seq = \
|
||||||
self.card_play_action_seq
|
self.card_play_action_seq
|
||||||
|
|
||||||
|
@ -424,7 +433,8 @@ class InfoSet(object):
|
||||||
# The number of cards left for each player. It is a dict with str-->int
|
# The number of cards left for each player. It is a dict with str-->int
|
||||||
self.num_cards_left_dict = None
|
self.num_cards_left_dict = None
|
||||||
# The three landload cards. A list.
|
# The three landload cards. A list.
|
||||||
# self.three_landlord_cards = None
|
self.three_landlord_cards = None
|
||||||
|
self.three_landlord_cards_all = None
|
||||||
# The historical moves. It is a list of list
|
# The historical moves. It is a list of list
|
||||||
self.card_play_action_seq = None
|
self.card_play_action_seq = None
|
||||||
# The union of the hand cards of the other two players for the current player
|
# The union of the hand cards of the other two players for the current player
|
||||||
|
|
|
@ -6,8 +6,8 @@ from onnxruntime.datasets import get_example
|
||||||
|
|
||||||
from douzero.env.env import get_obs
|
from douzero.env.env import get_obs
|
||||||
|
|
||||||
def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx=False):
|
def _load_model(position, model_path, model_type, use_legacy, use_lite, use_unified, use_onnx=False):
|
||||||
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite
|
from douzero.dmc.models import model_dict_new, model_dict_new_lite, model_dict, model_dict_legacy, model_dict_lite, model_dict_uni_lite
|
||||||
model_path_onnx = model_path + '.onnx'
|
model_path_onnx = model_path + '.onnx'
|
||||||
if use_onnx and os.path.exists(model_path_onnx):
|
if use_onnx and os.path.exists(model_path_onnx):
|
||||||
return None
|
return None
|
||||||
|
@ -20,6 +20,11 @@ def _load_model(position, model_path, model_type, use_legacy, use_lite, use_onnx
|
||||||
else:
|
else:
|
||||||
if use_legacy:
|
if use_legacy:
|
||||||
model = model_dict_legacy[position]()
|
model = model_dict_legacy[position]()
|
||||||
|
elif use_unified:
|
||||||
|
if use_lite:
|
||||||
|
model = model_dict_uni_lite[position]()
|
||||||
|
else:
|
||||||
|
model = model_dict[position]()
|
||||||
else:
|
else:
|
||||||
if use_lite:
|
if use_lite:
|
||||||
model = model_dict_lite[position]()
|
model = model_dict_lite[position]()
|
||||||
|
@ -52,9 +57,10 @@ class DeepAgent:
|
||||||
|
|
||||||
def __init__(self, position, model_path, use_onnx=False):
|
def __init__(self, position, model_path, use_onnx=False):
|
||||||
self.use_legacy = True if "legacy" in model_path else False
|
self.use_legacy = True if "legacy" in model_path else False
|
||||||
|
self.use_unified = True if "uni" in model_path else False
|
||||||
self.lite_model = True if "lite" in model_path else False
|
self.lite_model = True if "lite" in model_path else False
|
||||||
self.model_type = "general" if "resnet" in model_path else "old"
|
self.model_type = "general" if "resnet" in model_path else "old"
|
||||||
self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model, use_onnx=use_onnx)
|
self.model = _load_model(position, model_path, self.model_type, self.use_legacy, self.lite_model, self.use_unified, use_onnx=use_onnx)
|
||||||
self.onnx_model_path = os.path.abspath(model_path + '.onnx')
|
self.onnx_model_path = os.path.abspath(model_path + '.onnx')
|
||||||
self.use_onnx = use_onnx
|
self.use_onnx = use_onnx
|
||||||
self.onnx_model = None
|
self.onnx_model = None
|
||||||
|
@ -71,7 +77,7 @@ class DeepAgent:
|
||||||
if not with_confidence and len(infoset.legal_actions) == 1:
|
if not with_confidence and len(infoset.legal_actions) == 1:
|
||||||
return infoset.legal_actions[0]
|
return infoset.legal_actions[0]
|
||||||
|
|
||||||
obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model)
|
obs = get_obs(infoset, self.model_type == "general", self.use_legacy, self.lite_model, self.use_unified)
|
||||||
|
|
||||||
if self.onnx_model is None:
|
if self.onnx_model is None:
|
||||||
z_batch = torch.from_numpy(obs['z_batch']).float()
|
z_batch = torch.from_numpy(obs['z_batch']).float()
|
||||||
|
|
21
evaluate.py
21
evaluate.py
|
@ -58,15 +58,15 @@ if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
'Dou Dizhu Evaluation')
|
'Dou Dizhu Evaluation')
|
||||||
parser.add_argument('--landlord', type=str,
|
parser.add_argument('--landlord', type=str,
|
||||||
default='baselines/douzero_12/landlord_weights_39762328900.ckpt')
|
default='baselines/lite_uni_5001600.ckpt')
|
||||||
parser.add_argument('--landlord_up', type=str,
|
parser.add_argument('--landlord_up', type=str,
|
||||||
default='baselines/douzero_12/landlord_up_weights_39762328900.ckpt')
|
default='baselines/lite_uni_5001600.ckpt')
|
||||||
parser.add_argument('--landlord_front', type=str,
|
parser.add_argument('--landlord_front', type=str,
|
||||||
default='baselines/douzero_12/landlord_front_weights_39762328900.ckpt')
|
default='baselines/lite_uni_5001600.ckpt')
|
||||||
parser.add_argument('--landlord_down', type=str,
|
parser.add_argument('--landlord_down', type=str,
|
||||||
default='baselines/douzero_12/landlord_down_weights_39762328900.ckpt')
|
default='baselines/lite_uni_5001600.ckpt')
|
||||||
parser.add_argument('--eval_data', type=str,
|
parser.add_argument('--eval_data', type=str,
|
||||||
default='eval_data_200.pkl')
|
default='eval_data_200_r.pkl')
|
||||||
parser.add_argument('--num_workers', type=int, default=3)
|
parser.add_argument('--num_workers', type=int, default=3)
|
||||||
parser.add_argument('--gpu_device', type=str, default='0')
|
parser.add_argument('--gpu_device', type=str, default='0')
|
||||||
parser.add_argument('--output', type=bool, default=True)
|
parser.add_argument('--output', type=bool, default=True)
|
||||||
|
@ -80,6 +80,17 @@ if __name__ == '__main__':
|
||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
|
||||||
|
|
||||||
|
landlord_wp, farmer_wp, landlord_adp, farmer_adp = evaluate(args.landlord,
|
||||||
|
args.landlord_up,
|
||||||
|
args.landlord_front,
|
||||||
|
args.landlord_down,
|
||||||
|
args.eval_data,
|
||||||
|
args.num_workers,
|
||||||
|
args.output,
|
||||||
|
args.title)
|
||||||
|
print(landlord_wp, farmer_wp, landlord_adp, farmer_adp)
|
||||||
|
os._exit(0)
|
||||||
|
|
||||||
baselines = [
|
baselines = [
|
||||||
{'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 736107200},
|
{'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 736107200},
|
||||||
{'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 479412800},
|
{'folder': 'baselines', 'prefix': 'legacy_general', 'frame': 479412800},
|
||||||
|
|
|
@ -28,7 +28,7 @@ RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
|
||||||
@app.route('/upload', methods=['POST'])
|
@app.route('/upload', methods=['POST'])
|
||||||
def upload():
|
def upload():
|
||||||
type = request.form.get('type')
|
type = request.form.get('type')
|
||||||
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla']:
|
if type not in ['lite_resnet', 'lite_vanilla', 'legacy_vanilla', 'lite_unified']:
|
||||||
return jsonify({'status': -1, 'message': 'illegal type'})
|
return jsonify({'status': -1, 'message': 'illegal type'})
|
||||||
position = request.form.get("position")
|
position = request.form.get("position")
|
||||||
if position not in positions:
|
if position not in positions:
|
||||||
|
|
|
@ -21,7 +21,8 @@ def generate():
|
||||||
'landlord_up': _deck[33:58],
|
'landlord_up': _deck[33:58],
|
||||||
'landlord_front': _deck[58:83],
|
'landlord_front': _deck[58:83],
|
||||||
'landlord_down': _deck[83:108],
|
'landlord_down': _deck[83:108],
|
||||||
# 'three_landlord_cards': _deck[25:33],
|
'three_landlord_cards': _deck[25:33],
|
||||||
|
'three_landlord_cards_all': _deck[25:33],
|
||||||
}
|
}
|
||||||
for key in card_play_data:
|
for key in card_play_data:
|
||||||
card_play_data[key].sort()
|
card_play_data[key].sort()
|
||||||
|
|
Loading…
Reference in New Issue