206 lines
7.6 KiB
Python
206 lines
7.6 KiB
Python
import os
|
|
import typing
|
|
import logging
|
|
import traceback
|
|
import numpy as np
|
|
from collections import Counter
|
|
import time
|
|
|
|
import torch
|
|
from torch import multiprocessing as mp
|
|
|
|
from .env_utils import Environment
|
|
from douzero.env import Env
|
|
|
|
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
|
|
11: 8, 12: 9, 13: 10, 14: 11, 17: 12}
|
|
|
|
NumOnes2Array = {0: np.array([0, 0, 0, 0]),
|
|
1: np.array([1, 0, 0, 0]),
|
|
2: np.array([1, 1, 0, 0]),
|
|
3: np.array([1, 1, 1, 0]),
|
|
4: np.array([1, 1, 1, 1])}
|
|
|
|
shandle = logging.StreamHandler()
|
|
shandle.setFormatter(
|
|
logging.Formatter(
|
|
'[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
|
|
'%(message)s'))
|
|
log = logging.getLogger('doudzero')
|
|
log.propagate = False
|
|
log.addHandler(shandle)
|
|
log.setLevel(logging.INFO)
|
|
|
|
# Buffers are used to transfer data between actor processes
|
|
# and learner processes. They are shared tensors in GPU
|
|
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
|
|
|
def create_env(flags):
|
|
return Env(flags.objective)
|
|
|
|
def get_batch(free_queue,
|
|
full_queue,
|
|
buffers,
|
|
flags,
|
|
lock):
|
|
"""
|
|
This function will sample a batch from the buffers based
|
|
on the indices received from the full queue. It will also
|
|
free the indices by sending it to full_queue.
|
|
"""
|
|
with lock:
|
|
indices = [full_queue.get() for _ in range(flags.batch_size)]
|
|
batch = {
|
|
key: torch.stack([buffers[key][m] for m in indices], dim=1)
|
|
for key in buffers
|
|
}
|
|
for m in indices:
|
|
free_queue.put(m)
|
|
return batch
|
|
|
|
def create_optimizers(flags, learner_model):
|
|
"""
|
|
Create three optimizers for the three positions
|
|
"""
|
|
positions = ['landlord', 'landlord_up', 'landlord_down']
|
|
optimizers = {}
|
|
for position in positions:
|
|
optimizer = torch.optim.RMSprop(
|
|
learner_model.parameters(position),
|
|
lr=flags.learning_rate,
|
|
momentum=flags.momentum,
|
|
eps=flags.epsilon,
|
|
alpha=flags.alpha)
|
|
optimizers[position] = optimizer
|
|
return optimizers
|
|
|
|
def create_buffers(flags):
|
|
"""
|
|
We create buffers for different positions as well as
|
|
for different devices (i.e., GPU). That is, each device
|
|
will have three buffers for the three positions.
|
|
"""
|
|
T = flags.unroll_length
|
|
positions = ['landlord', 'landlord_up', 'landlord_down']
|
|
buffers = []
|
|
for device in range(torch.cuda.device_count()):
|
|
buffers.append({})
|
|
for position in positions:
|
|
x_dim = 319 if position == 'landlord' else 430
|
|
specs = dict(
|
|
done=dict(size=(T,), dtype=torch.bool),
|
|
episode_return=dict(size=(T,), dtype=torch.float32),
|
|
target=dict(size=(T,), dtype=torch.float32),
|
|
obs_x_no_action=dict(size=(T, x_dim), dtype=torch.int8),
|
|
obs_action=dict(size=(T, 54), dtype=torch.int8),
|
|
obs_z=dict(size=(T, 5, 162), dtype=torch.int8),
|
|
)
|
|
_buffers: Buffers = {key: [] for key in specs}
|
|
for _ in range(flags.num_buffers):
|
|
for key in _buffers:
|
|
_buffer = torch.empty(**specs[key]).to(torch.device('cuda:'+str(device))).share_memory_()
|
|
_buffers[key].append(_buffer)
|
|
buffers[device][position] = _buffers
|
|
return buffers
|
|
|
|
def act(i, device, free_queue, full_queue, model, buffers, flags):
|
|
"""
|
|
This function will run forever until we stop it. It will generate
|
|
data from the environment and send the data to buffer. It uses
|
|
a free queue and full queue to syncup with the main process.
|
|
"""
|
|
positions = ['landlord', 'landlord_up', 'landlord_down']
|
|
try:
|
|
T = flags.unroll_length
|
|
log.info('Device %i Actor %i started.', device, i)
|
|
|
|
env = create_env(flags)
|
|
|
|
env = Environment(env, device)
|
|
|
|
done_buf = {p: [] for p in positions}
|
|
episode_return_buf = {p: [] for p in positions}
|
|
target_buf = {p: [] for p in positions}
|
|
obs_x_no_action_buf = {p: [] for p in positions}
|
|
obs_action_buf = {p: [] for p in positions}
|
|
obs_z_buf = {p: [] for p in positions}
|
|
size = {p: 0 for p in positions}
|
|
|
|
position, obs, env_output = env.initial()
|
|
|
|
while True:
|
|
while True:
|
|
obs_x_no_action_buf[position].append(env_output['obs_x_no_action'])
|
|
obs_z_buf[position].append(env_output['obs_z'])
|
|
with torch.no_grad():
|
|
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
|
_action_idx = int(agent_output['action'].cpu().detach().numpy())
|
|
action = obs['legal_actions'][_action_idx]
|
|
obs_action_buf[position].append(_cards2tensor(action))
|
|
position, obs, env_output = env.step(action)
|
|
size[position] += 1
|
|
if env_output['done']:
|
|
for p in positions:
|
|
diff = size[p] - len(target_buf[p])
|
|
if diff > 0:
|
|
done_buf[p].extend([False for _ in range(diff-1)])
|
|
done_buf[p].append(True)
|
|
|
|
episode_return = env_output['episode_return'] if p == 'landlord' else -env_output['episode_return']
|
|
episode_return_buf[p].extend([0.0 for _ in range(diff-1)])
|
|
episode_return_buf[p].append(episode_return)
|
|
target_buf[p].extend([episode_return for _ in range(diff)])
|
|
break
|
|
|
|
for p in positions:
|
|
if size[p] > T:
|
|
index = free_queue[p].get()
|
|
if index is None:
|
|
break
|
|
for t in range(T):
|
|
buffers[p]['done'][index][t, ...] = done_buf[p][t]
|
|
buffers[p]['episode_return'][index][t, ...] = episode_return_buf[p][t]
|
|
buffers[p]['target'][index][t, ...] = target_buf[p][t]
|
|
buffers[p]['obs_x_no_action'][index][t, ...] = obs_x_no_action_buf[p][t]
|
|
buffers[p]['obs_action'][index][t, ...] = obs_action_buf[p][t]
|
|
buffers[p]['obs_z'][index][t, ...] = obs_z_buf[p][t]
|
|
full_queue[p].put(index)
|
|
done_buf[p] = done_buf[p][T:]
|
|
episode_return_buf[p] = episode_return_buf[p][T:]
|
|
target_buf[p] = target_buf[p][T:]
|
|
obs_x_no_action_buf[p] = obs_x_no_action_buf[p][T:]
|
|
obs_action_buf[p] = obs_action_buf[p][T:]
|
|
obs_z_buf[p] = obs_z_buf[p][T:]
|
|
size[p] -= T
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
except Exception as e:
|
|
log.error('Exception in worker process %i', i)
|
|
traceback.print_exc()
|
|
print()
|
|
raise e
|
|
|
|
def _cards2tensor(list_cards):
|
|
"""
|
|
Convert a list of integers to the tensor
|
|
representation
|
|
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
|
|
"""
|
|
if len(list_cards) == 0:
|
|
return torch.zeros(54, dtype=torch.int8)
|
|
|
|
matrix = np.zeros([4, 13], dtype=np.int8)
|
|
jokers = np.zeros(2, dtype=np.int8)
|
|
counter = Counter(list_cards)
|
|
for card, num_times in counter.items():
|
|
if card < 20:
|
|
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
|
elif card == 20:
|
|
jokers[0] = 1
|
|
elif card == 30:
|
|
jokers[1] = 1
|
|
matrix = np.concatenate((matrix.flatten('F'), jokers))
|
|
matrix = torch.from_numpy(matrix)
|
|
return matrix
|