独立infer进程

This commit is contained in:
ZaneYork 2022-01-01 17:39:36 +08:00
parent b40eed2acc
commit d8b5bb71e6
4 changed files with 98 additions and 74 deletions

View File

@ -17,6 +17,8 @@ parser.add_argument('--actor_device_cpu', action='store_true',
help='Use CPU as actor device') help='Use CPU as actor device')
parser.add_argument('--gpu_devices', default='0', type=str, parser.add_argument('--gpu_devices', default='0', type=str,
help='Which GPUs to be used for training') help='Which GPUs to be used for training')
parser.add_argument('--num_infer', default=2, type=int,
help='The number of process used for infer')
parser.add_argument('--num_actor_devices', default=1, type=int, parser.add_argument('--num_actor_devices', default=1, type=int,
help='The number of devices used for simulation') help='The number of devices used for simulation')
parser.add_argument('--num_actors', default=2, type=int, parser.add_argument('--num_actors', default=2, type=int,

View File

@ -14,7 +14,7 @@ import douzero.dmc.models
import douzero.env.env import douzero.env.env
from .file_writer import FileWriter from .file_writer import FileWriter
from .models import Model, OldModel from .models import Model, OldModel
from .utils import get_batch, log, create_env, create_optimizers, act from .utils import get_batch, log, create_env, create_optimizers, act, infer_logic
import psutil import psutil
import shutil import shutil
import requests import requests
@ -25,7 +25,7 @@ def compute_loss(logits, targets):
loss = ((logits.squeeze(-1) - targets)**2).mean() loss = ((logits.squeeze(-1) - targets)**2).mean()
return loss return loss
def learn(position, actor_models, model, batch, optimizer, flags, lock): def learn(position, actor_model, model, batch, optimizer, flags, lock):
"""Performs a learning (optimization) step.""" """Performs a learning (optimization) step."""
position_index = {"landlord": 31, "landlord_up": 32, 'landlord_front': 33, "landlord_down": 34} position_index = {"landlord": 31, "landlord_up": 32, 'landlord_front': 33, "landlord_down": 34}
print("Learn", position) print("Learn", position)
@ -60,7 +60,6 @@ def learn(position, actor_models, model, batch, optimizer, flags, lock):
optimizer.step() optimizer.step()
if not flags.enable_onnx: if not flags.enable_onnx:
for actor_model in actor_models.values():
actor_model.get_model(position).load_state_dict(model.state_dict()) actor_model.get_model(position).load_state_dict(model.state_dict())
return stats return stats
@ -85,22 +84,12 @@ def train(flags):
T = flags.unroll_length T = flags.unroll_length
B = flags.batch_size B = flags.batch_size
if flags.actor_device_cpu:
device_iterator = ['cpu']
else:
device_iterator = range(flags.num_actor_devices) #[0, 'cpu']
assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices'
# Initialize actor models # Initialize actor models
models = {}
for device in device_iterator:
if flags.old_model: if flags.old_model:
model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model) actor_model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model)
else: else:
model = Model(device="cpu", flags = flags, lite_model = flags.lite_model) actor_model = Model(device="cpu", flags = flags, lite_model = flags.lite_model)
model.share_memory() actor_model.eval()
model.eval()
models[device] = model
# Initialize queues # Initialize queues
actor_processes = [] actor_processes = []
@ -114,9 +103,6 @@ def train(flags):
else: else:
learner_model = Model(device=flags.training_device, lite_model = flags.lite_model) learner_model = Model(device=flags.training_device, lite_model = flags.lite_model)
# Create optimizers
optimizers = create_optimizers(flags, learner_model)
# Stat Keys # Stat Keys
stat_keys = [ stat_keys = [
'mean_episode_return_landlord', 'mean_episode_return_landlord',
@ -155,6 +141,9 @@ def train(flags):
) )
onnx_frame.value = frames onnx_frame.value = frames
# Create optimizers
optimizers = create_optimizers(flags, learner_model)
# Load models if any # Load models if any
if flags.load_model and os.path.exists(checkpointpath): if flags.load_model and os.path.exists(checkpointpath):
checkpoint_states = torch.load( checkpoint_states = torch.load(
@ -164,8 +153,7 @@ def train(flags):
learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k]) optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k])
if not flags.enable_onnx: if not flags.enable_onnx:
for device in device_iterator: actor_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k])
stats = checkpoint_states["stats"] stats = checkpoint_states["stats"]
frames = checkpoint_states["frames"] frames = checkpoint_states["frames"]
@ -173,20 +161,35 @@ def train(flags):
sync_onnx_model(frames) sync_onnx_model(frames)
log.info(f"Resuming preempted job, current stats:\n{stats}") log.info(f"Resuming preempted job, current stats:\n{stats}")
# Starting actor processes infer_queues = []
for device in device_iterator:
if device == 'cpu':
num_actors = flags.num_actors_cpu
else:
num_actors = flags.num_actors num_actors = flags.num_actors
for i in range(num_actors):
infer_queues.append({
'input': ctx.Queue(), 'output': ctx.Queue()
})
infer_processes = []
for device in ['0']:
for i in range(flags.num_infer):
infer = mp.Process(
target=infer_logic,
args=(i, device, infer_queues, actor_model, flags, onnx_frame))
infer.daemon = True
infer.start()
infer_processes.append({
'device': device,
'i': i,
'infer': infer
})
# Starting actor processes
for i in range(num_actors): for i in range(num_actors):
actor = mp.Process( actor = mp.Process(
target=act, target=act,
args=(i, device, batch_queues, models[device], flags, onnx_frame)) args=(i, infer_queues[i]['input'], infer_queues[i]['output'], batch_queues, actor_model, flags))
actor.daemon = True actor.daemon = True
actor.start() actor.start()
actor_processes.append({ actor_processes.append({
'device': device,
'i': i, 'i': i,
'actor': actor 'actor': actor
}) })
@ -201,7 +204,7 @@ def train(flags):
nonlocal frames, position_frames, stats nonlocal frames, position_frames, stats
while frames < flags.total_frames: while frames < flags.total_frames:
batch = get_batch(batch_queues, position, flags, local_lock) batch = get_batch(batch_queues, position, flags, local_lock)
_stats = learn(position, models, learner_model.get_model(position), batch, _stats = learn(position, actor_model, learner_model.get_model(position), batch,
optimizers[position], flags, position_lock) optimizers[position], flags, position_lock)
with lock: with lock:
for k in _stats: for k in _stats:
@ -215,13 +218,12 @@ def train(flags):
threads = [] threads = []
locks = {} locks = {}
for device in device_iterator: locks['cpu'] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
for i in range(flags.num_threads): for i in range(flags.num_threads):
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
thread = threading.Thread( thread = threading.Thread(
target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,position,locks[device][position],position_locks[position])) target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,position, locks['cpu'][position],position_locks[position]))
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
threads.append(thread) threads.append(thread)
@ -303,14 +305,14 @@ def train(flags):
position_fps['landlord_front'], position_fps['landlord_front'],
position_fps['landlord_down'], position_fps['landlord_down'],
pprint.pformat(stats)) pprint.pformat(stats))
for proc in actor_processes: # for proc in actor_processes:
if not proc['actor'].is_alive(): # if not proc['actor'].is_alive():
actor = mp.Process( # actor = mp.Process(
target=act, # target=act,
args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame)) # args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame))
actor.daemon = True # actor.daemon = True
actor.start() # actor.start()
proc['actor'] = actor # proc['actor'] = actor
except KeyboardInterrupt: except KeyboardInterrupt:
flags.enable_upload = False flags.enable_upload = False

