This commit is contained in:
Vincentzyx 2021-09-07 16:38:34 +08:00
parent 5fbacd142e
commit e1e727a2f3
23 changed files with 2189 additions and 0 deletions

80
BidModel.py Normal file
View File

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Created by: Vincentzyx
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import time
def EnvToOnehot(cards):
Env2IdxMap = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8,12:9,13:10,14:11,17:12,20:13,30:14}
cards = [Env2IdxMap[i] for i in cards]
Onehot = torch.zeros((4,15))
for i in range(0, 15):
Onehot[:cards.count(i),i] = 1
return Onehot
def RealToOnehot(cards):
RealCard2EnvCard = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4,
'8': 5, '9': 6, 'T': 7, 'J': 8, 'Q': 9,
'K': 10, 'A': 11, '2': 12, 'X': 13, 'D': 14}
cards = [RealCard2EnvCard[c] for c in cards]
Onehot = torch.zeros((4,15))
for i in range(0, 15):
Onehot[:cards.count(i),i] = 1
return Onehot
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(60, 512)
self.fc2 = nn.Linear(512, 512)
self.fc3 = nn.Linear(512, 512)
self.fc4 = nn.Linear(512, 512)
self.fc5 = nn.Linear(512, 512)
self.fc6 = nn.Linear(512, 1)
self.dropout5 = nn.Dropout(0.5)
self.dropout3 = nn.Dropout(0.3)
self.dropout1 = nn.Dropout(0.1)
def forward(self, input):
x = self.fc1(input)
x = torch.relu(self.dropout1(self.fc2(x)))
x = torch.relu(self.dropout3(self.fc3(x)))
x = torch.relu(self.dropout5(self.fc4(x)))
x = torch.relu(self.dropout5(self.fc5(x)))
x = self.fc6(x)
return x
UseGPU = False
device = torch.device('cuda:0')
net = Net()
net.eval()
if UseGPU:
net = net.to(device)
if os.path.exists("./bid_weights.pkl"):
if torch.cuda.is_available():
net.load_state_dict(torch.load('./bid_weights.pkl'))
else:
net.load_state_dict(torch.load('./bid_weights.pkl', map_location=torch.device("cpu")))
def predict(cards):
input = RealToOnehot(cards)
if UseGPU:
input = input.to(device)
input = torch.flatten(input)
win_rate = net(input)
return win_rate[0].item() * 100
def predict_env(cards):
input = EnvToOnehot(cards)
if UseGPU:
input = input.to(device)
input = torch.flatten(input)
win_rate = net(input)
return win_rate[0].item() * 100

BIN
bid_weights.pkl Normal file

Binary file not shown.

0
douzero/__init__.py Normal file
View File

2
douzero/dmc/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .dmc import train
from .arguments import parser

55
douzero/dmc/arguments.py Normal file
View File

@ -0,0 +1,55 @@
import argparse
parser = argparse.ArgumentParser(description='DouZero: PyTorch DouDizhu AI')
# General Settings
parser.add_argument('--xpid', default='douzero',
help='Experiment id (default: douzero)')
parser.add_argument('--save_interval', default=10, type=int,
help='Time interval (in minutes) at which to save the model')
parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp', 'logadp'],
help='Use ADP or WP as reward (default: ADP)')
# Training settings
parser.add_argument('--actor_device_cpu', action='store_true',
help='Use CPU as actor device')
parser.add_argument('--gpu_devices', default='0', type=str,
help='Which GPUs to be used for training')
parser.add_argument('--num_actor_devices', default=1, type=int,
help='The number of devices used for simulation')
parser.add_argument('--num_actors', default=2, type=int,
help='The number of actors for each simulation device')
parser.add_argument('--training_device', default='0', type=str,
help='The index of the GPU used for training models. `cpu` means using cpu')
parser.add_argument('--load_model', action='store_true',
help='Load an existing model')
parser.add_argument('--disable_checkpoint', action='store_true',
help='Disable saving checkpoint')
parser.add_argument('--savedir', default='douzero_checkpoints',
help='Root dir where experiment data will be saved')
# Hyperparameters
parser.add_argument('--total_frames', default=100000000000, type=int,
help='Total environment frames to train for')
parser.add_argument('--exp_epsilon', default=0.01, type=float,
help='The probability for exploration')
parser.add_argument('--batch_size', default=16, type=int,
help='Learner batch size')
parser.add_argument('--unroll_length', default=100, type=int,
help='The unroll length (time dimension)')
parser.add_argument('--num_buffers', default=50, type=int,
help='Number of shared-memory buffers')
parser.add_argument('--num_threads', default=1, type=int,
help='Number learner threads')
parser.add_argument('--max_grad_norm', default=40., type=float,
help='Max norm of gradients')
# Optimizer settings
parser.add_argument('--learning_rate', default=0.0001, type=float,
help='Learning rate')
parser.add_argument('--alpha', default=0.99, type=float,
help='RMSProp smoothing constant')
parser.add_argument('--momentum', default=0, type=float,
help='RMSProp momentum')
parser.add_argument('--epsilon', default=1e-8, type=float,
help='RMSProp epsilon')

252
douzero/dmc/dmc.py Normal file
View File

