Douzero_Resnet/douzero/dmc/utils.py

315 lines
14 KiB
Python

import os
import queue
import threading
import typing
import logging
import traceback
import numpy as np
from collections import Counter
import time
from douzero.radam.radam import RAdam
import torch
from torch import multiprocessing as mp
from .env_utils import Environment
from douzero.env import Env
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
11: 8, 12: 9, 13: 10, 14: 11, 17: 12}
NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
2: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
3: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
4: np.array([1, 1, 1, 1, 0, 0, 0, 0]),
5: np.array([1, 1, 1, 1, 1, 0, 0, 0]),
6: np.array([1, 1, 1, 1, 1, 1, 0, 0]),
7: np.array([1, 1, 1, 1, 1, 1, 1, 0]),
8: np.array([1, 1, 1, 1, 1, 1, 1, 1])}
NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
3: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
4: np.array([0, 0, 1, 0, 0, 0, 0, 0]),
5: np.array([1, 0, 1, 0, 0, 0, 0, 0]),
7: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
12: np.array([0, 0, 1, 1, 0, 0, 0, 0]),
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
2: np.array([1, 1, 0, 0, 0]),
3: np.array([1, 1, 1, 0, 0]),
4: np.array([1, 1, 1, 1, 0]),
5: np.array([1, 0, 0, 0, 1]),
6: np.array([1, 1, 0, 0, 1]),
7: np.array([1, 1, 1, 0, 1]),
8: np.array([1, 1, 1, 1, 1])}
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
3: np.array([1, 1, 0, 0, 0]),
4: np.array([0, 0, 1, 0, 0]),
5: np.array([1, 0, 1, 0, 0]),
7: np.array([1, 1, 1, 0, 0]),
12: np.array([0, 0, 1, 1, 0]),
13: np.array([1, 0, 1, 1, 0]),
15: np.array([1, 1, 1, 1, 0])}
shandle = logging.StreamHandler()
shandle.setFormatter(
logging.Formatter(
'[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
'%(message)s'))
log = logging.getLogger('doudzero')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)
# Buffers are used to transfer data between actor processes
# and learner processes. They are shared tensors in GPU
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags):
return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model, flags.unified_model)
def get_batch(b_queues, position, flags, lock):
"""
This function will sample a batch from the buffers based
on the indices received from the full queue. It will also
free the indices by sending it to full_queue.
"""
b_queue = b_queues[position]
buffer = []
while len(buffer) < flags.batch_size:
buffer.append(b_queue.get())
if flags.old_model:
batch = {
key: torch.stack([m[key] for m in buffer], dim=1)
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
}
else:
batch = {
key: torch.stack([m[key] for m in buffer], dim=1)
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"]
}
del buffer
return batch
def create_optimizers(flags, learner_model):
"""
Create three optimizers for the three positions
"""
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
optimizers = {}
for position in positions:
optimizer = RAdam(
learner_model.parameters(position),
lr=flags.learning_rate,
eps=flags.epsilon)
optimizers[position] = optimizer
return optimizers
def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if not flags.enable_onnx:
if flags.unified_model:
model.model.to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
else:
for pos in positions:
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
last_onnx_frame = -1
log.info('Infer %i started.', i)
while True:
# print("posi", position)
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
last_onnx_frame = onnx_frame.value
model.set_onnx_model(device)
all_empty = True
for infer_queue in infer_queues:
try:
task = infer_queue['input'].get_nowait()
with torch.no_grad():
result = model.forward(task['position'], task['z_batch'], task['x_batch'], device=device, return_value=True, flags=flags)
infer_queue['output'].put({
'values': result['values']
})
all_empty = False
except queue.Empty:
pass
if all_empty:
time.sleep(0.01)
def act_queue(i, infer_queue, batch_queues, flags):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
try:
T = flags.unroll_length
log.info('Actor %i started.', i)
env = create_env(flags)
device = 'cpu'
env = Environment(env, device)
done_buf = {p: [] for p in positions}
episode_return_buf = {p: [] for p in positions}
target_buf = {p: [] for p in positions}
obs_z_buf = {p: [] for p in positions}
size = {p: 0 for p in positions}
type_buf = {p: [] for p in positions}
obs_x_batch_buf = {p: [] for p in positions}
if flags.old_model:
obs_action_buf = {p: [] for p in positions}
obs_x_no_action = {p: [] for p in positions}
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
position, obs, env_output = env.initial(flags=flags)
while True:
while True:
if len(obs['legal_actions']) > 1:
infer_queue['input'].put({
'position': position,
'z_batch': obs['z_batch'],
'x_batch': obs['x_batch']
})
result = infer_queue['output'].get()
action = np.argmax(result['values'], axis=0)[0]
_action_idx = int(action)
action = obs['legal_actions'][_action_idx]
else:
action = obs['legal_actions'][0]
if flags.old_model:
obs_action_buf[position].append(_cards2tensor(action, flags.lite_model))
obs_x_no_action[position].append(env_output['obs_x_no_action'])
obs_z_buf[position].append(env_output['obs_z'])
else:
obs_z_buf[position].append(torch.vstack((_cards2tensor(action, flags.lite_model).unsqueeze(0), env_output['obs_z'])).float())
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
x_batch = env_output['obs_x_no_action'].float()
obs_x_batch_buf[position].append(x_batch)
type_buf[position].append(position_index[position])
position, obs, env_output = env.step(action, flags=flags)
size[position] += 1
if env_output['done']:
for p in positions:
diff = size[p] - len(target_buf[p])
# print(p, diff)
if diff > 0:
done_buf[p].extend([False for _ in range(diff-1)])
done_buf[p].append(True)
episode_return = env_output['episode_return']["play"][p] if p == 'landlord' else -env_output['episode_return']["play"][p]
episode_return_buf[p].extend([0.0 for _ in range(diff-1)])
episode_return_buf[p].append(episode_return)
target_buf[p].extend([episode_return for _ in range(diff)])
break
for p in positions:
while size[p] > T:
# print(p, "epr", torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),)
if flags.old_model:
batch_queues[p].put({
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_action_buf[p][:T]]),
"obs_x_no_action": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_no_action[p][:T]]),
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
})
obs_action_buf[p] = obs_action_buf[p][T:]
obs_x_no_action[p] = obs_x_no_action[p][T:]
else:
batch_queues[p].put({
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_batch_buf[p][:T]]),
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
})
obs_x_batch_buf[p] = obs_x_batch_buf[p][T:]
done_buf[p] = done_buf[p][T:]
episode_return_buf[p] = episode_return_buf[p][T:]
target_buf[p] = target_buf[p][T:]
obs_z_buf[p] = obs_z_buf[p][T:]
type_buf[p] = type_buf[p][T:]
size[p] -= T
except KeyboardInterrupt:
pass
except Exception as e:
log.error('Exception in worker process %i', i)
traceback.print_exc()
print()
raise e
def act(i, infer_queues, batch_queues, flags):
threads = []
for x in range(len(infer_queues)):
thread = threading.Thread(
target=act_queue, name='act_queue-%d-%d' % (i, x),
args=(x, infer_queues[x], batch_queues, flags))
thread.setDaemon(True)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
def _cards2tensor(list_cards, compress_form = False):
"""
Convert a list of integers to the tensor
representation
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
"""
if compress_form:
if len(list_cards) == 0:
return torch.zeros(69, dtype=torch.int8)
matrix = np.zeros([5, 14], dtype=np.int8)
counter = Counter(list_cards)
joker_cnt = 0
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
elif card == 20:
if num_times == 2:
joker_cnt |= 0b11
else:
joker_cnt |= 0b01
elif card == 30:
if num_times == 2:
joker_cnt |= 0b1100
else:
joker_cnt |= 0b0100
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
matrix = matrix.flatten('F')[:-1]
matrix = torch.from_numpy(matrix)
return matrix
else:
if len(list_cards) == 0:
return torch.zeros(108, dtype=torch.int8)
matrix = np.zeros([8, 14], dtype=np.int8)
counter = Counter(list_cards)
joker_cnt = 0
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
elif card == 20:
if num_times == 2:
joker_cnt |= 0b11
else:
joker_cnt |= 0b01
elif card == 30:
if num_times == 2:
joker_cnt |= 0b1100
else:
joker_cnt |= 0b0100
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
matrix = matrix.flatten('F')[:-4]
matrix = torch.from_numpy(matrix)
return matrix