View File

@ -6,21 +6,14 @@ the environment, we do it automatically.
import numpy as np import numpy as np
import torch import torch
def _format_observation(obs, device, flags): def _format_observation(obs):
""" """
A utility function to process observations and A utility function to process observations and
move them to CUDA. move them to CUDA.
""" """
position = obs['position'] position = obs['position']
if flags.enable_onnx:
x_batch = obs['x_batch'] x_batch = obs['x_batch']
z_batch = obs['z_batch'] z_batch = obs['z_batch']
else:
if not device == "cpu":
device = 'cuda:' + str(device)
device = torch.device(device)
x_batch = torch.from_numpy(obs['x_batch']).to(device)
z_batch = torch.from_numpy(obs['z_batch']).to(device)
x_no_action = torch.from_numpy(obs['x_no_action']) x_no_action = torch.from_numpy(obs['x_no_action'])
z = torch.from_numpy(obs['z']) z = torch.from_numpy(obs['z'])
obs = {'x_batch': x_batch, obs = {'x_batch': x_batch,
@ -39,7 +32,7 @@ class Environment:
def initial(self, model, device, flags=None): def initial(self, model, device, flags=None):
obs = self.env.reset(model, device, flags=flags) obs = self.env.reset(model, device, flags=flags)
initial_position, initial_obs, x_no_action, z = _format_observation(obs, self.device, flags) initial_position, initial_obs, x_no_action, z = _format_observation(obs)
self.episode_return = torch.zeros(1, 1) self.episode_return = torch.zeros(1, 1)
initial_done = torch.ones(1, 1, dtype=torch.bool) initial_done = torch.ones(1, 1, dtype=torch.bool)
return initial_position, initial_obs, dict( return initial_position, initial_obs, dict(
@ -58,7 +51,7 @@ class Environment:
obs = self.env.reset(model, device, flags=flags) obs = self.env.reset(model, device, flags=flags)
self.episode_return = torch.zeros(1, 1) self.episode_return = torch.zeros(1, 1)
position, obs, x_no_action, z = _format_observation(obs, self.device, flags) position, obs, x_no_action, z = _format_observation(obs)
# reward = torch.tensor(reward).view(1, 1) # reward = torch.tensor(reward).view(1, 1)
done = torch.tensor(done).view(1, 1) done = torch.tensor(done).view(1, 1)

View File

@ -1,4 +1,5 @@
import os import os
import queue
import typing import typing
import logging import logging
import traceback import traceback
@ -111,16 +112,44 @@ def create_optimizers(flags, learner_model):
return optimizers return optimizers
def act(i, device, batch_queues, model, flags, onnx_frame): def infer_logic(i, device, infer_queues, model, flags, onnx_frame):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if not flags.enable_onnx: if not flags.enable_onnx:
for pos in positions: for pos in positions:
model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device)))) model.models[pos].to(torch.device(device if device == "cpu" else ("cuda:"+str(device))))
last_onnx_frame = -1
log.info('Infer %i started.', i)
while True:
# print("posi", position)
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
last_onnx_frame = onnx_frame.value
model.set_onnx_model(device)
all_empty = True
for infer_queue in infer_queues:
try:
task = infer_queue['input'].get_nowait()
with torch.no_grad():
agent_output = model.forward(task['position'], task['z_batch'], task['x_batch'], flags=flags)
_action_idx = int(agent_output['action'])
infer_queue['output'].put({
'action': _action_idx
})
all_empty = False
except queue.Empty:
pass
if all_empty:
time.sleep(0.01)
def act(i, input_queue, output_queue, batch_queues, actor_model, flags):
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
try: try:
T = flags.unroll_length T = flags.unroll_length
log.info('Device %s Actor %i started.', str(device), i) log.info('Actor %i started.', i)
env = create_env(flags) env = create_env(flags)
device = 'cpu'
env = Environment(env, device) env = Environment(env, device)
done_buf = {p: [] for p in positions} done_buf = {p: [] for p in positions}
@ -136,18 +165,16 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34} position_index = {"landlord": 31, "landlord_up": 32, "landlord_front": 33, "landlord_down": 34}
position, obs, env_output = env.initial(model, device, flags=flags) position, obs, env_output = env.initial(actor_model, device, flags=flags)
last_onnx_frame = -1
while True: while True:
# print("posi", position)
if flags.enable_onnx and onnx_frame.value != last_onnx_frame:
last_onnx_frame = onnx_frame.value
model.set_onnx_model(device)
while True: while True:
if len(obs['legal_actions']) > 1: if len(obs['legal_actions']) > 1:
with torch.no_grad(): input_queue.put({
agent_output = model.forward(position, obs['z_batch'], obs['x_batch'], flags=flags) 'position': position,
'z_batch': obs['z_batch'],
'x_batch': obs['x_batch']
})
agent_output = output_queue.get()
_action_idx = int(agent_output['action']) _action_idx = int(agent_output['action'])
action = obs['legal_actions'][_action_idx] action = obs['legal_actions'][_action_idx]
else: else:
@ -162,7 +189,7 @@ def act(i, device, batch_queues, model, flags, onnx_frame):
x_batch = env_output['obs_x_no_action'].float() x_batch = env_output['obs_x_no_action'].float()
obs_x_batch_buf[position].append(x_batch) obs_x_batch_buf[position].append(x_batch)
type_buf[position].append(position_index[position]) type_buf[position].append(position_index[position])
position, obs, env_output = env.step(action, model, device, flags=flags) position, obs, env_output = env.step(action, actor_model, device, flags=flags)
size[position] += 1 size[position] += 1
if env_output['done']: if env_output['done']:
for p in positions: for p in positions: