DouZero_For_HLDDZ_FullAuto/douzero/dmc/utils.py

206 lines
7.6 KiB
Python
Raw Permalink Normal View History

2021-07-28 19:47:43 +08:00
import os
import typing
import logging
import traceback
import numpy as np
from collections import Counter
import time
import torch
from torch import multiprocessing as mp
from .env_utils import Environment
from douzero.env import Env
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
11: 8, 12: 9, 13: 10, 14: 11, 17: 12}
NumOnes2Array = {0: np.array([0, 0, 0, 0]),
1: np.array([1, 0, 0, 0]),
2: np.array([1, 1, 0, 0]),
3: np.array([1, 1, 1, 0]),
4: np.array([1, 1, 1, 1])}
shandle = logging.StreamHandler()
shandle.setFormatter(
logging.Formatter(
'[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
'%(message)s'))
log = logging.getLogger('doudzero')
2021-07-28 19:47:43 +08:00
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)
# Buffers are used to transfer data between actor processes
# and learner processes. They are shared tensors in GPU
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags):
return Env(flags.objective)
def get_batch(free_queue,
full_queue,
buffers,
flags,
lock):
"""
This function will sample a batch from the buffers based
on the indices received from the full queue. It will also
free the indices by sending it to full_queue.
"""
with lock:
indices = [full_queue.get() for _ in range(flags.batch_size)]
batch = {
key: torch.stack([buffers[key][m] for m in indices], dim=1)
for key in buffers
}
for m in indices:
free_queue.put(m)
return batch
def create_optimizers(flags, learner_model):
"""
Create three optimizers for the three positions
"""
positions = ['landlord', 'landlord_up', 'landlord_down']
optimizers = {}
for position in positions:
optimizer = torch.optim.RMSprop(
learner_model.parameters(position),
lr=flags.learning_rate,
momentum=flags.momentum,
eps=flags.epsilon,
alpha=flags.alpha)
optimizers[position] = optimizer
return optimizers
def create_buffers(flags):
"""
We create buffers for different positions as well as
for different devices (i.e., GPU). That is, each device
will have three buffers for the three positions.
"""
T = flags.unroll_length
positions = ['landlord', 'landlord_up', 'landlord_down']
buffers = []
for device in range(torch.cuda.device_count()):
buffers.append({})
for position in positions:
x_dim = 319 if position == 'landlord' else 430
specs = dict(
done=dict(size=(T,), dtype=torch.bool),
episode_return=dict(size=(T,), dtype=torch.float32),
target=dict(size=(T,), dtype=torch.float32),
obs_x_no_action=dict(size=(T, x_dim), dtype=torch.int8),
obs_action=dict(size=(T, 54), dtype=torch.int8),
obs_z=dict(size=(T, 5, 162), dtype=torch.int8),
)
_buffers: Buffers = {key: [] for key in specs}
for _ in range(flags.num_buffers):
for key in _buffers:
_buffer = torch.empty(**specs[key]).to(torch.device('cuda:'+str(device))).share_memory_()
_buffers[key].append(_buffer)
buffers[device][position] = _buffers
return buffers
def act(i, device, free_queue, full_queue, model, buffers, flags):
"""
This function will run forever until we stop it. It will generate
data from the environment and send the data to buffer. It uses
a free queue and full queue to syncup with the main process.
"""
positions = ['landlord', 'landlord_up', 'landlord_down']
try:
T = flags.unroll_length
log.info('Device %i Actor %i started.', device, i)
env = create_env(flags)
env = Environment(env, device)
done_buf = {p: [] for p in positions}
episode_return_buf = {p: [] for p in positions}
target_buf = {p: [] for p in positions}
obs_x_no_action_buf = {p: [] for p in positions}
obs_action_buf = {p: [] for p in positions}
obs_z_buf = {p: [] for p in positions}
size = {p: 0 for p in positions}
position, obs, env_output = env.initial()
while True:
while True:
obs_x_no_action_buf[position].append(env_output['obs_x_no_action'])
obs_z_buf[position].append(env_output['obs_z'])
with torch.no_grad():
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
_action_idx = int(agent_output['action'].cpu().detach().numpy())
action = obs['legal_actions'][_action_idx]
obs_action_buf[position].append(_cards2tensor(action))
position, obs, env_output = env.step(action)
size[position] += 1
if env_output['done']:
for p in positions:
diff = size[p] - len(target_buf[p])
if diff > 0:
done_buf[p].extend([False for _ in range(diff-1)])
done_buf[p].append(True)
episode_return = env_output['episode_return'] if p == 'landlord' else -env_output['episode_return']
episode_return_buf[p].extend([0.0 for _ in range(diff-1)])
episode_return_buf[p].append(episode_return)
target_buf[p].extend([episode_return for _ in range(diff)])
break
for p in positions:
if size[p] > T:
index = free_queue[p].get()
if index is None:
break
for t in range(T):
buffers[p]['done'][index][t, ...] = done_buf[p][t]
buffers[p]['episode_return'][index][t, ...] = episode_return_buf[p][t]
buffers[p]['target'][index][t, ...] = target_buf[p][t]
buffers[p]['obs_x_no_action'][index][t, ...] = obs_x_no_action_buf[p][t]
buffers[p]['obs_action'][index][t, ...] = obs_action_buf[p][t]
buffers[p]['obs_z'][index][t, ...] = obs_z_buf[p][t]
full_queue[p].put(index)
done_buf[p] = done_buf[p][T:]
episode_return_buf[p] = episode_return_buf[p][T:]
target_buf[p] = target_buf[p][T:]
obs_x_no_action_buf[p] = obs_x_no_action_buf[p][T:]
obs_action_buf[p] = obs_action_buf[p][T:]
obs_z_buf[p] = obs_z_buf[p][T:]
size[p] -= T
except KeyboardInterrupt:
pass
except Exception as e:
log.error('Exception in worker process %i', i)
traceback.print_exc()
print()
raise e
def _cards2tensor(list_cards):
"""
Convert a list of integers to the tensor
representation
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
"""
if len(list_cards) == 0:
return torch.zeros(54, dtype=torch.int8)
matrix = np.zeros([4, 13], dtype=np.int8)
jokers = np.zeros(2, dtype=np.int8)
counter = Counter(list_cards)
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
elif card == 20:
jokers[0] = 1
elif card == 30:
jokers[1] = 1
matrix = np.concatenate((matrix.flatten('F'), jokers))
matrix = torch.from_numpy(matrix)
return matrix