2021-09-07 16:38:34 +08:00
|
|
|
import os
|
|
|
|
import typing
|
|
|
|
import logging
|
|
|
|
import traceback
|
|
|
|
import numpy as np
|
|
|
|
from collections import Counter
|
|
|
|
import time
|
|
|
|
from douzero.radam.radam import RAdam
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import multiprocessing as mp
|
|
|
|
|
|
|
|
from .env_utils import Environment
|
|
|
|
from douzero.env import Env
|
|
|
|
|
|
|
|
Card2Column = {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7,
|
|
|
|
11: 8, 12: 9, 13: 10, 14: 11, 17: 12}
|
|
|
|
|
2021-12-05 12:03:30 +08:00
|
|
|
NumOnes2Array = {0: np.array([0, 0, 0, 0, 0, 0, 0, 0]),
|
|
|
|
1: np.array([1, 0, 0, 0, 0, 0, 0, 0]),
|
|
|
|
2: np.array([1, 1, 0, 0, 0, 0, 0, 0]),
|
|
|
|
3: np.array([1, 1, 1, 0, 0, 0, 0, 0]),
|
|
|
|
4: np.array([1, 1, 1, 1, 0, 0, 0, 0]),
|
|
|
|
5: np.array([1, 1, 1, 1, 1, 0, 0, 0]),
|
|
|
|
6: np.array([1, 1, 1, 1, 1, 1, 0, 0]),
|
|
|
|
7: np.array([1, 1, 1, 1, 1, 1, 1, 0]),
|
|
|
|
8: np.array([1, 1, 1, 1, 1, 1, 1, 1])}
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
shandle = logging.StreamHandler()
|
|
|
|
shandle.setFormatter(
|
|
|
|
logging.Formatter(
|
|
|
|
'[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
|
|
|
|
'%(message)s'))
|
|
|
|
log = logging.getLogger('doudzero')
|
|
|
|
log.propagate = False
|
|
|
|
log.addHandler(shandle)
|
|
|
|
log.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
# Buffers are used to transfer data between actor processes
|
|
|
|
# and learner processes. They are shared tensors in GPU
|
|
|
|
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
|
|
|
|
|
|
|
|
def create_env(flags):
|
2021-12-10 16:12:50 +08:00
|
|
|
return Env(flags.objective, flags.old_model)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
def get_batch(b_queues, position, flags, lock):
|
|
|
|
"""
|
|
|
|
This function will sample a batch from the buffers based
|
|
|
|
on the indices received from the full queue. It will also
|
|
|
|
free the indices by sending it to full_queue.
|
|
|
|
"""
|
|
|
|
b_queue = b_queues[position]
|
|
|
|
buffer = []
|
|
|
|
while len(buffer) < flags.batch_size:
|
|
|
|
buffer.append(b_queue.get())
|
|
|
|
batch = {
|
|
|
|
key: torch.stack([m[key] for m in buffer], dim=1)
|
2021-12-10 16:12:50 +08:00
|
|
|
for key in ["done", "episode_return", "target", "obs_z", "obs_x_batch", "obs_x_no_action", "obs_type"]
|
2021-09-07 16:38:34 +08:00
|
|
|
}
|
|
|
|
del buffer
|
|
|
|
return batch
|
|
|
|
|
|
|
|
def create_optimizers(flags, learner_model):
|
|
|
|
"""
|
|
|
|
Create three optimizers for the three positions
|
|
|
|
"""
|
2021-12-05 12:03:30 +08:00
|
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']
|
2021-09-07 16:38:34 +08:00
|
|
|
optimizers = {}
|
|
|
|
for position in positions:
|
|
|
|
optimizer = RAdam(
|
|
|
|
learner_model.parameters(position),
|
|
|
|
lr=flags.learning_rate,
|
|
|
|
eps=flags.epsilon)
|
|
|
|
optimizers[position] = optimizer
|
|
|
|
return optimizers
|
|
|
|
|
|
|
|
|
|
|
|
def act(i, device, batch_queues, model, flags):
|
2021-12-05 12:03:30 +08:00
|
|
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']
|
2021-09-07 16:38:34 +08:00
|
|
|
for pos in positions:
|
|
|
|
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
|
|
|
|
try:
|
|
|
|
T = flags.unroll_length
|
|
|
|
log.info('Device %s Actor %i started.', str(device), i)
|
|
|
|
|
|
|
|
env = create_env(flags)
|
|
|
|
env = Environment(env, device)
|
|
|
|
|
|
|
|
done_buf = {p: [] for p in positions}
|
|
|
|
episode_return_buf = {p: [] for p in positions}
|
|
|
|
target_buf = {p: [] for p in positions}
|
|
|
|
obs_z_buf = {p: [] for p in positions}
|
|
|
|
size = {p: 0 for p in positions}
|
|
|
|
type_buf = {p: [] for p in positions}
|
|
|
|
obs_x_batch_buf = {p: [] for p in positions}
|
2021-12-10 16:12:50 +08:00
|
|
|
if flags.old_model:
|
|
|
|
obs_action_buf = {p: [] for p in positions}
|
|
|
|
obs_x_no_action = {p: [] for p in positions}
|
2021-09-07 16:38:34 +08:00
|
|
|
|
2021-12-05 12:03:30 +08:00
|
|
|
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
|
2021-12-07 11:18:44 +08:00
|
|
|
bid_type_index = {"landlord": 41, "landlord_up": 42, "landlord_front": 43, "landlord_down": 44}
|
2021-12-05 12:03:30 +08:00
|
|
|
bid_type_map = {41: "landlord", 42: "landlord_up", 43: "landlord_front", 44: "landlord_down"}
|
2021-09-07 16:38:34 +08:00
|
|
|
|
|
|
|
position, obs, env_output = env.initial(model, device, flags=flags)
|
|
|
|
bid_obs_buffer = env_output["begin_buf"]["bid_obs_buffer"]
|
|
|
|
multiply_obs_buffer = env_output["begin_buf"]["multiply_obs_buffer"]
|
|
|
|
while True:
|
|
|
|
# print("posi", position)
|
|
|
|
for bid_obs in bid_obs_buffer:
|
|
|
|
obs_z_buf["bidding"].append(bid_obs['z_batch'])
|
|
|
|
obs_x_batch_buf["bidding"].append(bid_obs["x_batch"])
|
|
|
|
type_buf["bidding"].append(bid_type_index[bid_obs["position"]])
|
|
|
|
size["bidding"] += 1
|
|
|
|
for mul_obs in multiply_obs_buffer:
|
|
|
|
obs_z_buf[mul_obs["position"]].append(mul_obs['z_batch'])
|
|
|
|
obs_x_batch_buf[mul_obs["position"]].append(mul_obs["x_batch"])
|
|
|
|
type_buf[mul_obs["position"]].append(2)
|
|
|
|
size[mul_obs["position"]] += 1
|
|
|
|
while True:
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
|
|
|
_action_idx = int(agent_output['action'].cpu().detach().numpy())
|
|
|
|
action = obs['legal_actions'][_action_idx]
|
2021-12-10 16:12:50 +08:00
|
|
|
if flags.old_model and position != 'bidding':
|
|
|
|
obs_action_buf[position].append(_cards2tensor(action))
|
|
|
|
obs_x_no_action[position].append(env_output['obs_x_no_action'])
|
|
|
|
obs_z_buf[position].append(env_output['obs_z'])
|
|
|
|
else:
|
|
|
|
obs_z_buf[position].append(torch.vstack((_cards2tensor(action).unsqueeze(0), env_output['obs_z'])).float())
|
|
|
|
# x_batch = torch.cat((env_output['obs_x_no_action'], _cards2tensor(action)), dim=0).float()
|
|
|
|
x_batch = env_output['obs_x_no_action'].float()
|
|
|
|
obs_x_batch_buf[position].append(x_batch)
|
2021-09-07 16:38:34 +08:00
|
|
|
type_buf[position].append(position_index[position])
|
|
|
|
position, obs, env_output = env.step(action, model, device, flags=flags)
|
|
|
|
size[position] += 1
|
|
|
|
if env_output['done']:
|
|
|
|
bid_obs_buffer = env_output["begin_buf"]["bid_obs_buffer"]
|
|
|
|
multiply_obs_buffer = env_output["begin_buf"]["multiply_obs_buffer"]
|
|
|
|
for p in positions:
|
|
|
|
diff = size[p] - len(target_buf[p])
|
|
|
|
# print(p, diff)
|
|
|
|
if diff > 0:
|
|
|
|
done_buf[p].extend([False for _ in range(diff-1)])
|
|
|
|
done_buf[p].append(True)
|
|
|
|
if p != "bidding":
|
|
|
|
episode_return = env_output['episode_return']["play"][p] if p == 'landlord' else -env_output['episode_return']["play"][p]
|
|
|
|
episode_return_buf[p].extend([0.0 for _ in range(diff-1)])
|
|
|
|
episode_return_buf[p].append(episode_return)
|
|
|
|
target_buf[p].extend([episode_return for _ in range(diff)])
|
|
|
|
else:
|
|
|
|
offset = len(target_buf[p])
|
|
|
|
for index in range(diff):
|
|
|
|
pos = type_buf[p][index+offset]
|
|
|
|
if pos == 41:
|
|
|
|
episode_return = env_output['episode_return']["bid"]["landlord"]
|
|
|
|
else:
|
|
|
|
episode_return = -env_output['episode_return']["bid"][bid_type_map[pos]]
|
|
|
|
episode_return_buf[p].append(episode_return)
|
|
|
|
# print(p, episode_return)
|
|
|
|
target_buf[p].append(episode_return)
|
|
|
|
break
|
|
|
|
for p in positions:
|
2021-12-05 12:03:30 +08:00
|
|
|
while size[p] > T:
|
2021-09-07 16:38:34 +08:00
|
|
|
# print(p, "epr", torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),)
|
2021-12-10 16:12:50 +08:00
|
|
|
if flags.old_model and p != 'bidding':
|
|
|
|
batch_queues[p].put({
|
|
|
|
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
|
|
|
|
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
|
|
|
|
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
|
|
|
|
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
|
|
|
|
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_action_buf[p][:T]]),
|
|
|
|
"obs_x_no_action": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_no_action[p][:T]]),
|
|
|
|
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
|
|
|
|
})
|
|
|
|
obs_action_buf[p] = obs_action_buf[p][T:]
|
|
|
|
obs_x_no_action[p] = obs_x_no_action[p][T:]
|
|
|
|
else:
|
|
|
|
batch_queues[p].put({
|
|
|
|
"done": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in done_buf[p][:T]]),
|
|
|
|
"episode_return": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in episode_return_buf[p][:T]]),
|
|
|
|
"target": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in target_buf[p][:T]]),
|
|
|
|
"obs_z": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_z_buf[p][:T]]),
|
|
|
|
"obs_x_batch": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in obs_x_batch_buf[p][:T]]),
|
|
|
|
"obs_type": torch.stack([torch.tensor(ndarr, device="cpu") for ndarr in type_buf[p][:T]])
|
|
|
|
})
|
|
|
|
obs_x_batch_buf[p] = obs_x_batch_buf[p][T:]
|
2021-09-07 16:38:34 +08:00
|
|
|
done_buf[p] = done_buf[p][T:]
|
|
|
|
episode_return_buf[p] = episode_return_buf[p][T:]
|
|
|
|
target_buf[p] = target_buf[p][T:]
|
|
|
|
obs_z_buf[p] = obs_z_buf[p][T:]
|
|
|
|
type_buf[p] = type_buf[p][T:]
|
|
|
|
size[p] -= T
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
pass
|
|
|
|
except Exception as e:
|
|
|
|
log.error('Exception in worker process %i', i)
|
|
|
|
traceback.print_exc()
|
|
|
|
print()
|
|
|
|
raise e
|
|
|
|
|
|
|
|
def _cards2tensor(list_cards):
|
|
|
|
"""
|
|
|
|
Convert a list of integers to the tensor
|
|
|
|
representation
|
|
|
|
See Figure 2 in https://arxiv.org/pdf/2106.06135.pdf
|
|
|
|
"""
|
|
|
|
if len(list_cards) == 0:
|
2021-12-05 12:03:30 +08:00
|
|
|
return torch.zeros(108, dtype=torch.int8)
|
2021-09-07 16:38:34 +08:00
|
|
|
|
2021-12-05 12:03:30 +08:00
|
|
|
matrix = np.zeros([8, 13], dtype=np.int8)
|
|
|
|
jokers = np.zeros(4, dtype=np.int8)
|
2021-09-07 16:38:34 +08:00
|
|
|
counter = Counter(list_cards)
|
|
|
|
for card, num_times in counter.items():
|
|
|
|
if card < 20:
|
|
|
|
matrix[:, Card2Column[card]] = NumOnes2Array[num_times]
|
|
|
|
elif card == 20:
|
|
|
|
jokers[0] = 1
|
2021-12-05 12:03:30 +08:00
|
|
|
if num_times == 2:
|
|
|
|
jokers[1] = 1
|
2021-09-07 16:38:34 +08:00
|
|
|
elif card == 30:
|
2021-12-05 12:03:30 +08:00
|
|
|
jokers[2] = 1
|
|
|
|
if num_times == 2:
|
|
|
|
jokers[3] = 1
|
2021-09-07 16:38:34 +08:00
|
|
|
matrix = np.concatenate((matrix.flatten('F'), jokers))
|
|
|
|
matrix = torch.from_numpy(matrix)
|
|
|
|
return matrix
|