Merge branch 'main' of ssh://git@git.zaneyork.cn:2222/douzero/Douzero_Resnet.git into main
This commit is contained in:
commit
afa295e42b
|
@ -13,15 +13,17 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp'
|
||||||
# Training settings
|
# Training settings
|
||||||
parser.add_argument('--onnx_sync_interval', default=120, type=int,
|
parser.add_argument('--onnx_sync_interval', default=120, type=int,
|
||||||
help='Time interval (in seconds) at which to sync the onnx model')
|
help='Time interval (in seconds) at which to sync the onnx model')
|
||||||
parser.add_argument('--actor_device_cpu', action='store_true',
|
|
||||||
help='Use CPU as actor device')
|
|
||||||
parser.add_argument('--gpu_devices', default='0', type=str,
|
parser.add_argument('--gpu_devices', default='0', type=str,
|
||||||
help='Which GPUs to be used for training')
|
help='Which GPUs to be used for training')
|
||||||
|
parser.add_argument('--infer_devices', default='0', type=str,
|
||||||
|
help='Which device to be used for infer')
|
||||||
|
parser.add_argument('--num_infer', default=2, type=int,
|
||||||
|
help='The number of process used for infer')
|
||||||
parser.add_argument('--num_actor_devices', default=1, type=int,
|
parser.add_argument('--num_actor_devices', default=1, type=int,
|
||||||
help='The number of devices used for simulation')
|
help='The number of devices used for simulation')
|
||||||
parser.add_argument('--num_actors', default=2, type=int,
|
parser.add_argument('--num_actors', default=3, type=int,
|
||||||
help='The number of actors for each simulation device')
|
help='The number of actors for each simulation device')
|
||||||
parser.add_argument('--num_actors_cpu', default=1, type=int,
|
parser.add_argument('--num_actors_thread', default=4, type=int,
|
||||||
help='The number of actors for each simulation device')
|
help='The number of actors for each simulation device')
|
||||||
parser.add_argument('--training_device', default='0', type=str,
|
parser.add_argument('--training_device', default='0', type=str,
|
||||||
help='The index of the GPU used for training models. `cpu` means using cpu')
|
help='The index of the GPU used for training models. `cpu` means using cpu')
|
||||||
|
|
|
@ -14,7 +14,7 @@ import douzero.dmc.models
|
||||||
import douzero.env.env
|
import douzero.env.env
|
||||||
from .file_writer import FileWriter
|
from .file_writer import FileWriter
|
||||||
from .models import Model, OldModel
|
from .models import Model, OldModel
|
||||||
from .utils import get_batch, log, create_env, create_optimizers, act
|
from .utils import get_batch, log, create_env, create_optimizers, act, infer_logic
|
||||||
import psutil
|
import psutil
|
||||||
import shutil
|
import shutil
|
||||||
import requests
|
import requests
|
||||||
|
@ -25,7 +25,7 @@ def compute_loss(logits, targets):
|
||||||
loss = ((logits.squeeze(-1) - targets)**2).mean()
|
loss = ((logits.squeeze(-1) - targets)**2).mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def learn(position, actor_models, model, batch, optimizer, flags, lock):
|
def learn(position, actor_model, model, batch, optimizer, flags, lock):
|
||||||
"""Performs a learning (optimization) step."""
|
"""Performs a learning (optimization) step."""
|
||||||
position_index = {"landlord": 31, "landlord_up": 32, 'landlord_front': 33, "landlord_down": 34}
|
position_index = {"landlord": 31, "landlord_up": 32, 'landlord_front': 33, "landlord_down": 34}
|
||||||
print("Learn", position)
|
print("Learn", position)
|
||||||
|
@ -60,7 +60,6 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if not flags.enable_onnx:
|
if not flags.enable_onnx:
|
||||||
for actor_model in actor_models.values():
|
|
||||||
actor_model.get_model(position).load_state_dict(model.state_dict())
|
actor_model.get_model(position).load_state_dict(model.state_dict())
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
@ -71,9 +70,9 @@ def train(flags):
|
||||||
Then it will start subprocesses as actors. Then, it will call
|
Then it will start subprocesses as actors. Then, it will call
|
||||||
learning function with multiple threads.
|
learning function with multiple threads.
|
||||||
"""
|
"""
|
||||||
if not flags.actor_device_cpu or flags.training_device != 'cpu':
|
if flags.training_device != 'cpu' or flags.infer_devices != 'cpu':
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`")
|
raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --infer_devices cpu --training_device cpu`")
|
||||||
plogger = FileWriter(
|
plogger = FileWriter(
|
||||||
xpid=flags.xpid,
|
xpid=flags.xpid,
|
||||||
xp_args=flags.__dict__,
|
xp_args=flags.__dict__,
|
||||||
|
@ -85,22 +84,12 @@ def train(flags):
|
||||||
T = flags.unroll_length
|
T = flags.unroll_length
|
||||||
B = flags.batch_size
|
B = flags.batch_size
|
||||||
|
|
||||||
if flags.actor_device_cpu:
|
|
||||||
device_iterator = ['cpu']
|
|
||||||
else:
|
|
||||||
device_iterator = range(flags.num_actor_devices) #[0, 'cpu']
|
|
||||||
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
|
|
||||||
|
|
||||||
# Initialize actor models
|
# Initialize actor models
|
||||||
models = {}
|
|
||||||
for device in device_iterator:
|
|
||||||
if flags.old_model:
|
if flags.old_model:
|
||||||
model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model)
|
actor_model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model)
|
||||||
else:
|
else:
|
||||||
model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
|
actor_model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
|
||||||
model.share_memory()
|
actor_model.eval()
|
||||||
model.eval()
|
|
||||||
models[device] = model
|
|
||||||
|
|
||||||
# Initialize queues
|
# Initialize queues
|
||||||
actor_processes = []
|
actor_processes = []
|
||||||
|
@ -114,9 +103,6 @@ def train(flags):
|
||||||
else:
|
else:
|
||||||
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
|
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
|
||||||
|
|
||||||
# Create optimizers
|
|
||||||
optimizers = create_optimizers(flags, learner_model)
|
|
||||||
|
|
||||||
# Stat Keys
|
# Stat Keys
|
||||||
stat_keys = [
|
stat_keys = [
|
||||||
'mean_episode_return_landlord',
|
'mean_episode_return_landlord',
|
||||||
|
@ -155,6 +141,9 @@ def train(flags):
|
||||||
)
|
)
|
||||||
onnx_frame.value = frames
|
onnx_frame.value = frames
|
||||||
|
|
||||||
|
# Create optimizers
|
||||||
|
optimizers = create_optimizers(flags, learner_model)
|
||||||
|
|
||||||
# Load models if any
|
# Load models if any
|
||||||
if flags.load_model and os.path.exists(checkpointpath):
|
if flags.load_model and os.path.exists(checkpointpath):
|
||||||
checkpoint_states = torch.load(
|
checkpoint_states = torch.load(
|
||||||
|
@ -164,8 +153,7 @@ def train(flags):
|
||||||
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||||
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
|
||||||
if not flags.enable_onnx:
|
if not flags.enable_onnx:
|
||||||
for device in device_iterator:
|
actor_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
||||||
models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
|
|
||||||
stats = checkpoint_states["stats"]
|
stats = checkpoint_states["stats"]
|
||||||
|
|
||||||
frames = checkpoint_states["frames"]
|
frames = checkpoint_states["frames"]
|
||||||
|
@ -173,20 +161,36 @@ def train(flags):
|
||||||
sync_onnx_model(frames)
|
sync_onnx_model(frames)
|
||||||
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
log.info(f"Resuming preempted job, current stats:\n{stats}")
|
||||||
|
|
||||||
# Starting actor processes
|
infer_queues = []
|
||||||
for device in device_iterator:
|
|
||||||
if device == 'cpu':
|
|
||||||
num_actors = flags.num_actors_cpu
|
|
||||||
else:
|
|
||||||
num_actors = flags.num_actors
|
num_actors = flags.num_actors
|
||||||
|
for j in range(flags.num_actors_thread):
|
||||||
|
for i in range(num_actors):
|
||||||
|
infer_queues.append({
|
||||||
|
'input': ctx.Queue(maxsize=100), 'output': ctx.Queue(maxsize=100)
|
||||||
|
})
|
||||||
|
|
||||||
|
infer_processes = []
|
||||||
|
for device in flags.infer_devices.split(','):
|
||||||
|
for i in range(flags.num_infer if device != 'cpu' else 1):
|
||||||
|
infer = mp.Process(
|
||||||
|
target=infer_logic,
|
||||||
|
args=(i, device, infer_queues, actor_model, flags, onnx_frame))
|
||||||
|
infer.daemon = True
|
||||||
|
infer.start()
|
||||||
|
infer_processes.append({
|
||||||
|
'device': device,
|
||||||
|
'i': i,
|
||||||
|
'infer': infer
|
||||||
|
})
|
||||||
|
|
||||||
|
# Starting actor processes
|
||||||
for i in range(num_actors):
|
for i in range(num_actors):
|
||||||
actor = mp.Process(
|
actor = mp.Process(
|
||||||
target=act,
|
target=act,
|
||||||
args=(i, device, batch_queues, models[device], flags, onnx_frame))
|
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
|
||||||
actor.daemon = True
|
actor.daemon = True
|
||||||
actor.start()
|
actor.start()
|
||||||
actor_processes.append({
|
actor_processes.append({
|
||||||
'device': device,
|
|
||||||
'i': i,
|
'i': i,
|
||||||
'actor': actor
|
'actor': actor
|
||||||
})
|
})
|
||||||
|
@ -201,7 +205,7 @@ def train(flags):
|
||||||
nonlocal frames, position_frames, stats
|
nonlocal frames, position_frames, stats
|
||||||
while frames < flags.total_frames:
|
while frames < flags.total_frames:
|
||||||
batch = get_batch(batch_queues, position, flags, local_lock)
|
batch = get_batch(batch_queues, position, flags, local_lock)
|
||||||
_stats = learn(position, models, learner_model.get_model(position), batch,
|
_stats = learn(position, actor_model, learner_model.get_model(position), batch,
|
||||||
optimizers[position], flags, position_lock)
|
optimizers[position], flags, position_lock)
|
||||||
with lock:
|
with lock:
|
||||||
for k in _stats:
|
for k in _stats:
|
||||||
|
@ -215,13 +219,12 @@ def train(flags):
|
||||||
|
|
||||||
threads = []
|
threads = []
|
||||||
locks = {}
|
locks = {}
|
||||||
for device in device_iterator:
|
locks['cpu'] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||||
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
|
||||||
|
|
||||||
for i in range(flags.num_threads):
|
for i in range(flags.num_threads):
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,position,locks[device][position],position_locks[position]))
|
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,position, locks['cpu'][position],position_locks[position]))
|
||||||
thread.setDaemon(True)
|
thread.setDaemon(True)
|
||||||
thread.start()
|
thread.start()
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
|
@ -305,13 +308,23 @@ def train(flags):
|
||||||
pprint.pformat(stats))
|
pprint.pformat(stats))
|
||||||
for proc in actor_processes:
|
for proc in actor_processes:
|
||||||
if not proc['actor'].is_alive():
|
if not proc['actor'].is_alive():
|
||||||
|
i = proc['i']
|
||||||
actor = mp.Process(
|
actor = mp.Process(
|
||||||
target=act,
|
target=act,
|
||||||
args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame))
|
args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags))
|
||||||
actor.daemon = True
|
actor.daemon = True
|
||||||
actor.start()
|
actor.start()
|
||||||
proc['actor'] = actor
|
proc['actor'] = actor
|
||||||
|
|
||||||
|
for proc in infer_processes:
|
||||||
|
if not proc['infer'].is_alive():
|
||||||
|
infer = mp.Process(
|
||||||
|
target=infer_logic,
|
||||||
|
args=(proc['i'], proc['device'], infer_queues, actor_model, flags, onnx_frame))
|
||||||
|
infer.daemon = True
|
||||||
|
infer.start()
|
||||||
|
proc['infer'] = actor
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
flags.enable_upload = False
|
flags.enable_upload = False
|
||||||
checkpoint(frames)
|
checkpoint(frames)
|
||||||
|
|
|
@ -6,21 +6,14 @@ the environment, we do it automatically.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def _format_observation(obs, device, flags):
|
def _format_observation(obs):
|
||||||
"""
|
"""
|
||||||
A utility function to process observations and
|
A utility function to process observations and
|
||||||
move them to CUDA.
|
move them to CUDA.
|
||||||
"""
|
"""
|
||||||
position = obs['position']
|
position = obs['position']
|
||||||
if flags.enable_onnx:
|
|
||||||
x_batch = obs['x_batch']
|
x_batch = obs['x_batch']
|
||||||
z_batch = obs['z_batch']
|
z_batch = obs['z_batch']
|
||||||
else:
|
|
||||||
if not device == "cpu":
|
|
||||||
device = 'cuda:' + str(device)
|
|
||||||
device = torch.device(device)
|
|
||||||
x_batch = torch.from_numpy(obs['x_batch']).to(device)
|
|
||||||
z_batch = torch.from_numpy(obs['z_batch']).to(device)
|
|
||||||
x_no_action = torch.from_numpy(obs['x_no_action'])
|
x_no_action = torch.from_numpy(obs['x_no_action'])
|
||||||
z = torch.from_numpy(obs['z'])
|
z = torch.from_numpy(obs['z'])
|
||||||
obs = {'x_batch': x_batch,
|
obs = {'x_batch': x_batch,
|
||||||
|
@ -37,9 +30,9 @@ class Environment:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.episode_return = None
|
self.episode_return = None
|
||||||
|
|
||||||
def initial(self, model, device, flags=None):
|
def initial(self, flags=None):
|
||||||
obs = self.env.reset(model, device, flags=flags)
|
obs = self.env.reset(flags=flags)
|
||||||
initial_position, initial_obs, x_no_action, z = _format_observation(obs, self.device, flags)
|
initial_position, initial_obs, x_no_action, z = _format_observation(obs)
|
||||||
self.episode_return = torch.zeros(1, 1)
|
self.episode_return = torch.zeros(1, 1)
|
||||||
initial_done = torch.ones(1, 1, dtype=torch.bool)
|
initial_done = torch.ones(1, 1, dtype=torch.bool)
|
||||||
return initial_position, initial_obs, dict(
|
return initial_position, initial_obs, dict(
|
||||||
|
@ -49,16 +42,16 @@ class Environment:
|
||||||
obs_z=z,
|
obs_z=z,
|
||||||
)
|
)
|
||||||
|
|
||||||
def step(self, action, model, device, flags=None):
|
def step(self, action, flags=None):
|
||||||
obs, reward, done, _ = self.env.step(action)
|
obs, reward, done, _ = self.env.step(action)
|
||||||
|
|
||||||
self.episode_return = reward
|
self.episode_return = reward
|
||||||
episode_return = self.episode_return
|
episode_return = self.episode_return
|
||||||
if done:
|
if done:
|
||||||
obs = self.env.reset(model, device, flags=flags)
|
obs = self.env.reset(flags=flags)
|
||||||
self.episode_return = torch.zeros(1, 1)
|
self.episode_return = torch.zeros(1, 1)
|
||||||
|
|
||||||
position, obs, x_no_action, z = _format_observation(obs, self.device, flags)
|
position, obs, x_no_action, z = _format_observation(obs)
|
||||||
# reward = torch.tensor(reward).view(1, 1)
|
# reward = torch.tensor(reward).view(1, 1)
|
||||||
done = torch.tensor(done).view(1, 1)
|
done = torch.tensor(done).view(1, 1)
|
||||||
|
|
||||||
|
|
|
@ -493,8 +493,11 @@ model_dict_new_lite['landlord_up'] = GeneralModelLite
|
||||||
model_dict_new_lite['landlord_front'] = GeneralModelLite
|
model_dict_new_lite['landlord_front'] = GeneralModelLite
|
||||||
model_dict_new_lite['landlord_down'] = GeneralModelLite
|
model_dict_new_lite['landlord_down'] = GeneralModelLite
|
||||||
|
|
||||||
def forward_logic(self_model, position, z, x, return_value=False, flags=None):
|
def forward_logic(self_model, position, z, x, device='cpu', return_value=False, flags=None):
|
||||||
legal_count = len(z)
|
legal_count = len(z)
|
||||||
|
if not flags.enable_onnx:
|
||||||
|
z = torch.tensor(z, device=device)
|
||||||
|
x = torch.tensor(x, device=device)
|
||||||
if legal_count >= 80:
|
if legal_count >= 80:
|
||||||
partition_count = int(legal_count / 40)
|
partition_count = int(legal_count / 40)
|
||||||
sub_z = np.array_split(z, partition_count)
|
sub_z = np.array_split(z, partition_count)
|
||||||
|
@ -577,8 +580,8 @@ class OldModel:
|
||||||
def get_onnx_params(self, position):
|
def get_onnx_params(self, position):
|
||||||
self.models[position].get_onnx_params(self.device)
|
self.models[position].get_onnx_params(self.device)
|
||||||
|
|
||||||
def forward(self, position, z, x, return_value=False, flags=None):
|
def forward(self, position, z, x, device='cpu', return_value=False, flags=None):
|
||||||
return forward_logic(self, position, z, x, return_value, flags)
|
return forward_logic(self, position, z, x, device, return_value, flags)
|
||||||
|
|
||||||
def share_memory(self):
|
def share_memory(self):
|
||||||
if self.models['landlord'] is not None:
|
if self.models['landlord'] is not None:
|
||||||
|
@ -646,8 +649,8 @@ class Model:
|
||||||
def get_onnx_params(self, position):
|
def get_onnx_params(self, position):
|
||||||
self.models[position].get_onnx_params(self.device)
|
self.models[position].get_onnx_params(self.device)
|
||||||
|
|
||||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
def forward(self, position, z, x, device='cpu', return_value=False, flags=None):
|
||||||
return forward_logic(self, position, z, x, return_value, flags)
|
return forward_logic(self, position, z, x, device, return_value, flags)
|
||||||
|
|
||||||
def share_memory(self):
|
def share_memory(self):
|
||||||
if self.models['landlord'] is not None:
|
if self.models['landlord'] is not None:
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
import typing
|
import typing
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -111,16 +113,42 @@ def create_optimizers(flags, learner_model):
|
||||||
return optimizers
|
return optimizers
|
||||||
|
|
||||||
|
|
||||||
def act(i, device, batch_queues, model, flags, onnx_frame):
|
def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
|
||||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
if not flags.enable_onnx:
|
if not flags.enable_onnx:
|
||||||
for pos in positions:
|
for pos in positions:
|
||||||
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
|
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
|
||||||
|
last_onnx_frame = -1
|
||||||
|
log.info('Infer %i started.', i)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# print("posi", position)
|
||||||
|
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
|
||||||
|
last_onnx_frame = onnx_frame.value
|
||||||
|
model.set_onnx_model(device)
|
||||||
|
all_empty = True
|
||||||
|
for infer_queue in infer_queues:
|
||||||
|
try:
|
||||||
|
task = infer_queue['input'].get_nowait()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model.forward(task['position'], task['z_batch'], task['x_batch'], device=device, return_value=True, flags=flags)
|
||||||
|
infer_queue['output'].put({
|
||||||
|
'values': result['values']
|
||||||
|
})
|
||||||
|
all_empty = False
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
if all_empty:
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def act_queue(i, infer_queue, batch_queues, flags):
|
||||||
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
try:
|
try:
|
||||||
T = flags.unroll_length
|
T = flags.unroll_length
|
||||||
log.info('Device %s Actor %i started.', str(device), i)
|
log.info('Actor %i started.', i)
|
||||||
|
|
||||||
env = create_env(flags)
|
env = create_env(flags)
|
||||||
|
device = 'cpu'
|
||||||
env = Environment(env, device)
|
env = Environment(env, device)
|
||||||
|
|
||||||
done_buf = {p: [] for p in positions}
|
done_buf = {p: [] for p in positions}
|
||||||
|
@ -136,19 +164,18 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
|
|
||||||
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
|
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
|
||||||
|
|
||||||
position, obs, env_output = env.initial(model, device, flags=flags)
|
position, obs, env_output = env.initial(flags=flags)
|
||||||
last_onnx_frame = -1
|
|
||||||
while True:
|
while True:
|
||||||
# print("posi", position)
|
|
||||||
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
|
|
||||||
last_onnx_frame = onnx_frame.value
|
|
||||||
model.set_onnx_model(device)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if len(obs['legal_actions']) > 1:
|
if len(obs['legal_actions']) > 1:
|
||||||
with torch.no_grad():
|
infer_queue['input'].put({
|
||||||
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags)
|
'position': position,
|
||||||
_action_idx = int(agent_output['action'])
|
'z_batch': obs['z_batch'],
|
||||||
|
'x_batch': obs['x_batch']
|
||||||
|
})
|
||||||
|
result = infer_queue['output'].get()
|
||||||
|
action = np.argmax(result['values'], axis=0)[0]
|
||||||
|
_action_idx = int(action)
|
||||||
action = obs['legal_actions'][_action_idx]
|
action = obs['legal_actions'][_action_idx]
|
||||||
else:
|
else:
|
||||||
action = obs['legal_actions'][0]
|
action = obs['legal_actions'][0]
|
||||||
|
@ -162,7 +189,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
x_batch = env_output['obs_x_no_action'].float()
|
x_batch = env_output['obs_x_no_action'].float()
|
||||||
obs_x_batch_buf[position].append(x_batch)
|
obs_x_batch_buf[position].append(x_batch)
|
||||||
type_buf[position].append(position_index[position])
|
type_buf[position].append(position_index[position])
|
||||||
position, obs, env_output = env.step(action, model, device, flags=flags)
|
position, obs, env_output = env.step(action, flags=flags)
|
||||||
size[position] += 1
|
size[position] += 1
|
||||||
if env_output['done']:
|
if env_output['done']:
|
||||||
for p in positions:
|
for p in positions:
|
||||||
|
@ -216,6 +243,18 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
|
||||||
print()
|
print()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def act(i, infer_queues, batch_queues, flags):
|
||||||
|
threads = []
|
||||||
|
for x in range(len(infer_queues)):
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=act_queue, name='act_queue-%d-%d' % (i, x),
|
||||||
|
args=(x, infer_queues[x], batch_queues, flags))
|
||||||
|
thread.setDaemon(True)
|
||||||
|
thread.start()
|
||||||
|
threads.append(thread)
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
def _cards2tensor(list_cards, compress_form = False):
|
def _cards2tensor(list_cards, compress_form = False):
|
||||||
"""
|
"""
|
||||||
Convert a list of integers to the tensor
|
Convert a list of integers to the tensor
|
||||||
|
|
|
@ -91,7 +91,7 @@ class Env:
|
||||||
self.total_round = 0
|
self.total_round = 0
|
||||||
self.infoset = None
|
self.infoset = None
|
||||||
|
|
||||||
def reset(self, model, device, flags=None):
|
def reset(self, flags=None):
|
||||||
"""
|
"""
|
||||||
Every time reset is called, the environment
|
Every time reset is called, the environment
|
||||||
will be re-initialized with a new deck of cards.
|
will be re-initialized with a new deck of cards.
|
||||||
|
@ -100,21 +100,6 @@ class Env:
|
||||||
self._env.reset()
|
self._env.reset()
|
||||||
|
|
||||||
# Randomly shuffle the deck
|
# Randomly shuffle the deck
|
||||||
if model is None:
|
|
||||||
_deck = deck.copy()
|
|
||||||
np.random.shuffle(_deck)
|
|
||||||
card_play_data = {'landlord': _deck[:33],
|
|
||||||
'landlord_up': _deck[33:58],
|
|
||||||
'landlord_front': _deck[58:83],
|
|
||||||
'landlord_down': _deck[83:108],
|
|
||||||
# 'three_landlord_cards': _deck[17:20],
|
|
||||||
}
|
|
||||||
for key in card_play_data:
|
|
||||||
card_play_data[key].sort()
|
|
||||||
self._env.card_play_init(card_play_data)
|
|
||||||
self.infoset = self._game_infoset
|
|
||||||
return get_obs(self.infoset, self.use_general, self.lite_model)
|
|
||||||
else:
|
|
||||||
self.total_round += 1
|
self.total_round += 1
|
||||||
_deck = deck.copy()
|
_deck = deck.copy()
|
||||||
np.random.shuffle(_deck)
|
np.random.shuffle(_deck)
|
||||||
|
|
Loading…
Reference in New Issue