import os import threading import time import timeit import pprint from collections import deque import numpy as np import torch from torch import multiprocessing as mp from torch import nn import douzero.dmc.models import douzero.env.env from .file_writer import FileWriter from .models import Model, OldModel from .utils import get_batch, log, create_env, create_optimizers, act import psutil import shutil mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']} def compute_loss(logits, targets): loss = ((logits.squeeze(-1) - targets)**2).mean() return loss def learn(position, actor_models, model, batch, optimizer, flags, lock): """Performs a learning (optimization) step.""" position_index = {"landlord": 31, "landlord_up": 32, 'landlord_front': 33, "landlord_down": 34} print("Learn", position) if flags.training_device != "cpu": device = torch.device('cuda:'+str(flags.training_device)) else: device = torch.device('cpu') if flags.old_model: obs_x_no_action = batch['obs_x_no_action'].to(device) obs_action = batch['obs_x_batch'].to(device) obs_x = torch.cat((obs_x_no_action, obs_action), dim=2).float() obs_x = torch.flatten(obs_x, 0, 1) else: obs_x = batch["obs_x_batch"] obs_x = torch.flatten(obs_x, 0, 1).to(device) obs_z = torch.flatten(batch['obs_z'].to(device), 0, 1).float() target = torch.flatten(batch['target'].to(device), 0, 1) episode_returns = batch['episode_return'][batch['done'] & (batch["obs_type"] == position_index[position])] if len(episode_returns) > 0: mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device)) with lock: learner_outputs = model(obs_z, obs_x) loss = compute_loss(learner_outputs['values'], target) stats = { 'mean_episode_return_'+position: torch.mean(torch.stack([_r for _r in mean_episode_return_buf[position]])).item(), 'loss_'+position: loss.item(), } optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) optimizer.step() if not flags.enable_onnx: for actor_model in actor_models.values(): actor_model.get_model(position).load_state_dict(model.state_dict()) return stats def train(flags): """ This is the main funtion for training. It will first initilize everything, such as buffers, optimizers, etc. Then it will start subprocesses as actors. Then, it will call learning function with multiple threads. """ if not flags.actor_device_cpu or flags.training_device != 'cpu': if not torch.cuda.is_available(): raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`") plogger = FileWriter( xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir, ) checkpointpath = os.path.expandvars( os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar'))) T = flags.unroll_length B = flags.batch_size if flags.actor_device_cpu: device_iterator = ['cpu'] else: device_iterator = range(flags.num_actor_devices) #[0, 'cpu'] assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices' # Initialize actor models models = {} for device in device_iterator: if flags.old_model: model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model) else: model = Model(device="cpu", flags = flags, lite_model = flags.lite_model) model.share_memory() model.eval() models[device] = model # Initialize queues actor_processes = [] ctx = mp.get_context('spawn') batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue()} onnx_frame = ctx.Value('d', -1) # Learner model for training if flags.old_model: learner_model = OldModel(device=flags.training_device, lite_model = flags.lite_model) else: learner_model = Model(device=flags.training_device, lite_model = flags.lite_model) # Create optimizers optimizers = create_optimizers(flags, learner_model) # Stat Keys stat_keys = [ 'mean_episode_return_landlord', 'loss_landlord', 'mean_episode_return_landlord_up', 'loss_landlord_up', 'mean_episode_return_landlord_front', 'loss_landlord_front', 'mean_episode_return_landlord_down', 'loss_landlord_down' ] frames, stats = 0, {k: 0 for k in stat_keys} position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} def sync_onnx_model(frames): p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid) if not os.path.exists(p_path): os.makedirs(p_path) for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: if flags.enable_onnx: model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position) onnx_params = learner_model.get_model(position)\ .get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device)) with position_locks[position]: torch.onnx.export( learner_model.get_model(position), onnx_params['args'], model_path, export_params=True, opset_version=10, do_constant_folding=True, input_names=onnx_params['input_names'], output_names=onnx_params['output_names'], dynamic_axes=onnx_params['dynamic_axes'] ) onnx_frame.value = frames # Load models if any if flags.load_model and os.path.exists(checkpointpath): checkpoint_states = torch.load( checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu") ) for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down'] learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k]) if not flags.enable_onnx: for device in device_iterator: models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) stats = checkpoint_states["stats"] frames = checkpoint_states["frames"] position_frames = checkpoint_states["position_frames"] sync_onnx_model(frames) log.info(f"Resuming preempted job, current stats:\n{stats}") # Starting actor processes for device in device_iterator: if device == 'cpu': num_actors = flags.num_actors_cpu else: num_actors = flags.num_actors for i in range(num_actors): actor = mp.Process( target=act, args=(i, device, batch_queues, models[device], flags, onnx_frame)) actor.daemon = True actor.start() actor_processes.append({ 'device': device, 'i': i, 'actor': actor }) parent = psutil.Process() parent.nice(psutil.NORMAL_PRIORITY_CLASS) for child in parent.children(): child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) def batch_and_learn(i, position, local_lock, position_lock, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal frames, position_frames, stats while frames < flags.total_frames: batch = get_batch(batch_queues, position, flags, local_lock) _stats = learn(position, models, learner_model.get_model(position), batch, optimizers[position], flags, position_lock) with lock: for k in _stats: stats[k] = _stats[k] to_log = dict(frames=frames) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) frames += T * B position_frames[position] += T * B threads = [] locks = {} for device in device_iterator: locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} for i in range(flags.num_threads): for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: thread = threading.Thread( target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,position,locks[device][position],position_locks[position])) thread.setDaemon(True) thread.start() threads.append(thread) def checkpoint(frames): if flags.disable_checkpoint: return log.info('Saving checkpoint to %s', checkpointpath) _models = learner_model.get_models() torch.save({ 'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()} 'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()} "stats": stats, 'flags': vars(flags), 'frames': frames, 'position_frames': position_frames }, checkpointpath + '.new') # Save the weights for evaluation purpose for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: model_weights_dir = os.path.expandvars(os.path.expanduser( '%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt'))) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) shutil.move(checkpointpath + '.new', checkpointpath) fps_log = [] timer = timeit.default_timer try: last_checkpoint_time = timer() - flags.save_interval * 60 last_onnx_sync_time = timer() while frames < flags.total_frames: start_frames = frames position_start_frames = {k: position_frames[k] for k in position_frames} start_time = timer() time.sleep(5) if timer() - last_checkpoint_time > flags.save_interval * 60: checkpoint(frames) last_checkpoint_time = timer() if timer() - last_onnx_sync_time > flags.onnx_sync_interval: sync_onnx_model(frames) last_onnx_sync_time = timer() end_time = timer() fps = (frames - start_frames) / (end_time - start_time) fps_log.append(fps) if len(fps_log) > 240: fps_log = fps_log[1:] fps_avg = np.mean(fps_log) position_fps = {k:(position_frames[k]-position_start_frames[k])/(end_time-start_time) for k in position_frames} log.info('After %i (L:%i U:%i F:%i D:%i) frames: @ %.1f fps (avg@ %.1f fps) (L:%.1f U:%.1f F:%.1f D:%.1f) Stats:\n%s', frames, position_frames['landlord'], position_frames['landlord_up'], position_frames['landlord_front'], position_frames['landlord_down'], fps, fps_avg, position_fps['landlord'], position_fps['landlord_up'], position_fps['landlord_front'], position_fps['landlord_down'], pprint.pformat(stats)) for proc in actor_processes: if not proc['actor'].is_alive(): actor = mp.Process( target=act, args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame)) actor.daemon = True actor.start() proc['actor'] = actor except KeyboardInterrupt: checkpoint(frames) return else: for thread in threads: thread.join() log.info('Learning finished after %d frames.', frames) checkpoint(frames) plogger.close()