import os import threading import time import timeit import pprint from collections import deque import numpy as np import torch from torch import multiprocessing as mp from torch import nn import douzero.dmc.models import douzero.env.env from .file_writer import FileWriter from .models import Model, OldModel from .utils import get_batch, log, create_env, create_optimizers, act import psutil import shutil mean_episode_return_buf = {p:deque(maxlen=100) for p in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']} onnx_frame = mp.Value('d', -1) def compute_loss(logits, targets): loss = ((logits.squeeze(-1) - targets)**2).mean() return loss def compute_loss_for_bid(outputs, reward): pass def learn(position, actor_models, model, batch, optimizer, flags, lock): """Performs a learning (optimization) step.""" position_index = {"landlord": 31, "landlord_up": 32, 'landlord_front': 33, "landlord_down": 34} print("Learn", position) if flags.training_device != "cpu": device = torch.device('cuda:'+str(flags.training_device)) else: device = torch.device('cpu') if flags.old_model and position != 'bidding': obs_x_no_action = batch['obs_x_no_action'].to(device) obs_action = batch['obs_x_batch'].to(device) obs_x = torch.cat((obs_x_no_action, obs_action), dim=2).float() obs_x = torch.flatten(obs_x, 0, 1) else: obs_x = batch["obs_x_batch"] obs_x = torch.flatten(obs_x, 0, 1).to(device) obs_z = torch.flatten(batch['obs_z'].to(device), 0, 1).float() target = torch.flatten(batch['target'].to(device), 0, 1) if position != "bidding": episode_returns = batch['episode_return'][batch['done'] & (batch["obs_type"] == position_index[position])] else: episode_returns = batch['episode_return'][batch['done'] & ((batch["obs_type"] == 41) | (batch["obs_type"] == 42) | (batch["obs_type"] == 43) | (batch["obs_type"] == 44))] if len(episode_returns) > 0: mean_episode_return_buf[position].append(torch.mean(episode_returns).to(device)) with lock: learner_outputs = model(obs_z, obs_x, return_value=True) if position == "bidding": loss = compute_loss(learner_outputs['values'], target) # pass else: loss = compute_loss(learner_outputs['values'], target) stats = { 'mean_episode_return_'+position: torch.mean(torch.stack([_r for _r in mean_episode_return_buf[position]])).item(), 'loss_'+position: loss.item(), } optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) optimizer.step() for actor_model in actor_models.values(): actor_model.get_model(position).load_state_dict(model.state_dict()) return stats def train(flags): """ This is the main funtion for training. It will first initilize everything, such as buffers, optimizers, etc. Then it will start subprocesses as actors. Then, it will call learning function with multiple threads. """ global onnx_frame if not flags.actor_device_cpu or flags.training_device != 'cpu': if not torch.cuda.is_available(): raise AssertionError("CUDA not available. If you have GPUs, please specify the ID after `--gpu_devices`. Otherwise, please train with CPU with `python3 train.py --actor_device_cpu --training_device cpu`") plogger = FileWriter( xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir, ) checkpointpath = os.path.expandvars( os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar'))) T = flags.unroll_length B = flags.batch_size if flags.actor_device_cpu: device_iterator = ['cpu'] else: device_iterator = range(flags.num_actor_devices) #[0, 'cpu'] assert flags.num_actor_devices <= len(flags.gpu_devices.split(',')), 'The number of actor devices can not exceed the number of available devices' # Initialize actor models models = {} for device in device_iterator: if flags.old_model: model = OldModel(device="cpu") else: model = Model(device="cpu") model.share_memory() model.eval() models[device] = model # Initialize queues actor_processes = [] ctx = mp.get_context('spawn') batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue(), "bidding": ctx.SimpleQueue()} # Learner model for training if flags.old_model: learner_model = OldModel(device=flags.training_device) else: learner_model = Model(device=flags.training_device) # Create optimizers optimizers = create_optimizers(flags, learner_model) # Stat Keys stat_keys = [ 'mean_episode_return_landlord', 'loss_landlord', 'mean_episode_return_landlord_up', 'loss_landlord_up', 'mean_episode_return_landlord_front', 'loss_landlord_front', 'mean_episode_return_landlord_down', 'loss_landlord_down', 'mean_episode_return_bidding', 'loss_bidding', ] frames, stats = 0, {k: 0 for k in stat_keys} position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0, 'bidding': 0} # Load models if any if flags.load_model and os.path.exists(checkpointpath): checkpoint_states = torch.load( checkpointpath, map_location=("cuda:"+str(flags.training_device) if flags.training_device != "cpu" else "cpu") ) for k in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: # ['landlord', 'landlord_up', 'landlord_down'] learner_model.get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) optimizers[k].load_state_dict(checkpoint_states["optimizer_state_dict"][k]) for device in device_iterator: models[device].get_model(k).load_state_dict(checkpoint_states["model_state_dict"][k]) stats = checkpoint_states["stats"] if not 'mean_episode_return_bidding' in stats: stats.update({"mean_episode_return_bidding": 0}) if not 'loss_bidding' in stats: stats.update({"loss_bidding": 0}) frames = checkpoint_states["frames"] position_frames = checkpoint_states["position_frames"] if not "bidding" in position_frames: position_frames.update({"bidding": 0}) log.info(f"Resuming preempted job, current stats:\n{stats}") # Starting actor processes for device in device_iterator: if device == 'cpu': num_actors = flags.num_actors_cpu else: num_actors = flags.num_actors for i in range(num_actors): actor = mp.Process( target=act, args=(i, device, batch_queues, models[device], flags, onnx_frame)) actor.daemon = True actor.start() actor_processes.append(actor) parent = psutil.Process() parent.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) for child in parent.children(): child.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) def batch_and_learn(i, device, position, local_lock, position_lock, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal frames, position_frames, stats while frames < flags.total_frames: batch = get_batch(batch_queues, position, flags, local_lock) _stats = learn(position, models, learner_model.get_model(position), batch, optimizers[position], flags, position_lock) with lock: for k in _stats: stats[k] = _stats[k] to_log = dict(frames=frames) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) frames += T * B position_frames[position] += T * B threads = [] locks = {} for device in device_iterator: locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock(), 'bidding': threading.Lock()} position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock(), 'bidding': threading.Lock()} for device in device_iterator: for i in range(flags.num_threads): for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: thread = threading.Thread( target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,device,position,locks[device][position],position_locks[position])) thread.setDaemon(True) thread.start() threads.append(thread) def checkpoint(frames): global onnx_frame if flags.disable_checkpoint: return log.info('Saving checkpoint to %s', checkpointpath) _models = learner_model.get_models() torch.save({ 'model_state_dict': {k: _models[k].state_dict() for k in _models}, # {{"general": _models["landlord"].state_dict()} 'optimizer_state_dict': {k: optimizers[k].state_dict() for k in optimizers}, # {"general": optimizers["landlord"].state_dict()} "stats": stats, 'flags': vars(flags), 'frames': frames, 'position_frames': position_frames }, checkpointpath + '.new') # Save the weights for evaluation purpose dummy_input = ( torch.tensor(np.zeros([1, 40, 108]), dtype=torch.float32), torch.tensor(np.zeros((1, 80)), dtype=torch.float32) ) for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down', 'bidding']: model_weights_dir = os.path.expandvars(os.path.expanduser( '%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt'))) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) if position != 'bidding': model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position) torch.onnx.export( learner_model.get_model(position), dummy_input, model_path, input_names=['z_batch','x_batch'], output_names=['values'], dynamic_axes={ 'z_batch': { 0: "legal_actions" }, 'x_batch': { 0: "legal_actions" } } ) onnx_frame = frames shutil.move(checkpointpath + '.new', checkpointpath) fps_log = [] timer = timeit.default_timer try: last_checkpoint_time = timer() - flags.save_interval * 60 while frames < flags.total_frames: start_frames = frames position_start_frames = {k: position_frames[k] for k in position_frames} start_time = timer() time.sleep(5) if timer() - last_checkpoint_time > flags.save_interval * 60: checkpoint(frames) last_checkpoint_time = timer() end_time = timer() fps = (frames - start_frames) / (end_time - start_time) fps_log.append(fps) if len(fps_log) > 24: fps_log = fps_log[1:] fps_avg = np.mean(fps_log) position_fps = {k:(position_frames[k]-position_start_frames[k])/(end_time-start_time) for k in position_frames} log.info('After %i (L:%i U:%i F:%i D:%i) frames: @ %.1f fps (avg@ %.1f fps) (L:%.1f U:%.1f F:%.1f D:%.1f) Stats:\n%s', frames, position_frames['landlord'], position_frames['landlord_up'], position_frames['landlord_front'], position_frames['landlord_down'], fps, fps_avg, position_fps['landlord'], position_fps['landlord_up'], position_fps['landlord_front'], position_fps['landlord_down'], pprint.pformat(stats)) except KeyboardInterrupt: return else: for thread in threads: thread.join() log.info('Learning finished after %d frames.', frames) checkpoint(frames) plogger.close()