import os import queue import threading import typing import logging import traceback import numpy as np from collections import Counter import time from douzero.radam.radam import RAdam import torch from torch import multiprocessing as mp from .env_utils import Environment from douzero.env import Env Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7, 11: 8, 12: 9, 13: 10, 14: 11, 17: 12} NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]), 1: np.array([1, 0, 0, 0, 0, 0, 0, 0]), 2: np.array([1, 1, 0, 0, 0, 0, 0, 0]), 3: np.array([1, 1, 1, 0, 0, 0, 0, 0]), 4: np.array([1, 1, 1, 1, 0, 0, 0, 0]), 5: np.array([1, 1, 1, 1, 1, 0, 0, 0]), 6: np.array([1, 1, 1, 1, 1, 1, 0, 0]), 7: np.array([1, 1, 1, 1, 1, 1, 1, 0]), 8: np.array([1, 1, 1, 1, 1, 1, 1, 1])} NumOnesJoker2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]), 1: np.array([1, 0, 0, 0, 0, 0, 0, 0]), 3: np.array([1, 1, 0, 0, 0, 0, 0, 0]), 4: np.array([0, 0, 1, 0, 0, 0, 0, 0]), 5: np.array([1, 0, 1, 0, 0, 0, 0, 0]), 7: np.array([1, 1, 1, 0, 0, 0, 0, 0]), 12: np.array([0, 0, 1, 1, 0, 0, 0, 0]), 13: np.array([1, 0, 1, 1, 0, 0, 0, 0]), 15: np.array([1, 1, 1, 1, 0, 0, 0, 0])} NumOnes2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]), 1: np.array([1, 0, 0, 0, 0]), 2: np.array([1, 1, 0, 0, 0]), 3: np.array([1, 1, 1, 0, 0]), 4: np.array([1, 1, 1, 1, 0]), 5: np.array([1, 0, 0, 0, 1]), 6: np.array([1, 1, 0, 0, 1]), 7: np.array([1, 1, 1, 0, 1]), 8: np.array([1, 1, 1, 1, 1])} NumOnesJoker2ArrayCompressed = {0: np.array([0, 0, 0, 0, 0]), 1: np.array([1, 0, 0, 0, 0]), 3: np.array([1, 1, 0, 0, 0]), 4: np.array([0, 0, 1, 0, 0]), 5: np.array([1, 0, 1, 0, 0]), 7: np.array([1, 1, 1, 0, 0]), 12: np.array([0, 0, 1, 1, 0]), 13: np.array([1, 0, 1, 1, 0]), 15: np.array([1, 1, 1, 1, 0])} shandle = logging.StreamHandler() shandle.setFormatter( logging.Formatter( '[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] ' '%(message)s')) log = logging.getLogger('doudzero') log.propagate = False log.addHandler(shandle) log.setLevel(logging.INFO) # Buffers are used to transfer data between actor processes # and learner processes. They are shared tensors in GPU Buffers = typing.Dict[str, typing.List[torch.Tensor]] def create_env(flags): return Env(flags.objective, flags.old_model, flags.lagacy_model, flags.lite_model, flags.unified_model) def get_batch(b_queues, position, flags, lock): """ This function will sample a batch from the buffers based on the indices received from the full queue. It will also free the indices by sending it to full_queue. """ b_queue = b_queues[position] buffer = [] while len(buffer) < flags.batch_size: buffer.append(b_queue.get()) if flags.old_model: batch = { key: torch.stack([m[key] for m in buffer], dim=1) for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"] } else: batch = { key: torch.stack([m[key] for m in buffer], dim=1) for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_type"] } del buffer return batch def create_optimizers(flags, learner_model): """ Create three optimizers for the three positions """ positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] optimizers = {} if flags.unified_model: position = 'uni' optimizer = RAdam( learner_model.parameters(position), lr=flags.learning_rate, eps=flags.epsilon) optimizers[position] = optimizer else: for position in positions: optimizer = RAdam( learner_model.parameters(position), lr=flags.learning_rate, eps=flags.epsilon) optimizers[position] = optimizer return optimizers def infer_logic(i, device, infer_queues, model, flags, onnx_frame): positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] device = device if device == "cpu" else ("cuda:" + str(device)) if not flags.enable_onnx: if flags.unified_model: model.model.to(torch.device(device)) else: for pos in positions: model.models[pos].to(torch.device(device)) last_onnx_frame = -1 log.info('Infer %i started.', i) while True: # print("posi", position) if flags.enable_onnx and onnx_frame.value != last_onnx_frame: last_onnx_frame = onnx_frame.value model.set_onnx_model(device) all_empty = True for infer_queue in infer_queues: try: task = infer_queue['input'].get_nowait() with torch.no_grad(): result = model.forward(task['position'], task['z_batch'], task['x_batch'], device=device, return_value=True, flags=flags) infer_queue['output'].put({ 'values': result['values'] }) all_empty = False except queue.Empty: pass if all_empty: time.sleep(0.01) def act_queue(i, infer_queue, batch_queues, flags): positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] try: T = flags.unroll_length log.info('Actor %i started.', i) env = create_env(flags) device = 'cpu' env = Environment(env, device) done_buf = {p: [] for p in positions} episode_return_buf = {p: [] for p in positions} target_buf = {p: [] for p in positions} obs_z_buf = {p: [] for p in positions} size = {p: 0 for p in positions} type_buf = {p: [] for p in positions} obs_x_batch_buf = {p: [] for p in positions} if flags.old_model: obs_action_buf = {p: [] for p in positions} obs_x_no_action = {p: [] for p in positions} position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34} position, obs, env_output = env.initial(flags=flags) while True: while True: if len(obs['legal_actions']) > 1: infer_queue['input'].put({ 'position': position, 'z_batch': obs['z_batch'], 'x_batch': obs['x_batch'] }) result = infer_queue['output'].get() action = np.argmax(result['values'], axis=0)[0] _action_idx = int(action) action = obs['legal_actions'][_action_idx] else: action = obs['legal_actions'][0] if flags.old_model: obs_action_buf[position].append(_cards2tensor(action, flags.lite_model)) obs_x_no_action[position].append(env_output['obs_x_no_action']) obs_z_buf[position].append(env_output['obs_z']) else: obs_z_buf[position].append(torch.vstack((_cards2tensor(action, flags.lite_model).unsqueeze(0), env_output['obs_z'])).float()) # x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float() x_batch = env_output['obs_x_no_action'].float() obs_x_batch_buf[position].append(x_batch) type_buf[position].append(position_index[position]) position, obs, env_output = env.step(action, flags=flags) size[position] += 1 if env_output['done']: for p in positions: diff = size[p] - len(target_buf[p]) # print(p, diff) if diff > 0: done_buf[p].extend([False for _ in range(diff-1)]) done_buf[p].append(True) episode_return = env_output['episode_return']["play"][p] if p == 'landlord' else -env_output['episode_return']["play"][p] episode_return_buf[p].extend([0.0 for _ in range(diff-1)]) episode_return_buf[p].append(episode_return) target_buf[p].extend([episode_return for _ in range(diff)]) break for p in positions: while size[p] > T: # print(p, "epr", torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),) if flags.old_model: batch_queues[p].put({ "done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]), "episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]), "target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]), "obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]), "obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_action_buf[p][:T]]), "obs_x_no_action": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_no_action[p][:T]]), "obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]]) }) obs_action_buf[p] = obs_action_buf[p][T:] obs_x_no_action[p] = obs_x_no_action[p][T:] else: batch_queues[p].put({ "done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]), "episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]), "target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]), "obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]), "obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_batch_buf[p][:T]]), "obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]]) }) obs_x_batch_buf[p] = obs_x_batch_buf[p][T:] done_buf[p] = done_buf[p][T:] episode_return_buf[p] = episode_return_buf[p][T:] target_buf[p] = target_buf[p][T:] obs_z_buf[p] = obs_z_buf[p][T:] type_buf[p] = type_buf[p][T:] size[p] -= T except KeyboardInterrupt: pass except Exception as e: log.error('Exception in worker process %i', i) traceback.print_exc() print() raise e def act(i, infer_queues, batch_queues, flags): threads = [] for x in range(len(infer_queues)): thread = threading.Thread( target=act_queue, name='act_queue-%d-%d' % (i, x), args=(x, infer_queues[x], batch_queues, flags)) thread.setDaemon(True) thread.start() threads.append(thread) for thread in threads: thread.join() def _cards2tensor(list_cards, compress_form = False): """ Convert a list of integers to the tensor representation See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf """ if compress_form: if len(list_cards) == 0: return torch.zeros(69, dtype=torch.int8) matrix = np.zeros([5, 14], dtype=np.int8) counter = Counter(list_cards) joker_cnt = 0 for card, num_times in counter.items(): if card < 20: matrix[:, Card2Column[card]] = NumOnes2ArrayCompressed[num_times] elif card == 20: if num_times == 2: joker_cnt |= 0b11 else: joker_cnt |= 0b01 elif card == 30: if num_times == 2: joker_cnt |= 0b1100 else: joker_cnt |= 0b0100 matrix[:, 13] = NumOnesJoker2ArrayCompressed[joker_cnt] matrix = matrix.flatten('F')[:-1] matrix = torch.from_numpy(matrix) return matrix else: if len(list_cards) == 0: return torch.zeros(108, dtype=torch.int8) matrix = np.zeros([8, 14], dtype=np.int8) counter = Counter(list_cards) joker_cnt = 0 for card, num_times in counter.items(): if card < 20: matrix[:, Card2Column[card]] = NumOnes2Array[num_times] elif card == 20: if num_times == 2: joker_cnt |= 0b11 else: joker_cnt |= 0b01 elif card == 30: if num_times == 2: joker_cnt |= 0b1100 else: joker_cnt |= 0b0100 matrix[:, 13] = NumOnesJoker2Array[joker_cnt] matrix = matrix.flatten('F')[:-4] matrix = torch.from_numpy(matrix) return matrix