import os import threading import time import timeit import pprint from collections import deque import numpy as np import torch from torch import multiprocessing as mp from torch import nn import douzero.dmc.models import douzero.env.env from .file_writer import FileWriter from .models import Model, OldModel from .utils import get_batch, log, create_env, create_optimizers, act, infer_logic import psutil import shutil import requests mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']} def compute_loss(logits, targets): loss = ((logits.squeeze(-1) - targets)**2).mean() return loss def learn(position, actor_model, model, batch, optimizer, flags, lock): """Performs a learning (optimization) step.""" position_index = {"landlord": 31, "landlord_up": 32, 'landlord_front': 33, "landlord_down": 34} print("Learn", position) if flags.training_device != "cpu": device = torch.device('cuda:'+str(flags.training_device)) else: device = torch.device('cpu') if flags.old_model: obs_x_no_action = batch['obs_x_no_action'].to(device) obs_action = batch['obs_x_batch'].to(device) obs_x = torch.cat((obs_x_no_action, obs_action), dim=2).float() obs_x = torch.flatten(obs_x, 0, 1) else: obs_x = batch["obs_x_batch"] obs_x = torch.flatten(obs_x, 0, 1).to(device) obs_z = torch.flatten(batch['obs_z'].to(device), 0, 1).float() target = torch.flatten(batch['target'].to(device), 0, 1) episode_returns = batch['episode_return'][batch['done'] & (batch["obs_type"] == position_index[position])] if len(episode_returns) > 0: mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device)) with lock: learner_outputs = model(obs_z, obs_x) loss = compute_loss(learner_outputs['values'], target) stats = { 'mean_episode_return_'+position: torch.mean(torch.stack([_r for _r in mean_episode_return_buf[position]])).item(), 'loss_'+position: loss.item(), } optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) optimizer.step() if not flags.enable_onnx: actor_model.get_model(position).load_state_dict(model.state_dict()) return stats def train(flags): """ This is the main funtion for training. It will first initilize everything, such as buffers, optimizers, etc. Then it will start subprocesses as actors. Then, it will call learning function with multiple threads. """ if flags.training_device != 'cpu' or flags.infer_devices != 'cpu': if not torch.cuda.is_available(): raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --infer_devices cpu --training_device cpu`") plogger = FileWriter( xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir, ) checkpointpath = os.path.expandvars( os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar'))) T = flags.unroll_length B = flags.batch_size # Initialize actor models if flags.old_model: actor_model = OldModel(device="cpu", flags = flags, lite_model = flags.lite_model) else: actor_model = Model(device="cpu", flags = flags, lite_model = flags.lite_model) actor_model.eval() # Initialize queues actor_processes = [] ctx = mp.get_context('spawn') batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue()} onnx_frame = ctx.Value('d', -1) # Learner model for training if flags.old_model: learner_model = OldModel(device=flags.training_device, lite_model = flags.lite_model) else: learner_model = Model(device=flags.training_device, lite_model = flags.lite_model) # Stat Keys stat_keys = [ 'mean_episode_return_landlord', 'loss_landlord', 'mean_episode_return_landlord_up', 'loss_landlord_up', 'mean_episode_return_landlord_front', 'loss_landlord_front', 'mean_episode_return_landlord_down', 'loss_landlord_down' ] frames, stats = 0, {k: 0 for k in stat_keys} position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} def sync_onnx_model(frames): p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid) if not os.path.exists(p_path): os.makedirs(p_path) for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: if flags.enable_onnx: model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position) onnx_params = learner_model.get_model(position)\ .get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device)) with position_locks[position]: torch.onnx.export( learner_model.get_model(position), onnx_params['args'], model_path, export_params=True, opset_version=10, do_constant_folding=True, input_names=onnx_params['input_names'], output_names=onnx_params['output_names'], dynamic_axes=onnx_params['dynamic_axes'] ) onnx_frame.value = frames # Create optimizers optimizers = create_optimizers(flags, learner_model) # Load models if any if flags.load_model and os.path.exists(checkpointpath): checkpoint_states = torch.load( checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu") ) for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: # ['landlord', 'landlord_up', 'landlord_down'] learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k]) if not flags.enable_onnx: actor_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) stats = checkpoint_states["stats"] frames = checkpoint_states["frames"] position_frames = checkpoint_states["position_frames"] sync_onnx_model(frames) log.info(f"Resuming preempted job, current stats:\n{stats}") infer_queues = [] num_actors = flags.num_actors for j in range(flags.num_actors_thread): for i in range(num_actors): infer_queues.append({ 'input': ctx.Queue(maxsize=100), 'output': ctx.Queue(maxsize=100) }) infer_processes = [] for device in flags.infer_devices.split(','): for i in range(flags.num_infer if device != 'cpu' else 1): infer = mp.Process( target=infer_logic, args=(i, device, infer_queues, actor_model, flags, onnx_frame)) infer.daemon = True infer.start() infer_processes.append({ 'device': device, 'i': i, 'infer': infer }) # Starting actor processes for i in range(num_actors): actor = mp.Process( target=act, args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags)) actor.daemon = True actor.start() actor_processes.append({ 'i': i, 'actor': actor }) parent = psutil.Process() parent.nice(psutil.NORMAL_PRIORITY_CLASS) for child in parent.children(): child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) def batch_and_learn(i, position, local_lock, position_lock, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal frames, position_frames, stats while frames < flags.total_frames: batch = get_batch(batch_queues, position, flags, local_lock) _stats = learn(position, actor_model, learner_model.get_model(position), batch, optimizers[position], flags, position_lock) with lock: for k in _stats: stats[k] = _stats[k] to_log = dict(frames=frames) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) frames += T * B position_frames[position] += T * B threads = [] locks = {} locks['cpu'] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} for i in range(flags.num_threads): for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: thread = threading.Thread( target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,position, locks['cpu'][position],position_locks[position])) thread.setDaemon(True) thread.start() threads.append(thread) def checkpoint(frames): if flags.disable_checkpoint: return log.info('Saving checkpoint to %s', checkpointpath) _models = learner_model.get_models() torch.save({ 'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()} 'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()} "stats": stats, 'flags': vars(flags), 'frames': frames, 'position_frames': position_frames }, checkpointpath + '.new') # Save the weights for evaluation purpose for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: model_weights_dir = os.path.expandvars(os.path.expanduser( '%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt'))) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) if flags.enable_upload: if flags.lite_model: type = 'lite_' else: type = '' if flags.old_model: type += 'vanilla' else: type += 'resnet' requests.post(flags.upload_url, data={ 'type': type, 'position': position, 'frame': frames }, files = {'model_file':('model.ckpt', open(model_weights_dir, 'rb'))}) os.remove(model_weights_dir) shutil.move(checkpointpath + '.new', checkpointpath) fps_log = [] timer = timeit.default_timer try: last_checkpoint_time = timer() - flags.save_interval * 60 last_onnx_sync_time = timer() while frames < flags.total_frames: start_frames = frames position_start_frames = {k: position_frames[k] for k in position_frames} start_time = timer() time.sleep(5) if timer() - last_checkpoint_time > flags.save_interval * 60: checkpoint(frames) last_checkpoint_time = timer() if timer() - last_onnx_sync_time > flags.onnx_sync_interval: sync_onnx_model(frames) last_onnx_sync_time = timer() end_time = timer() fps = (frames - start_frames) / (end_time - start_time) fps_log.append(fps) if len(fps_log) > 240: fps_log = fps_log[1:] fps_avg = np.mean(fps_log) position_fps = {k:(position_frames[k]-position_start_frames[k])/(end_time-start_time) for k in position_frames} log.info('After %i (L:%i U:%i F:%i D:%i) frames: @ %.1f fps (avg@ %.1f fps) (L:%.1f U:%.1f F:%.1f D:%.1f) Stats:\n%s', frames, position_frames['landlord'], position_frames['landlord_up'], position_frames['landlord_front'], position_frames['landlord_down'], fps, fps_avg, position_fps['landlord'], position_fps['landlord_up'], position_fps['landlord_front'], position_fps['landlord_down'], pprint.pformat(stats)) for proc in actor_processes: if not proc['actor'].is_alive(): i = proc['i'] actor = mp.Process( target=act, args=(i, infer_queues[i * 4: (i + 1) * 4], batch_queues, flags)) actor.daemon = True actor.start() proc['actor'] = actor for proc in infer_processes: if not proc['infer'].is_alive(): infer = mp.Process( target=infer_logic, args=(proc['i'], proc['device'], infer_queues, actor_model, flags, onnx_frame)) infer.daemon = True infer.start() proc['infer'] = actor except KeyboardInterrupt: flags.enable_upload = False checkpoint(frames) return else: for thread in threads: thread.join() log.info('Learning finished after %d frames.', frames) checkpoint(frames) plogger.close()