Douzero_Resnet/douzero/dmc/utils.py

312 lines
13 KiB
Python
Raw Normal View History

2021-09-07 16:38:34 +08:00
import os
2022-01-01 23:21:14 +08:00
import queue
import threading
2021-09-07 16:38:34 +08:00
import typing
import logging
import traceback
import numpy as np
from collections import Counter
import time
from douzero.radam.radam import RAdam
import torch
from torch import multiprocessing as mp
from .env_utils import Environment
from douzero.env import Env
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
11: 8, 12: 9, 13: 10, 14: 11, 17: 12}
2021-12-05 12:03:30 +08:00
NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
2: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
3: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
4: np.array([1, 1, 1, 1, 0, 0, 0, 0]),
5: np.array([1, 1, 1, 1, 1, 0, 0, 0]),
6: np.array([1, 1, 1, 1, 1, 1, 0, 0]),
7: np.array([1, 1, 1, 1, 1, 1, 1, 0]),
8: np.array([1, 1, 1, 1, 1, 1, 1, 1])}
2021-09-07 16:38:34 +08:00
2021-12-21 14:58:28 +08:00
NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
3: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
4: np.array([0, 0, 1, 0, 0, 0, 0, 0]),
5: np.array([1, 0, 1, 0, 0, 0, 0, 0]),
7: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
12: np.array([0, 0, 1, 1, 0, 0, 0, 0]),
13: np.array([1, 0, 1, 1, 0, 0, 0, 0]),
15: np.array([1, 1, 1, 1, 0, 0, 0, 0])}
2021-12-22 21:19:10 +08:00
NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
2: np.array([1, 1, 0, 0, 0]),
3: np.array([1, 1, 1, 0, 0]),
4: np.array([1, 1, 1, 1, 0]),
5: np.array([1, 0, 0, 0, 1]),
6: np.array([1, 1, 0, 0, 1]),
7: np.array([1, 1, 1, 0, 1]),
8: np.array([1, 1, 1, 1, 1])}
NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]),
1: np.array([1, 0, 0, 0, 0]),
3: np.array([1, 1, 0, 0, 0]),
4: np.array([0, 0, 1, 0, 0]),
5: np.array([1, 0, 1, 0, 0]),
7: np.array([1, 1, 1, 0, 0]),
12: np.array([0, 0, 1, 1, 0]),
13: np.array([1, 0, 1, 1, 0]),
15: np.array([1, 1, 1, 1, 0])}
2021-09-07 16:38:34 +08:00
shandle = logging.StreamHandler()
shandle.setFormatter(
logging.Formatter(
'[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
'%(message)s'))
log = logging.getLogger('doudzero')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)
# Buffers are used to transfer data between actor processes
# and learner processes. They are shared tensors in GPU
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_env(flags):
2021-12-22 21:19:10 +08:00
return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model)
2021-09-07 16:38:34 +08:00
def get_batch(b_queues, position, flags, lock):
"""
This function will sample a batch from the buffers based
on the indices received from the full queue. It will also
free the indices by sending it to full_queue.
"""
b_queue = b_queues[position]
buffer = []
while len(buffer) < flags.batch_size:
buffer.append(b_queue.get())
2021-12-19 17:19:32 +08:00
if flags.old_model:
2021-12-10 16:28:03 +08:00
batch = {
key: torch.stack([m[key] for m in buffer], dim=1)
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
}
else:
batch = {
key: torch.stack([m[key] for m in buffer], dim=1)
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"]
}
2021-09-07 16:38:34 +08:00
del buffer
return batch
def create_optimizers(flags, learner_model):
"""
Create three optimizers for the three positions
"""
2021-12-19 17:19:32 +08:00
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
2021-09-07 16:38:34 +08:00
optimizers = {}
for position in positions:
optimizer = RAdam(
learner_model.parameters(position),
lr=flags.learning_rate,
eps=flags.epsilon)
optimizers[position] = optimizer
return optimizers
2022-01-01 23:21:14 +08:00
def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
2021-12-19 17:19:32 +08:00
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if not flags.enable_onnx:
for pos in positions:
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
2022-01-01 23:21:14 +08:00
last_onnx_frame = -1
log.info('Infer %i started.', i)
while True:
# print("posi", position)
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
last_onnx_frame = onnx_frame.value
model.set_onnx_model(device)
all_empty = True
for infer_queue in infer_queues:
try:
task = infer_queue['input'].get_nowait()
with torch.no_grad():
result = model.forward(task['position'], task['z_batch'], task['x_batch'], return_value=True, flags=flags)
infer_queue['output'].put({
'values': result['values']
})
all_empty = False
except queue.Empty:
pass
if all_empty:
time.sleep(0.01)
2022-01-02 22:33:48 +08:00
def act_queue(i, infer_queue, batch_queues, flags):
2022-01-01 23:21:14 +08:00
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
2021-09-07 16:38:34 +08:00
try:
T = flags.unroll_length
2022-01-01 23:21:14 +08:00
log.info('Actor %i started.', i)
2021-09-07 16:38:34 +08:00
env = create_env(flags)
2022-01-01 23:21:14 +08:00
device = 'cpu'
2021-09-07 16:38:34 +08:00
env = Environment(env, device)
done_buf = {p: [] for p in positions}
episode_return_buf = {p: [] for p in positions}
target_buf = {p: [] for p in positions}
obs_z_buf = {p: [] for p in positions}
size = {p: 0 for p in positions}
type_buf = {p: [] for p in positions}
obs_x_batch_buf = {p: [] for p in positions}
if flags.old_model:
obs_action_buf = {p: [] for p in positions}
obs_x_no_action = {p: [] for p in positions}
2021-09-07 16:38:34 +08:00
2021-12-05 12:03:30 +08:00
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
2021-09-07 16:38:34 +08:00
2022-01-02 22:33:48 +08:00
position, obs, env_output = env.initial(flags=flags)
2021-09-07 16:38:34 +08:00
while True:
while True:
2021-12-14 08:53:50 +08:00
if len(obs['legal_actions']) > 1:
2022-01-01 23:21:14 +08:00
infer_queue['input'].put({
'position': position,
'z_batch': obs['z_batch'],
'x_batch': obs['x_batch']
})
result = infer_queue['output'].get()
action = np.argmax(result['values'], axis=0)[0]
_action_idx = int(action)
2021-12-14 08:53:50 +08:00
action = obs['legal_actions'][_action_idx]
else:
action = obs['legal_actions'][0]
2021-12-19 17:19:32 +08:00
if flags.old_model:
2021-12-22 21:19:10 +08:00
obs_action_buf[position].append(_cards2tensor(action, flags.lite_model))
obs_x_no_action[position].append(env_output['obs_x_no_action'])
obs_z_buf[position].append(env_output['obs_z'])
else:
2021-12-22 21:19:10 +08:00
obs_z_buf[position].append(torch.vstack((_cards2tensor(action, flags.lite_model).unsqueeze(0), env_output['obs_z'])).float())
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
x_batch = env_output['obs_x_no_action'].float()
obs_x_batch_buf[position].append(x_batch)
2021-09-07 16:38:34 +08:00
type_buf[position].append(position_index[position])
2022-01-02 22:33:48 +08:00
position, obs, env_output = env.step(action, flags=flags)
2021-09-07 16:38:34 +08:00
size[position] += 1
if env_output['done']:
for p in positions:
diff = size[p] - len(target_buf[p])
# print(p, diff)
if diff > 0:
done_buf[p].extend([False for _ in range(diff-1)])
done_buf[p].append(True)
2021-12-19 17:19:32 +08:00
episode_return = env_output['episode_return']["play"][p] if p == 'landlord' else -env_output['episode_return']["play"][p]
episode_return_buf[p].extend([0.0 for _ in range(diff-1)])
episode_return_buf[p].append(episode_return)
target_buf[p].extend([episode_return for _ in range(diff)])
2021-09-07 16:38:34 +08:00
break
for p in positions:
2021-12-05 12:03:30 +08:00
while size[p] > T:
2021-09-07 16:38:34 +08:00
# print(p, "epr", torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),)
2021-12-19 17:19:32 +08:00
if flags.old_model:
batch_queues[p].put({
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_action_buf[p][:T]]),
"obs_x_no_action": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_no_action[p][:T]]),
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
})
obs_action_buf[p] = obs_action_buf[p][T:]
obs_x_no_action[p] = obs_x_no_action[p][T:]
else:
batch_queues[p].put({
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_batch_buf[p][:T]]),
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
})
obs_x_batch_buf[p] = obs_x_batch_buf[p][T:]
2021-09-07 16:38:34 +08:00
done_buf[p] = done_buf[p][T:]
episode_return_buf[p] = episode_return_buf[p][T:]
target_buf[p] = target_buf[p][T:]
obs_z_buf[p] = obs_z_buf[p][T:]
type_buf[p] = type_buf[p][T:]
size[p] -= T
except KeyboardInterrupt:
pass
except Exception as e:
log.error('Exception in worker process %i', i)
traceback.print_exc()
print()
raise e
2022-01-02 22:33:48 +08:00
def act(i, infer_queues, batch_queues, flags):
2022-01-01 23:21:14 +08:00
threads = []
for x in range(len(infer_queues)):
thread = threading.Thread(
target=act_queue, name='act_queue-%d-%d' % (i, x),
2022-01-02 22:33:48 +08:00
args=(x, infer_queues[x], batch_queues, flags))
2022-01-01 23:21:14 +08:00
thread.setDaemon(True)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
2021-12-22 21:19:10 +08:00
def _cards2tensor(list_cards, compress_form = False):
2021-09-07 16:38:34 +08:00
"""
Convert a list of integers to the tensor
representation
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
"""
2021-12-22 21:19:10 +08:00
if compress_form:
if len(list_cards) == 0:
return torch.zeros(69, dtype=torch.int8)
matrix = np.zeros([5, 14], dtype=np.int8)
counter = Counter(list_cards)
joker_cnt = 0
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times]
elif card == 20:
if num_times == 2:
joker_cnt |= 0b11
else:
joker_cnt |= 0b01
elif card == 30:
if num_times == 2:
joker_cnt |= 0b1100
else:
joker_cnt |= 0b0100
matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt]
matrix = matrix.flatten('F')[:-1]
matrix = torch.from_numpy(matrix)
return matrix
else:
if len(list_cards) == 0:
return torch.zeros(108, dtype=torch.int8)
matrix = np.zeros([8, 14], dtype=np.int8)
counter = Counter(list_cards)
joker_cnt = 0
for card, num_times in counter.items():
if card < 20:
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
elif card == 20:
if num_times == 2:
joker_cnt |= 0b11
else:
joker_cnt |= 0b01
elif card == 30:
if num_times == 2:
joker_cnt |= 0b1100
else:
joker_cnt |= 0b0100
matrix[:, 13] = NumOnesJoker2Array[joker_cnt]
matrix = matrix.flatten('F')[:-4]
matrix = torch.from_numpy(matrix)
return matrix