@ -0,0 +1,252 @@
import os
import threading
import time
import timeit
import pprint
from collections import deque
import numpy as np
import torch
from torch import multiprocessing as mp
from torch import nn
import douzero.dmc.models
import douzero.env.env
from .file_writer import FileWriter
from .models import Model, OldModel
from .utils import get_batch, log, create_env, create_optimizers, act
mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_down', 'bidding']}
def compute_loss(logits, targets):
loss = ((logits.squeeze(-1) - targets)**2).mean()
return loss
def compute_loss_for_bid(outputs, reward):
pass
def learn(position, actor_models, model, batch, optimizer, flags, lock):
"""Performs a learning (optimization) step."""
position_index = {"landlord": 31, "landlord_up": 32, "landlord_down": 33}
print("Learn", position)
if flags.training_device != "cpu":
device = torch.device('cuda:'+str(flags.training_device))
else:
device = torch.device('cpu')
obs_x = batch["obs_x_batch"]
obs_x = torch.flatten(obs_x, 0, 1).to(device)
obs_z = torch.flatten(batch['obs_z'].to(device), 0, 1).float()
target = torch.flatten(batch['target'].to(device), 0, 1)
if position != "bidding":
episode_returns = batch['episode_return'][batch['done'] & (batch["obs_type"] == position_index[position])]
else:
episode_returns = batch['episode_return'][batch['done'] & ((batch["obs_type"] == 41) | (batch["obs_type"] == 42) | (batch["obs_type"] == 43))]
if len(episode_returns) > 0:
mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device))
with lock:
learner_outputs = model(obs_z, obs_x, return_value=True)
if position == "bidding":
pass
else:
loss = compute_loss(learner_outputs['values'], target)
stats = {
'mean_episode_return_'+position: torch.mean(torch.stack([_r for _r in mean_episode_return_buf[position]])).item(),
'loss_'+position: loss.item(),
}
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
optimizer.step()
for actor_model in actor_models.values():
actor_model.get_model(position).load_state_dict(model.state_dict())
return stats
def train(flags):
"""
This is the main funtion for training. It will first
initilize everything, such as buffers, optimizers, etc.
Then it will start subprocesses as actors. Then, it will call
learning function with multiple threads.
"""
if not flags.actor_device_cpu or flags.training_device != 'cpu':
if not torch.cuda.is_available():
raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`")
plogger = FileWriter(
xpid=flags.xpid,
xp_args=flags.__dict__,
rootdir=flags.savedir,
)
checkpointpath = os.path.expandvars(
os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar')))
T = flags.unroll_length
B = flags.batch_size
if flags.actor_device_cpu:
device_iterator = ['cpu']
else:
device_iterator = range(flags.num_actor_devices)
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
# Initialize actor models
models = {}
for device in device_iterator:
model = Model(device="cpu")
model.share_memory()
model.eval()
models[device] = model
# Initialize queues
actor_processes = []
ctx = mp.get_context('spawn')
batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue(), "bidding": ctx.SimpleQueue()}
# Learner model for training
learner_model = Model(device=flags.training_device)
# Create optimizers
optimizers = create_optimizers(flags, learner_model)
# Stat Keys
stat_keys = [
'mean_episode_return_landlord',
'loss_landlord',
'mean_episode_return_landlord_up',
'loss_landlord_up',
'mean_episode_return_landlord_down',
'loss_landlord_down',
'mean_episode_return_bidding',
'loss_bidding',
]
frames, stats = 0, {k: 0 for k in stat_keys}
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_down':0, 'bidding': 0}
# Load models if any
if flags.load_model and os.path.exists(checkpointpath):
checkpoint_states = torch.load(
checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu")
)
for k in ['landlord', 'landlord_up', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_down']
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
for device in device_iterator:
models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
stats = checkpoint_states["stats"]
if not 'mean_episode_return_bidding' in stats:
stats.update({"mean_episode_return_bidding": 0})
if not 'loss_bidding' in stats:
stats.update({"loss_bidding": 0})
frames = checkpoint_states["frames"]
position_frames = checkpoint_states["position_frames"]
if not "bidding" in position_frames:
position_frames.update({"bidding": 0})
log.info(f"Resuming preempted job, current stats:\n{stats}")
# Starting actor processes
for device in device_iterator:
num_actors = flags.num_actors
for i in range(flags.num_actors):
actor = mp.Process(
target=act,
args=(i, device, batch_queues, models[device], flags))
# actor.setDaemon(True)
actor.start()
actor_processes.append(actor)
def batch_and_learn(i, device, position, local_lock, position_lock, lock=threading.Lock()):
"""Thread target for the learning process."""
nonlocal frames, position_frames, stats
while frames < flags.total_frames:
batch = get_batch(batch_queues, position, flags, local_lock)
_stats = learn(position, models, learner_model.get_model(position), batch,
optimizers[position], flags, position_lock)
with lock:
for k in _stats:
stats[k] = _stats[k]
to_log = dict(frames=frames)
to_log.update({k: stats[k] for k in stat_keys})
plogger.log(to_log)
frames += T * B
position_frames[position] += T * B
threads = []
locks = {}
for device in device_iterator:
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_down': threading.Lock(), 'bidding': threading.Lock()}
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_down': threading.Lock(), 'bidding': threading.Lock()}
for device in device_iterator:
for i in range(flags.num_threads):
for position in ['landlord', 'landlord_up', 'landlord_down', 'bidding']:
thread = threading.Thread(
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,device,position,locks[device][position],position_locks[position]))
thread.start()
threads.append(thread)
def checkpoint(frames):
if flags.disable_checkpoint:
return
log.info('Saving checkpoint to %s', checkpointpath)
_models = learner_model.get_models()
torch.save({
'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()}
'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()}
"stats": stats,
'flags': vars(flags),
'frames': frames,
'position_frames': position_frames
}, checkpointpath)
# Save the weights for evaluation purpose
for position in ['landlord', 'landlord_up', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_down']
model_weights_dir = os.path.expandvars(os.path.expanduser(
'%s/%s/%s' % (flags.savedir, flags.xpid, "general_"+position+'_'+str(frames)+'.ckpt')))
torch.save(learner_model.get_model(position).state_dict(), model_weights_dir)
fps_log = []
timer = timeit.default_timer
try:
last_checkpoint_time = timer() - flags.save_interval * 60
while frames < flags.total_frames:
start_frames = frames
position_start_frames = {k: position_frames[k] for k in position_frames}
start_time = timer()
time.sleep(5)
if timer() - last_checkpoint_time > flags.save_interval * 60:
checkpoint(frames)
last_checkpoint_time = timer()
end_time = timer()
fps = (frames - start_frames) / (end_time - start_time)
fps_log.append(fps)
if len(fps_log) > 24:
fps_log = fps_log[1:]
fps_avg = np.mean(fps_log)
position_fps = {k:(position_frames[k]-position_start_frames[k])/(end_time-start_time) for k in position_frames}
log.info('After %i (L:%i U:%i D:%i) frames: @ %.1f fps (avg@ %.1f fps) (L:%.1f U:%.1f D:%.1f) Stats:\n%s',
frames,
position_frames['landlord'],
position_frames['landlord_up'],
position_frames['landlord_down'],
fps,
fps_avg,
position_fps['landlord'],
position_fps['landlord_up'],
position_fps['landlord_down'],
pprint.pformat(stats))
except KeyboardInterrupt:
return
else:
for thread in threads:
thread.join()
log.info('Learning finished after %d frames.', frames)
checkpoint(frames)
plogger.close()

89
douzero/dmc/env_utils.py Normal file
View File

@ -0,0 +1,89 @@
"""
Here, we wrap the original environment to make it easier
to use. When a game is finished, instead of mannualy reseting
the environment, we do it automatically.
"""
import numpy as np
import torch
def _format_observation(obs, device):
"""
A utility function to process observations and
move them to CUDA.
"""
position = obs['position']
if not device == "cpu":
device = 'cuda:' + str(device)
device = torch.device(device)
x_batch = torch.from_numpy(obs['x_batch']).to(device)
z_batch = torch.from_numpy(obs['z_batch']).to(device)
x_no_action = torch.from_numpy(obs['x_no_action'])
z = torch.from_numpy(obs['z'])
obs = {'x_batch': x_batch,
'z_batch': z_batch,
'legal_actions': obs['legal_actions'],
}
return position, obs, x_no_action, z
class Environment:
def __init__(self, env, device):
""" Initialzie this environment wrapper
"""
self.env = env
self.device = device
self.episode_return = None
def initial(self, model, device, flags=None):
obs, buf = self.env.reset(model, device, flags=flags)
initial_position, initial_obs, x_no_action, z = _format_observation(obs, self.device)
initial_reward = torch.zeros(1, 1)
self.episode_return = torch.zeros(1, 1)
initial_done = torch.ones(1, 1, dtype=torch.bool)
if buf is None:
return initial_position, initial_obs, dict(
done=initial_done,
episode_return=self.episode_return,
obs_x_no_action=x_no_action,
obs_z=z,
)
else:
return initial_position, initial_obs, dict(
done=initial_done,
episode_return=self.episode_return,
obs_x_no_action=x_no_action,
obs_z=z,
begin_buf=buf
)
def step(self, action, model, device, flags=None):
obs, reward, done, _ = self.env.step(action)
self.episode_return = reward
episode_return = self.episode_return
buf = None
if done:
obs, buf = self.env.reset(model, device, flags=flags)
self.episode_return = torch.zeros(1, 1)
position, obs, x_no_action, z = _format_observation(obs, self.device)
# reward = torch.tensor(reward).view(1, 1)
done = torch.tensor(done).view(1, 1)
if buf is None:
return position, obs, dict(
done=done,
episode_return=episode_return,
obs_x_no_action=x_no_action,
obs_z=z,
)
else:
return position, obs, dict(
done=done,
episode_return=episode_return,
obs_x_no_action=x_no_action,
obs_z=z,
begin_buf=buf
)
def close(self):
self.env.close()

188
douzero/dmc/file_writer.py Normal file
View File

@ -0,0 +1,188 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import datetime
import csv
import json
import logging
import os
import time
from typing import Dict
# import git
# def gather_metadata() -> Dict:
# date_start = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
# # gathering git metadata
# try:
# repo = git.Repo(search_parent_directories=True)
# git_sha = repo.commit().hexsha
# git_data = dict(
# commit=git_sha,
# branch=repo.active_branch.name,
# is_dirty=repo.is_dirty(),
# path=repo.git_dir,
# )
# except git.InvalidGitRepositoryError:
# git_data = None
# # gathering slurm metadata
# if 'SLURM_JOB_ID' in os.environ:
# slurm_env_keys = [k for k in os.environ if k.startswith('SLURM')]
# slurm_data = {}
# for k in slurm_env_keys:
# d_key = k.replace('SLURM_', '').replace('SLURMD_', '').lower()
# slurm_data[d_key] = os.environ[k]
# else:
# slurm_data = None
# return dict(
# date_start=date_start,
# date_end=None,
# successful=False,
# git=git_data,
# slurm=slurm_data,
# env=os.environ.copy(),
# )
class FileWriter:
def __init__(self,
xpid: str = None,
xp_args: dict = None,
rootdir: str = '~/palaas'):
if not xpid:
# make unique id
xpid = '{proc}_{unixtime}'.format(
proc=os.getpid(), unixtime=int(time.time()))
self.xpid = xpid
self._tick = 0
# metadata gathering
if xp_args is None:
xp_args = {}
# self.metadata = gather_metadata()
# # we need to copy the args, otherwise when we close the file writer
# # (and rewrite the args) we might have non-serializable objects (or
# # other nasty stuff).
# self.metadata['args'] = copy.deepcopy(xp_args)
# self.metadata['xpid'] = self.xpid
formatter = logging.Formatter('%(message)s')
self._logger = logging.getLogger('palaas/out')
# to stdout handler
shandle = logging.StreamHandler()
shandle.setFormatter(formatter)
self._logger.addHandler(shandle)
self._logger.setLevel(logging.INFO)
rootdir = os.path.expandvars(os.path.expanduser(rootdir))
# to file handler
self.basepath = os.path.join(rootdir, self.xpid)
if not os.path.exists(self.basepath):
self._logger.info('Creating log directory: %s', self.basepath)
os.makedirs(self.basepath, exist_ok=True)
else:
self._logger.info('Found log directory: %s', self.basepath)
# NOTE: remove latest because it creates errors when running on slurm
# multiple jobs trying to write to latest but cannot find it
# Add 'latest' as symlink unless it exists and is no symlink.
# symlink = os.path.join(rootdir, 'latest')
# if os.path.islink(symlink):
# os.remove(symlink)
# if not os.path.exists(symlink):
# os.symlink(self.basepath, symlink)
# self._logger.info('Symlinked log directory: %s', symlink)
self.paths = dict(
msg='{base}/out.log'.format(base=self.basepath),
logs='{base}/logs.csv'.format(base=self.basepath),
fields='{base}/fields.csv'.format(base=self.basepath),
meta='{base}/meta.json'.format(base=self.basepath),
)
self._logger.info('Saving arguments to %s', self.paths['meta'])
if os.path.exists(self.paths['meta']):
self._logger.warning('Path to meta file already exists. '
'Not overriding meta.')
else:
self._save_metadata()
self._logger.info('Saving messages to %s', self.paths['msg'])
if os.path.exists(self.paths['msg']):
self._logger.warning('Path to message file already exists. '
'New data will be appended.')
fhandle = logging.FileHandler(self.paths['msg'])
fhandle.setFormatter(formatter)
self._logger.addHandler(fhandle)
self._logger.info('Saving logs data to %s', self.paths['logs'])
self._logger.info('Saving logs\' fields to %s', self.paths['fields'])
if os.path.exists(self.paths['logs']):
self._logger.warning('Path to log file already exists. '
'New data will be appended.')
with open(self.paths['fields'], 'r') as csvfile:
reader = csv.reader(csvfile)
self.fieldnames = list(reader)[0]
else:
self.fieldnames = ['_tick', '_time']
def log(self, to_log: Dict, tick: int = None,
verbose: bool = False) -> None:
if tick is not None:
raise NotImplementedError
else:
to_log['_tick'] = self._tick
self._tick += 1
to_log['_time'] = time.time()
old_len = len(self.fieldnames)
for k in to_log:
if k not in self.fieldnames:
self.fieldnames.append(k)
if old_len != len(self.fieldnames):
with open(self.paths['fields'], 'w') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(self.fieldnames)
self._logger.info('Updated log fields: %s', self.fieldnames)
if to_log['_tick'] == 0:
# print("\ncreating logs file ")
with open(self.paths['logs'], 'a') as f:
f.write('# %s\n' % ','.join(self.fieldnames))
if verbose:
self._logger.info('LOG | %s', ', '.join(
['{}: {}'.format(k, to_log[k]) for k in sorted(to_log)]))
with open(self.paths['logs'], 'a') as f:
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
writer.writerow(to_log)
# print("\nadded to log file")
def close(self, successful: bool = True) -> None:
# self.metadata['date_end'] = datetime.datetime.now().strftime(
# '%Y-%m-%d %H:%M:%S.%f')
# self.metadata['successful'] = successful
self._save_metadata()
def _save_metadata(self) -> None:
with open(self.paths['meta'], 'w') as jsonfile:
pass
# json.dump(self.metadata, jsonfile, indent=4, sort_keys=True)

480
douzero/dmc/models.py Normal file
View File

@ -0,0 +1,480 @@
"""
This file includes the torch models. We wrap the three
models into one class for convenience.
"""
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
class LandlordLstmModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(162, 128, batch_first=True)
self.dense1 = nn.Linear(373 + 128, 512)
self.dense2 = nn.Linear(512, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
class FarmerLstmModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(162, 128, batch_first=True)
self.dense1 = nn.Linear(484 + 128, 512)
self.dense2 = nn.Linear(512, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
class LandlordLstmNewModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(162, 128, batch_first=True)
self.dense1 = nn.Linear(373 + 128, 512)
self.dense2 = nn.Linear(512, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
class FarmerLstmNewModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(162, 128, batch_first=True)
self.dense1 = nn.Linear(484 + 128, 512)
self.dense2 = nn.Linear(512, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None):
lstm_out, (h_n, _) = self.lstm(z)
lstm_out = lstm_out[:,-1,:]
x = torch.cat([lstm_out,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action)
class GeneralModel1(nn.Module):
def __init__(self):
super().__init__()
# input: B * 32 * 57
# self.lstm = nn.LSTM(162, 512, batch_first=True)
self.conv_z_1 = torch.nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(1,57)), # B * 1 * 64 * 32
nn.ReLU(inplace=True),
nn.BatchNorm2d(64),
)
# Squeeze(-1) B * 64 * 16
self.conv_z_2 = torch.nn.Sequential(
nn.Conv1d(64, 128, kernel_size=(5,), padding=2), # 128 * 16
nn.ReLU(inplace=True),
nn.BatchNorm1d(128),
)
self.conv_z_3 = torch.nn.Sequential(
nn.Conv1d(128, 256, kernel_size=(3,), padding=1), # 256 * 8
nn.ReLU(inplace=True),
nn.BatchNorm1d(256),
)
self.conv_z_4 = torch.nn.Sequential(
nn.Conv1d(256, 512, kernel_size=(3,), padding=1), # 512 * 4
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
)
self.dense1 = nn.Linear(519 + 1024, 1024)
self.dense2 = nn.Linear(1024, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None, debug=False):
z = z.unsqueeze(1)
z = self.conv_z_1(z)
z = z.squeeze(-1)
z = torch.max_pool1d(z, 2)
z = self.conv_z_2(z)
z = torch.max_pool1d(z, 2)
z = self.conv_z_3(z)
z = torch.max_pool1d(z, 2)
z = self.conv_z_4(z)
z = torch.max_pool1d(z, 2)
z = z.flatten(1,2)
x = torch.cat([z,x], dim=-1)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
x = torch.relu(x)
x = self.dense3(x)
x = torch.relu(x)
x = self.dense4(x)
x = torch.relu(x)
x = self.dense5(x)
x = torch.relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action, max_value=torch.max(x))
# 用于ResNet18和34的残差块用的是2个3x3的卷积
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=(3,),
stride=(stride,), padding=1, bias=False)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = nn.Conv1d(planes, planes, kernel_size=(3,),
stride=(1,), padding=1, bias=False)
self.bn2 = nn.BatchNorm1d(planes)
self.shortcut = nn.Sequential()
# 经过处理后的x要与x的维度相同(尺寸和深度)
# 如果不相同,需要添加卷积+BN来变换为同一维度
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv1d(in_planes, self.expansion * planes,
kernel_size=(1,), stride=(stride,), bias=False),
nn.BatchNorm1d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class GeneralModel(nn.Module):
def __init__(self):
super().__init__()
self.in_planes = 80
#input 1*54*41
self.conv1 = nn.Conv1d(40, 80, kernel_size=(3,),
stride=(2,), padding=1, bias=False) #1*27*80
self.bn1 = nn.BatchNorm1d(80)
self.layer1 = self._make_layer(BasicBlock, 80, 2, stride=2)#1*14*80
self.layer2 = self._make_layer(BasicBlock, 160, 2, stride=2)#1*7*160
self.layer3 = self._make_layer(BasicBlock, 320, 2, stride=2)#1*4*320
# self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear1 = nn.Linear(320 * BasicBlock.expansion * 4 + 15 * 4, 1024)
self.linear2 = nn.Linear(1024, 512)
self.linear3 = nn.Linear(512, 256)
self.linear4 = nn.Linear(256, 1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, z, x, return_value=False, flags=None, debug=False):
out = F.relu(self.bn1(self.conv1(z)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = out.flatten(1,2)
out = torch.cat([x,x,x,x,out], dim=-1)
out = F.leaky_relu_(self.linear1(out))
out = F.leaky_relu_(self.linear2(out))
out = F.leaky_relu_(self.linear3(out))
out = F.leaky_relu_(self.linear4(out))
if return_value:
return dict(values=out)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(out.shape[0], (1,))[0]
else:
action = torch.argmax(out,dim=0)[0]
return dict(action=action, max_value=torch.max(out))
class BidModel(nn.Module):
def __init__(self):
super().__init__()
self.dense1 = nn.Linear(114, 512)
self.dense2 = nn.Linear(512, 512)
self.dense3 = nn.Linear(512, 512)
self.dense4 = nn.Linear(512, 512)
self.dense5 = nn.Linear(512, 512)
self.dense6 = nn.Linear(512, 1)
def forward(self, z, x, return_value=False, flags=None, debug=False):
x = self.dense1(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense2(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense3(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense4(x)
x = F.leaky_relu(x)
# x = F.relu(x)
x = self.dense5(x)
# x = F.relu(x)
x = F.leaky_relu(x)
x = self.dense6(x)
if return_value:
return dict(values=x)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
action = torch.randint(x.shape[0], (1,))[0]
else:
action = torch.argmax(x,dim=0)[0]
return dict(action=action, max_value=torch.max(x))
# Model dict is only used in evaluation but not training
model_dict = {}
model_dict['landlord'] = LandlordLstmModel
model_dict['landlord_up'] = FarmerLstmModel
model_dict['landlord_down'] = FarmerLstmModel
model_dict_new = {}
model_dict_new['landlord'] = GeneralModel
model_dict_new['landlord_up'] = GeneralModel
model_dict_new['landlord_down'] = GeneralModel
model_dict_new['bidding'] = BidModel
model_dict_lstm = {}
model_dict_lstm['landlord'] = GeneralModel
model_dict_lstm['landlord_up'] = GeneralModel
model_dict_lstm['landlord_down'] = GeneralModel
class General_Model:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0):
self.models = {}
if not device == "cpu":
device = 'cuda:' + str(device)
# model = GeneralModel().to(torch.device(device))
self.models['landlord'] = GeneralModel1().to(torch.device(device))
self.models['landlord_up'] = GeneralModel1().to(torch.device(device))
self.models['landlord_down'] = GeneralModel1().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None, debug=False):
model = self.models[position]
return model.forward(z, x, training, flags, debug)
def share_memory(self):
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_down'].share_memory()
self.models['bidding'].share_memory()
def eval(self):
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_down'].eval()
self.models['bidding'].eval()
def parameters(self, position):
return self.models[position].parameters()
def get_model(self, position):
return self.models[position]
def get_models(self):
return self.models
class OldModel:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0):
self.models = {}
if not device == "cpu":
device = 'cuda:' + str(device)
self.models['landlord'] = LandlordLstmModel().to(torch.device(device))
self.models['landlord_up'] = FarmerLstmModel().to(torch.device(device))
self.models['landlord_down'] = FarmerLstmModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None):
model = self.models[position]
return model.forward(z, x, training, flags)
def share_memory(self):
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_down'].share_memory()
def eval(self):
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_down'].eval()
def parameters(self, position):
return self.models[position].parameters()
def get_model(self, position):
return self.models[position]
def get_models(self):
return self.models
class Model:
"""
The wrapper for the three models. We also wrap several
interfaces such as share_memory, eval, etc.
"""
def __init__(self, device=0):
self.models = {}
if not device == "cpu":
device = 'cuda:' + str(device)
# model = GeneralModel().to(torch.device(device))
self.models['landlord'] = GeneralModel().to(torch.device(device))
self.models['landlord_up'] = GeneralModel().to(torch.device(device))
self.models['landlord_down'] = GeneralModel().to(torch.device(device))
self.models['bidding'] = BidModel().to(torch.device(device))
def forward(self, position, z, x, training=False, flags=None, debug=False):
model = self.models[position]
return model.forward(z, x, training, flags, debug)
def share_memory(self):
self.models['landlord'].share_memory()
self.models['landlord_up'].share_memory()
self.models['landlord_down'].share_memory()
self.models['bidding'].share_memory()
def eval(self):
self.models['landlord'].eval()
self.models['landlord_up'].eval()
self.models['landlord_down'].eval()
self.models['bidding'].eval()
def parameters(self, position):
return self.models[position].parameters()
def get_model(self, position):
return self.models[position]
def get_models(self):
return self.models

199
douzero/dmc/utils.py Normal file
View File

@ -0,0 +1,199 @@
import os
import typing
import logging
import traceback
import numpy as np
from collections import Counter
import time
from douzero.radam.radam import RAdam
import torch
from torch import multiprocessing as mp
from .env_utils import Environment
from douzero.env import Env
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
11: 8, 12: 9, 13: 10, 14: 11, 17: 12}
NumOnes2Array = {0: np.array([0, 0, 0, 0]),
1: np.array([1, 0, 0, 0]),
2: np.array([1, 1, 0, 0]),
3: np.array([1, 1, 1, 0]),
4: np.array([1, 1, 1, 1])}
shandle = logging.StreamHandler()
shandle.setFormatter(
logging.Formatter(
'[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
'%(message)s'))
log = logging.getLogger('doudzero')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)
# Buffers are used to transfer data between actor processes
# and learner processes. They are shared tensors in GPU
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags):
return Env(flags.objective)
def get_batch(b_queues, position, flags, lock):
"""
This function will sample a batch from the buffers based
on the indices received from the full queue. It will also
free the indices by sending it to full_queue.
"""
b_queue = b_queues[position]
buffer = []
while len(buffer) < flags.batch_size:
buffer.append(b_queue.get())
batch = {
key: torch.stack([m[key] for m in buffer], dim=1)
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"]
}
del buffer
return batch
def create_optimizers(flags, learner_model):
"""
Create three optimizers for the three positions
"""
positions = ['landlord', 'landlord_up', 'landlord_down', 'bidding']
optimizers = {}
for position in positions:
optimizer = RAdam(
learner_model.parameters(position),
lr=flags.learning_rate,
eps=flags.epsilon)
optimizers[position] = optimizer
return optimizers
def act(i, device, batch_queues, model, flags):
positions = ['landlord', 'landlord_up', 'landlord_down', 'bidding']
for pos in positions:
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
try:
T = flags.unroll_length
log.info('Device %s Actor %i started.', str(device), i)
env = create_env(flags)
env = Environment(env, device)
done_buf = {p: [] for p in positions}
episode_return_buf = {p: [] for p in positions}
target_buf = {p: [] for p in positions}
obs_z_buf = {p: [] for p in positions}
size = {p: 0 for p in positions}
type_buf = {p: [] for p in positions}
obs_x_batch_buf = {p: [] for p in positions}
position_index = {"landlord": 31, "landlord_up": 32, "landlord_down": 33}
bid_type_index = {"landlord": 41, "landlord_up": 42, "landlord_down": 43}
bid_type_map = {41: "landlord", 42: "landlord_up", 43: "landlord_down"}
position, obs, env_output = env.initial(model, device, flags=flags)
bid_obs_buffer = env_output["begin_buf"]["bid_obs_buffer"]
multiply_obs_buffer = env_output["begin_buf"]["multiply_obs_buffer"]
while True:
# print("posi", position)
for bid_obs in bid_obs_buffer:
obs_z_buf["bidding"].append(bid_obs['z_batch'])
obs_x_batch_buf["bidding"].append(bid_obs["x_batch"])
type_buf["bidding"].append(bid_type_index[bid_obs["position"]])
size["bidding"] += 1
for mul_obs in multiply_obs_buffer:
obs_z_buf[mul_obs["position"]].append(mul_obs['z_batch'])
obs_x_batch_buf[mul_obs["position"]].append(mul_obs["x_batch"])
type_buf[mul_obs["position"]].append(2)
size[mul_obs["position"]] += 1
while True:
with torch.no_grad():
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
_action_idx = int(agent_output['action'].cpu().detach().numpy())
action = obs['legal_actions'][_action_idx]
obs_z_buf[position].append(torch.vstack((_cards2tensor(action).unsqueeze(0), env_output['obs_z'])).float())
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
x_batch = env_output['obs_x_no_action'].float()
obs_x_batch_buf[position].append(x_batch)
type_buf[position].append(position_index[position])
position, obs, env_output = env.step(action, model, device, flags=flags)
size[position] += 1
if env_output['done']:
bid_obs_buffer = env_output["begin_buf"]["bid_obs_buffer"]
multiply_obs_buffer = env_output["begin_buf"]["multiply_obs_buffer"]
for p in positions:
diff = size[p] - len(target_buf[p])
# print(p, diff)
if diff > 0:
done_buf[p].extend([False for _ in range(diff-1)])
done_buf[p].append(True)
if p != "bidding":
episode_return = env_output['episode_return']["play"][p] if p == 'landlord' else -env_output['episode_return']["play"][p]
episode_return_buf[p].extend([0.0 for _ in range(diff-1)])
episode_return_buf[p].append(episode_return)
target_buf[p].extend([episode_return for _ in range(diff)])
else:
offset = len(target_buf[p])
for index in range(diff):
pos = type_buf[p][index+offset]
if pos == 41:
episode_return = env_output['episode_return']["bid"]["landlord"]
else:
episode_return = -env_output['episode_return']["bid"][bid_type_map[pos]]
episode_return_buf[p].append(episode_return)
# print(p, episode_return)
target_buf[p].append(episode_return)
break
for p in positions:
if size[p] > T:
# print(p, "epr", torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),)
batch_queues[p].put({
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_batch_buf[p][:T]]),
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
})
done_buf[p] = done_buf[p][T:]
episode_return_buf[p] = episode_return_buf[p][T:]
target_buf[p] = target_buf[p][T:]
obs_x_batch_buf[p] = obs_x_batch_buf[p][T:]
obs_z_buf[p] = obs_z_buf[p][T:]
type_buf[p] = type_buf[p][T:]
size[p] -= T
except KeyboardInterrupt:
pass
except Exception as e:
log.error('Exception in worker process %i', i)
traceback.print_exc()
print()
raise e
def _cards2tensor(list_cards):
"""
Convert a list of integers to the tensor
representation
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
"""
if len(list_cards) == 0:
return torch.zeros(54, dtype=torch.int8)
matrix = np.zeros([4, 13], dtype=np.int8)
jokers = np.zeros(2, dtype=np.int8)
counter = Counter(list_cards)
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
elif card == 20:
jokers[0] = 1
elif card == 30:
jokers[1] = 1
matrix = np.concatenate((matrix.flatten('F'), jokers))
matrix = torch.from_numpy(matrix)
return matrix

View File

View File

@ -0,0 +1,62 @@
import torch
import numpy as np
from douzero.env.env import get_obs
def _load_model(position, model_path, model_type):
from douzero.dmc.models import model_dict_new, model_dict
model = None
if model_type == "general":
model = model_dict_new[position]()
else:
model = model_dict[position]()
model_state_dict = model.state_dict()
if torch.cuda.is_available():
pretrained = torch.load(model_path, map_location='cuda:0')
else:
pretrained = torch.load(model_path, map_location='cpu')
pretrained = {k: v for k, v in pretrained.items() if k in model_state_dict}
model_state_dict.update(pretrained)
model.load_state_dict(model_state_dict)
# torch.save(model.state_dict(), model_path.replace(".ckpt", "_nobn.ckpt"))
if torch.cuda.is_available():
model.cuda()
model.eval()
return model
class DeepAgent:
def __init__(self, position, model_path):
self.model_type = "general" if "resnet" in model_path else "old"
self.model = _load_model(position, model_path, self.model_type)
self.EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
def act(self, infoset):
if len(infoset.legal_actions) == 1:
return infoset.legal_actions[0]
obs = get_obs(infoset, self.model_type == "general")
z_batch = torch.from_numpy(obs['z_batch']).float()
x_batch = torch.from_numpy(obs['x_batch']).float()
if torch.cuda.is_available():
z_batch, x_batch = z_batch.cuda(), x_batch.cuda()
y_pred = self.model.forward(z_batch, x_batch, return_value=True)['values']
y_pred = y_pred.detach().cpu().numpy()
best_action_index = np.argmax(y_pred, axis=0)[0]
best_action = infoset.legal_actions[best_action_index]
# action_list = []
# output = ""
# for i, action in enumerate(y_pred):
# action_list.append((y_pred[i].item(), "".join([self.EnvCard2RealCard[ii] for ii in infoset.legal_actions[i]]) if len(infoset.legal_actions[i]) != 0 else "Pass"))
# action_list.sort(key=lambda x: x[0], reverse=True)
# value_list = []
# for action in action_list:
# output += str(round(action[0],3)) + " " + action[1] + "\n"
# value_list.append(action[0])
# # print(value_list)
# print(output)
# print("--------------------\n")
return best_action

View File

@ -0,0 +1,9 @@
import random
class RandomAgent():
def __init__(self):
self.name = 'Random'
def act(self, infoset):
return random.choice(infoset.legal_actions)

View File

@ -0,0 +1,183 @@
import random
from rlcard.games.doudizhu.utils import CARD_TYPE
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'B', 30: 'R'}
RealCard2EnvCard = {'3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
'8': 8, '9': 9, 'T': 10, 'J': 11, 'Q': 12,
'K': 13, 'A': 14, '2': 17, 'B': 20, 'R': 30}
INDEX = {'3': 0, '4': 1, '5': 2, '6': 3, '7': 4,
'8': 5, '9': 6, 'T': 7, 'J': 8, 'Q': 9,
'K': 10, 'A': 11, '2': 12, 'B': 13, 'R': 14}
class RLCardAgent(object):
def __init__(self, position):
self.name = 'RLCard'
self.position = position
def act(self, infoset):
try:
# Hand cards
hand_cards = infoset.player_hand_cards
for i, c in enumerate(hand_cards):
hand_cards[i] = EnvCard2RealCard[c]
hand_cards = ''.join(hand_cards)
# Last move
last_move = infoset.last_move.copy()
for i, c in enumerate(last_move):
last_move[i] = EnvCard2RealCard[c]
last_move = ''.join(last_move)
# Last two moves
last_two_cards = infoset.last_two_moves
for i in range(2):
for j, c in enumerate(last_two_cards[i]):
last_two_cards[i][j] = EnvCard2RealCard[c]
last_two_cards[i] = ''.join(last_two_cards[i])
# Last pid
last_pid = infoset.last_pid
action = None
# the rule of leading round
if last_two_cards[0] == '' and last_two_cards[1] == '':
chosen_action = None
comb = combine_cards(hand_cards)
min_card = hand_cards[0]
for _, acs in comb.items():
for ac in acs:
if min_card in ac:
chosen_action = ac
action = [char for char in chosen_action]
for i, c in enumerate(action):
action[i] = RealCard2EnvCard[c]
#print('lead action:', action)
# the rule of following cards
else:
the_type = CARD_TYPE[0][last_move][0][0]
chosen_action = ''
rank = 1000
for ac in infoset.legal_actions:
_ac = ac.copy()
for i, c in enumerate(_ac):
_ac[i] = EnvCard2RealCard[c]
_ac = ''.join(_ac)
if _ac != '' and the_type == CARD_TYPE[0][_ac][0][0]:
if int(CARD_TYPE[0][_ac][0][1]) < rank:
rank = int(CARD_TYPE[0][_ac][0][1])
chosen_action = _ac
if chosen_action != '':
action = [char for char in chosen_action]
for i, c in enumerate(action):
action[i] = RealCard2EnvCard[c]
#print('action:', action)
elif last_pid != 'landlord' and self.position != 'landlord':
action = []
if action is None:
action = random.choice(infoset.legal_actions)
except:
action = random.choice(infoset.legal_actions)
#import traceback
#traceback.print_exc()
assert action in infoset.legal_actions
return action
def card_str2list(hand):
hand_list = [0 for _ in range(15)]
for card in hand:
hand_list[INDEX[card]] += 1
return hand_list
def list2card_str(hand_list):
card_str = ''
cards = [card for card in INDEX]
for index, count in enumerate(hand_list):
card_str += cards[index] * count
return card_str
def pick_chain(hand_list, count):
chains = []
str_card = [card for card in INDEX]
hand_list = [str(card) for card in hand_list]
hand = ''.join(hand_list[:12])
chain_list = hand.split('0')
add = 0
for index, chain in enumerate(chain_list):
if len(chain) > 0:
if len(chain) >= 5:
start = index + add
min_count = int(min(chain)) // count
if min_count != 0:
str_chain = ''
for num in range(len(chain)):
str_chain += str_card[start+num]
hand_list[start+num] = int(hand_list[start+num]) - int(min(chain))
for _ in range(min_count):
chains.append(str_chain)
add += len(chain)
hand_list = [int(card) for card in hand_list]
return (chains, hand_list)
def combine_cards(hand):
'''Get optimal combinations of cards in hand
'''
comb = {'rocket': [], 'bomb': [], 'trio': [], 'trio_chain': [],
'solo_chain': [], 'pair_chain': [], 'pair': [], 'solo': []}
# 1. pick rocket
if hand[-2:] == 'BR':
comb['rocket'].append('BR')
hand = hand[:-2]
# 2. pick bomb
hand_cp = hand
for index in range(len(hand_cp) - 3):
if hand_cp[index] == hand_cp[index+3]:
bomb = hand_cp[index: index+4]
comb['bomb'].append(bomb)
hand = hand.replace(bomb, '')
# 3. pick trio and trio_chain
hand_cp = hand
for index in range(len(hand_cp) - 2):
if hand_cp[index] == hand_cp[index+2]:
trio = hand_cp[index: index+3]
if len(comb['trio']) > 0 and INDEX[trio[-1]] < 12 and (INDEX[trio[-1]]-1) == INDEX[comb['trio'][-1][-1]]:
comb['trio'][-1] += trio
else:
comb['trio'].append(trio)
hand = hand.replace(trio, '')
only_trio = []
only_trio_chain = []
for trio in comb['trio']:
if len(trio) == 3:
only_trio.append(trio)
else:
only_trio_chain.append(trio)
comb['trio'] = only_trio
comb['trio_chain'] = only_trio_chain
# 4. pick solo chain
hand_list = card_str2list(hand)
chains, hand_list = pick_chain(hand_list, 1)
comb['solo_chain'] = chains
# 5. pick par_chain
chains, hand_list = pick_chain(hand_list, 2)
comb['pair_chain'] = chains
hand = list2card_str(hand_list)
# 6. pick pair and solo
index = 0
while index < len(hand) - 1:
if hand[index] == hand[index+1]:
comb['pair'].append(hand[index] + hand[index+1])
index += 2
else:
comb['solo'].append(hand[index])
index += 1
if index == (len(hand) - 1):
comb['solo'].append(hand[index])
return comb

View File

@ -0,0 +1,162 @@
import multiprocessing as mp
import pickle
import douzero.env.env
from douzero.dmc.models import Model
from douzero.env.game import GameEnv
import torch
import numpy as np
import BidModel
def load_card_play_models(card_play_model_path_dict):
players = {}
for position in ['landlord', 'landlord_up', 'landlord_down']:
if card_play_model_path_dict[position] == 'rlcard':
from .rlcard_agent import RLCardAgent
players[position] = RLCardAgent(position)
elif card_play_model_path_dict[position] == 'random':
from .random_agent import RandomAgent
players[position] = RandomAgent()
else:
from .deep_agent import DeepAgent
players[position] = DeepAgent(position, card_play_model_path_dict[position])
return players
def mp_simulate(card_play_data_list, card_play_model_path_dict, q, output, bid_output, title):
players = load_card_play_models(card_play_model_path_dict)
EnvCard2RealCard = {3: '3', 4: '4', 5: '5', 6: '6', 7: '7',
8: '8', 9: '9', 10: 'T', 11: 'J', 12: 'Q',
13: 'K', 14: 'A', 17: '2', 20: 'X', 30: 'D'}
env = GameEnv(players)
bid_model = None
if bid_output:
model = Model(device=0)
bid_model = model.get_model("bidding")
bid_model_path = card_play_model_path_dict["landlord"].replace("landlord", "bidding")
weights = torch.load(bid_model_path)
bid_model.load_state_dict(weights)
bid_model.eval()
for idx, card_play_data in enumerate(card_play_data_list):
env.card_play_init(card_play_data)
if bid_output:
output = True
bid_results = []
bid_values = []
bid_info_list = [
np.array([[-1,-1,-1],
[-1,-1,-1],
[-1,-1,-1],
[-1,-1,-1]]),
np.array([[0,0,0],
[-1,-1,-1],
[-1,-1,-1],
[-1,-1,-1]]),
np.array([[1,0,0],
[-1,-1,-1],
[-1,-1,-1],
[-1,-1,-1]]),
np.array([[0,0,0],
[0,0,0],
[-1,-1,-1],
[-1,-1,-1]]),
np.array([[0,0,1],
[1,0,0],
[-1,-1,-1],
[-1,-1,-1]]),
np.array([[0,1,0],
[0,0,1],
[1,0,0],
[-1,-1,-1]]),
]
for bid_info in bid_info_list:
bid_obs = douzero.env.env._get_obs_for_bid(1, bid_info, card_play_data["landlord"])
result = bid_model.forward(torch.tensor(bid_obs["z_batch"], device=torch.device("cuda:0")), torch.tensor(bid_obs["x_batch"], device=torch.device("cuda:0")), True)
values = result["values"]
bid = 1 if values[1] > values[0] else 0
bid_results.append(bid)
bid_values.append(values[bid])
result2 = BidModel.predict_env(card_play_data["landlord"])
print("".join([EnvCard2RealCard[c] for c in card_play_data["landlord"]]), end="")
print(" bid: %i|%i%i|%i%i|%i (%.3f %.3f %.3f %.3f %.3f %.3f) %.1f" % (bid_results[0],bid_results[1],bid_results[2],bid_results[3],bid_results[4],bid_results[5],bid_values[0],bid_values[1],bid_values[2],bid_values[3],bid_values[4],bid_values[5], result2))
if output and not bid_output:
print("\nStart ------- " + title)
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord"]]))
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord_down"]]))
print ("".join([EnvCard2RealCard[c] for c in card_play_data["landlord_up"]]))
# print(card_play_data)
count = 0
while not env.game_over and not bid_output:
action = env.step()
if output:
if count % 3 == 2:
end = "\n"
else:
end = " "
if len(action) == 0:
print("Pass", end=end)
else:
print("".join([EnvCard2RealCard[c] for c in action]), end=end)
count+=1
if idx % 10 == 0 and not bid_output:
print("\nindex", idx)
# print("End -------")
env.reset()
q.put((env.num_wins['landlord'],
env.num_wins['farmer'],
env.num_scores['landlord'],
env.num_scores['farmer']
))
def data_allocation_per_worker(card_play_data_list, num_workers):
card_play_data_list_each_worker = [[] for k in range(num_workers)]
for idx, data in enumerate(card_play_data_list):
card_play_data_list_each_worker[idx % num_workers].append(data)
return card_play_data_list_each_worker
def evaluate(landlord, landlord_up, landlord_down, eval_data, num_workers, output, output_bid, title):
with open(eval_data, 'rb') as f:
card_play_data_list = pickle.load(f)
card_play_data_list_each_worker = data_allocation_per_worker(
card_play_data_list, num_workers)
del card_play_data_list
card_play_model_path_dict = {
'landlord': landlord,
'landlord_up': landlord_up,
'landlord_down': landlord_down}
num_landlord_wins = 0
num_farmer_wins = 0
num_landlord_scores = 0
num_farmer_scores = 0
ctx = mp.get_context('spawn')
q = ctx.SimpleQueue()
processes = []
for card_paly_data in card_play_data_list_each_worker:
p = ctx.Process(
target=mp_simulate,
args=(card_paly_data, card_play_model_path_dict, q, output, output_bid, title))
p.start()
processes.append(p)
for p in processes:
p.join()
for i in range(num_workers):
result = q.get()
num_landlord_wins += result[0]
num_farmer_wins += result[1]
num_landlord_scores += result[2]
num_farmer_scores += result[3]
num_total_wins = num_landlord_wins + num_farmer_wins
print('WP results:')
print('landlord : Farmers - {} : {}'.format(num_landlord_wins / num_total_wins, num_farmer_wins / num_total_wins))
print('ADP results:')
print('landlord : Farmers - {} : {}'.format(num_landlord_scores / num_total_wins, 2 * num_farmer_scores / num_total_wins))

View File

@ -0,0 +1 @@
from .radam import RAdam, PlainRAdam, AdamW

244
douzero/radam/radam.py Normal file
View File

@ -0,0 +1,244 @@
import math
import torch
from torch.optim.optimizer import Optimizer
class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
self.degenerated_to_sgd = degenerated_to_sgd
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
for param in params:
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
param['buffer'] = [[None, None, None] for _ in range(10)]
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
super(RAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state['step'])
else:
step_size = -1
buffered[2] = step_size
# more conservative since it's an approximated value
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p.data.copy_(p_data_fp32)
return loss
class PlainRAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
self.degenerated_to_sgd = degenerated_to_sgd
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(PlainRAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(PlainRAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
# more conservative since it's an approximated value
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
elif self.degenerated_to_sgd:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
step_size = group['lr'] / (1 - beta1 ** state['step'])
p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32)
return loss
class AdamW(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, warmup = warmup)
super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if group['warmup'] > state['step']:
scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
else:
scheduled_lr = group['lr']
step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
return loss

91
evaluate.py Normal file
View File

@ -0,0 +1,91 @@
import os
import argparse
from douzero.evaluation.simulation import evaluate
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'Dou Dizhu Evaluation')
parser.add_argument('--landlord', type=str,
default='baselines/douzero_12/landlord_weights_39762328900.ckpt')
parser.add_argument('--landlord_up', type=str,
default='baselines/douzero_12/landlord_up_weights_39762328900.ckpt')
parser.add_argument('--landlord_down', type=str,
default='baselines/douzero_12/landlord_down_weights_39762328900.ckpt')
parser.add_argument('--eval_data', type=str,
default='eval_data_1000.pkl')
parser.add_argument('--num_workers', type=int, default=5)
parser.add_argument('--gpu_device', type=str, default='0')
parser.add_argument('--output', type=bool, default=True)
parser.add_argument('--bid', type=bool, default=True)
parser.add_argument('--title', type=str, default='New')
args = parser.parse_args()
args.output = True
args.bid = False
if args.output or args.bid:
args.num_workers = 1
t = 3
frame = 3085177900
adp_frame = 2511184300
# args.landlord = 'baselines/resnet_landlord_%i.ckpt' % frame
args.landlord_up = 'baselines/resnet_landlord_up_%i.ckpt' % frame
args.landlord_down = 'baselines/resnet_landlord_%i.ckpt' % frame
args.landlord = 'baselines/douzero_ADP/landlord.ckpt'
# args.landlord_up = 'baselines/douzero_ADP/landlord_up.ckpt'
# args.landlord_down = 'baselines/douzero_ADP/landlord_down.ckpt'
if t == 1:
args.landlord = 'baselines/resnet_landlord_%i.ckpt' % frame
args.landlord_up = 'baselines/douzero_ADP/landlord_up.ckpt'
args.landlord_down = 'baselines/douzero_ADP/landlord_down.ckpt'
elif t == 2:
args.landlord = 'baselines/douzero_ADP/landlord.ckpt'
args.landlord_up = 'baselines/resnet_landlord_up_%i.ckpt' % frame
args.landlord_down = 'baselines/resnet_landlord_down_%i.ckpt' % frame
elif t == 3:
args.landlord = 'baselines/resnet_landlord_%i.ckpt' % frame
args.landlord_up = 'baselines/resnet_landlord_up_%i.ckpt' % frame
args.landlord_down = 'baselines/resnet_landlord_down_%i.ckpt' % frame
elif t == 4:
args.landlord = 'baselines/douzero_ADP/landlord.ckpt'
args.landlord_up = 'baselines/douzero_ADP/landlord_up.ckpt'
args.landlord_down = 'baselines/douzero_ADP/landlord_down.ckpt'
elif t == 5:
args.landlord = 'baselines/douzero_WP/landlord.ckpt'
args.landlord_up = 'baselines/douzero_WP/landlord_up.ckpt'
args.landlord_down = 'baselines/douzero_WP/landlord_down.ckpt'
elif t == 6:
args.landlord = 'baselines/resnet_landlord_%i.ckpt' % frame
args.landlord_up = 'baselines/douzero_ADP/landlord_up_weights_%i.ckpt' % adp_frame
args.landlord_down = 'baselines/douzero_ADP/landlord_down_weights_%i.ckpt' % adp_frame
elif t == 7:
args.landlord = 'baselines/douzero_ADP/landlord_weights_%i.ckpt' % adp_frame
args.landlord_up = 'baselines/resnet_landlord_up_%i.ckpt' % frame
args.landlord_down = 'baselines/resnet_landlord_down_%i.ckpt' % frame
elif t == 8:
args.landlord = 'baselines/douzero_ADP/landlord_weights_%i.ckpt' % adp_frame
args.landlord_up = 'baselines/douzero_ADP/landlord_up_weights_%i.ckpt' % adp_frame
args.landlord_down = 'baselines/douzero_ADP/landlord_down_weights_%i.ckpt' % adp_frame
elif t == 9:
args.landlord = 'baselines/resnet_landlord_%i.ckpt' % frame
args.landlord_up = 'baselines/resnet_landlord_up_%i.ckpt' % adp_frame
args.landlord_down = 'baselines/resnet_landlord_down_%i.ckpt' % adp_frame
elif t == 10:
# landlord_down_weights_10777798400
args.landlord = 'baselines/douzero_ADP/landlord.ckpt'
args.landlord_up = 'baselines/douzero_ADP/landlord_up_weights_%i.ckpt' % adp_frame
args.landlord_down = 'baselines/douzero_ADP/landlord_down_weights_%i.ckpt' % adp_frame
elif t == 11:
args.landlord = 'baselines/douzero_ADP/landlord_weights_%i.ckpt' % adp_frame
args.landlord_up = 'baselines/douzero_ADP/landlord_up.ckpt'
args.landlord_down = 'baselines/douzero_ADP/landlord_down.ckpt'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
evaluate(args.landlord,
args.landlord_up,
args.landlord_down,
args.eval_data,
args.num_workers,
args.output,
args.bid,
args.title)

47
generate_eval_data.py Normal file
View File

@ -0,0 +1,47 @@
import argparse
import pickle
import numpy as np
deck = []
for i in range(3, 15):
deck.extend([i for _ in range(4)])
deck.extend([17 for _ in range(4)])
deck.extend([20, 30])
def get_parser():
parser = argparse.ArgumentParser(description='DouZero: random data generator')
parser.add_argument('--output', default='eval_data', type=str)
parser.add_argument('--num_games', default=10000, type=int)
return parser
def generate():
_deck = deck.copy()
np.random.shuffle(_deck)
card_play_data = {'landlord': _deck[:20],
'landlord_up': _deck[20:37],
'landlord_down': _deck[37:54],
'three_landlord_cards': _deck[17:20],
}
for key in card_play_data:
card_play_data[key].sort()
return card_play_data
if __name__ == '__main__':
flags = get_parser().parse_args()
output_pickle = flags.output + '.pkl'
print("output_pickle:", output_pickle)
print("generating data...")
data = []
for _ in range(flags.num_games):
data.append(generate())
print("saving pickle file...")
with open(output_pickle,'wb') as g:
pickle.dump(data,g,pickle.HIGHEST_PROTOCOL)

BIN
imgs/douzero_logo.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

4
requirements.txt Normal file
View File

@ -0,0 +1,4 @@
torch>=1.6.0
GitPython
gitdb2
rlcard

33
setup.py Normal file
View File

@ -0,0 +1,33 @@
import setuptools
VERSION = '1.1.0'
with open("README.md", "r", encoding="utf8") as fh:
long_description = fh.read()
setuptools.setup(
name="douzero",
version=VERSION,
author="Daochen Zha",
author_email="daochen.zha@tamu.edu",
description="DouZero DouDizhu AI",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/kwai/DouZero",
license='Apache License 2.0',
keywords=["DouDizhu", "AI", "Reinforcment Learning", "RL", "Torch", "Poker"],
packages=setuptools.find_packages(),
install_requires=[
'torch',
'rlcard'
],
requires_python='>=3.6',
classifiers=[
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.6",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
)

8
train.py Normal file
View File

@ -0,0 +1,8 @@
import os
from douzero.dmc import parser, train
if __name__ == '__main__':
flags = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = flags.gpu_devices
train(flags)