324 lines
14 KiB
Python
324 lines
14 KiB
Python
import os
|
|
import queue
|
|
import threading
|
|
import typing
|
|
import logging
|
|
import traceback
|
|
import numpy as np
|
|
from collections import Counter
|
|
import time
|
|
from douzero.radam.radam import RAdam
|
|
|
|
import torch
|
|
from torch import multiprocessing as mp
|
|
|
|
from .env_utils import Environment
|
|
from douzero.env import Env
|
|
|
|
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
|
|
11: 8, 12: 9, 13: 10, 14: 11, 17: 12}
|
|
|
|
NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
|
|
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
|
|
2: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
|
|
3: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
|
|
4: np.array([1, 1, 1, 1, 0, 0, 0, 0]),
|
|
5: np.array([1, 1, 1, 1, 1, 0, 0, 0]),
|
|
6: np.array([1, 1, 1, 1, 1, 1, 0, 0]),
|
|
7: np.array([1, 1, 1, 1, 1, 1, 1, 0]),
|
|
8: np.array([1, 1, 1, 1, 1, 1, 1, 1])}
|
|
|
|
NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
|
|
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
|
|
3: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
|
|
4: np.array([0, 0, 1, 0, 0, 0, 0, 0]),
|
|
5: np.array([1, 0, 1, 0, 0, 0, 0, 0]),
|
|
7: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
|
|
12: np.array([0, 0, 1, 1, 0, 0, 0, 0]),
|
|
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
|
|
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
|
|
|
|
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
|
1: np.array([1, 0, 0, 0, 0]),
|
|
2: np.array([1, 1, 0, 0, 0]),
|
|
3: np.array([1, 1, 1, 0, 0]),
|
|
4: np.array([1, 1, 1, 1, 0]),
|
|
5: np.array([1, 0, 0, 0, 1]),
|
|
6: np.array([1, 1, 0, 0, 1]),
|
|
7: np.array([1, 1, 1, 0, 1]),
|
|
8: np.array([1, 1, 1, 1, 1])}
|
|
|
|
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
|
|
1: np.array([1, 0, 0, 0, 0]),
|
|
3: np.array([1, 1, 0, 0, 0]),
|
|
4: np.array([0, 0, 1, 0, 0]),
|
|
5: np.array([1, 0, 1, 0, 0]),
|
|
7: np.array([1, 1, 1, 0, 0]),
|
|
12: np.array([0, 0, 1, 1, 0]),
|
|
13: np.array([1, 0, 1, 1, 0]),
|
|
15: np.array([1, 1, 1, 1, 0])}
|
|
|
|
shandle = logging.StreamHandler()
|
|
shandle.setFormatter(
|
|
logging.Formatter(
|
|
'[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
|
|
'%(message)s'))
|
|
log = logging.getLogger('doudzero')
|
|
log.propagate = False
|
|
log.addHandler(shandle)
|
|
log.setLevel(logging.INFO)
|
|
|
|
# Buffers are used to transfer data between actor processes
|
|
# and learner processes. They are shared tensors in GPU
|
|
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
|
|
|
def create_env(flags):
|
|
return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model, flags.unified_model)
|
|
|
|
def get_batch(b_queues, position, flags, lock):
|
|
"""
|
|
This function will sample a batch from the buffers based
|
|
on the indices received from the full queue. It will also
|
|
free the indices by sending it to full_queue.
|
|
"""
|
|
b_queue = b_queues[position]
|
|
buffer = []
|
|
while len(buffer) < flags.batch_size:
|
|
buffer.append(b_queue.get())
|
|
if flags.old_model:
|
|
batch = {
|
|
key: torch.stack([m[key] for m in buffer], dim=1)
|
|
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
|
|
}
|
|
else:
|
|
batch = {
|
|
key: torch.stack([m[key] for m in buffer], dim=1)
|
|
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"]
|
|
}
|
|
del buffer
|
|
return batch
|
|
|
|
def create_optimizers(flags, learner_model):
|
|
"""
|
|
Create three optimizers for the three positions
|
|
"""
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
|
optimizers = {}
|
|
if flags.unified_model:
|
|
position = 'uni'
|
|
optimizer = RAdam(
|
|
learner_model.parameters(position),
|
|
lr=flags.learning_rate,
|
|
eps=flags.epsilon)
|
|
optimizers[position] = optimizer
|
|
else:
|
|
for position in positions:
|
|
optimizer = RAdam(
|
|
learner_model.parameters(position),
|
|
lr=flags.learning_rate,
|
|
eps=flags.epsilon)
|
|
optimizers[position] = optimizer
|
|
return optimizers
|
|
|
|
|
|
def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
|
device = device if device == "cpu" else ("cuda:" + str(device))
|
|
if not flags.enable_onnx:
|
|
if flags.unified_model:
|
|
model.model.to(torch.device(device))
|
|
else:
|
|
for pos in positions:
|
|
model.models[pos].to(torch.device(device))
|
|
last_onnx_frame = -1
|
|
log.info('Infer %i started.', i)
|
|
|
|
while True:
|
|
# print("posi", position)
|
|
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
|
|
last_onnx_frame = onnx_frame.value
|
|
model.set_onnx_model(device)
|
|
all_empty = True
|
|
for infer_queue in infer_queues:
|
|
try:
|
|
task = infer_queue['input'].get_nowait()
|
|
with torch.no_grad():
|
|
result = model.forward(task['position'], task['z_batch'], task['x_batch'], device=device, return_value=True, flags=flags)
|
|
infer_queue['output'].put({
|
|
'values': result['values']
|
|
})
|
|
all_empty = False
|
|
except queue.Empty:
|
|
pass
|
|
if all_empty:
|
|
time.sleep(0.01)
|
|
|
|
def act_queue(i, infer_queue, batch_queues, flags):
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
|
try:
|
|
T = flags.unroll_length
|
|
log.info('Actor %i started.', i)
|
|
|
|
env = create_env(flags)
|
|
device = 'cpu'
|
|
env = Environment(env, device)
|
|
|
|
done_buf = {p: [] for p in positions}
|
|
episode_return_buf = {p: [] for p in positions}
|
|
target_buf = {p: [] for p in positions}
|
|
obs_z_buf = {p: [] for p in positions}
|
|
size = {p: 0 for p in positions}
|
|
type_buf = {p: [] for p in positions}
|
|
obs_x_batch_buf = {p: [] for p in positions}
|
|
if flags.old_model:
|
|
obs_action_buf = {p: [] for p in positions}
|
|
obs_x_no_action = {p: [] for p in positions}
|
|
|
|
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
|
|
|
|
position, obs, env_output = env.initial(flags=flags)
|
|
while True:
|
|
while True:
|
|
if len(obs['legal_actions']) > 1:
|
|
infer_queue['input'].put({
|
|
'position': position,
|
|
'z_batch': obs['z_batch'],
|
|
'x_batch': obs['x_batch']
|
|
})
|
|
result = infer_queue['output'].get()
|
|
action = np.argmax(result['values'], axis=0)[0]
|
|
_action_idx = int(action)
|
|
action = obs['legal_actions'][_action_idx]
|
|
else:
|
|
action = obs['legal_actions'][0]
|
|
if flags.old_model:
|
|
obs_action_buf[position].append(_cards2tensor(action, flags.lite_model))
|
|
obs_x_no_action[position].append(env_output['obs_x_no_action'])
|
|
obs_z_buf[position].append(env_output['obs_z'])
|
|
else:
|
|
obs_z_buf[position].append(torch.vstack((_cards2tensor(action, flags.lite_model).unsqueeze(0), env_output['obs_z'])).float())
|
|
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
|
|
x_batch = env_output['obs_x_no_action'].float()
|
|
obs_x_batch_buf[position].append(x_batch)
|
|
type_buf[position].append(position_index[position])
|
|
position, obs, env_output = env.step(action, flags=flags)
|
|
size[position] += 1
|
|
if env_output['done']:
|
|
for p in positions:
|
|
diff = size[p] - len(target_buf[p])
|
|
# print(p, diff)
|
|
if diff > 0:
|
|
done_buf[p].extend([False for _ in range(diff-1)])
|
|
done_buf[p].append(True)
|
|
episode_return = env_output['episode_return']["play"][p] if p == 'landlord' else -env_output['episode_return']["play"][p]
|
|
episode_return_buf[p].extend([0.0 for _ in range(diff-1)])
|
|
episode_return_buf[p].append(episode_return)
|
|
target_buf[p].extend([episode_return for _ in range(diff)])
|
|
break
|
|
for p in positions:
|
|
while size[p] > T:
|
|
# print(p, "epr", torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),)
|
|
if flags.old_model:
|
|
batch_queues[p].put({
|
|
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
|
|
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
|
|
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
|
|
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
|
|
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_action_buf[p][:T]]),
|
|
"obs_x_no_action": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_no_action[p][:T]]),
|
|
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
|
|
})
|
|
obs_action_buf[p] = obs_action_buf[p][T:]
|
|
obs_x_no_action[p] = obs_x_no_action[p][T:]
|
|
else:
|
|
batch_queues[p].put({
|
|
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
|
|
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
|
|
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
|
|
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
|
|
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_batch_buf[p][:T]]),
|
|
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
|
|
})
|
|
obs_x_batch_buf[p] = obs_x_batch_buf[p][T:]
|
|
done_buf[p] = done_buf[p][T:]
|
|
episode_return_buf[p] = episode_return_buf[p][T:]
|
|
target_buf[p] = target_buf[p][T:]
|
|
obs_z_buf[p] = obs_z_buf[p][T:]
|
|
type_buf[p] = type_buf[p][T:]
|
|
size[p] -= T
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
except Exception as e:
|
|
log.error('Exception in worker process %i', i)
|
|
traceback.print_exc()
|
|
print()
|
|
raise e
|
|
|
|
def act(i, infer_queues, batch_queues, flags):
|
|
threads = []
|
|
for x in range(len(infer_queues)):
|
|
thread = threading.Thread(
|
|
target=act_queue, name='act_queue-%d-%d' % (i, x),
|
|
args=(x, infer_queues[x], batch_queues, flags))
|
|
thread.setDaemon(True)
|
|
thread.start()
|
|
threads.append(thread)
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
def _cards2tensor(list_cards, compress_form = False):
|
|
"""
|
|
Convert a list of integers to the tensor
|
|
representation
|
|
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
|
|
"""
|
|
if compress_form:
|
|
if len(list_cards) == 0:
|
|
return torch.zeros(69, dtype=torch.int8)
|
|
|
|
matrix = np.zeros([5, 14], dtype=np.int8)
|
|
counter = Counter(list_cards)
|
|
joker_cnt = 0
|
|
for card, num_times in counter.items():
|
|
if card < 20:
|
|
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
|
|
elif card == 20:
|
|
if num_times == 2:
|
|
joker_cnt |= 0b11
|
|
else:
|
|
joker_cnt |= 0b01
|
|
elif card == 30:
|
|
if num_times == 2:
|
|
joker_cnt |= 0b1100
|
|
else:
|
|
joker_cnt |= 0b0100
|
|
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
|
|
matrix = matrix.flatten('F')[:-1]
|
|
matrix = torch.from_numpy(matrix)
|
|
return matrix
|
|
else:
|
|
if len(list_cards) == 0:
|
|
return torch.zeros(108, dtype=torch.int8)
|
|
|
|
matrix = np.zeros([8, 14], dtype=np.int8)
|
|
counter = Counter(list_cards)
|
|
joker_cnt = 0
|
|
for card, num_times in counter.items():
|
|
if card < 20:
|
|
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
|
elif card == 20:
|
|
if num_times == 2:
|
|
joker_cnt |= 0b11
|
|
else:
|
|
joker_cnt |= 0b01
|
|
elif card == 30:
|
|
if num_times == 2:
|
|
joker_cnt |= 0b1100
|
|
else:
|
|
joker_cnt |= 0b0100
|
|
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
|
|
matrix = matrix.flatten('F')[:-4]
|
|
matrix = torch.from_numpy(matrix)
|
|
return matrix
|