diff --git a/douzero/dmc/arguments.py b/douzero/dmc/arguments.py index 588c23f..496fe4c 100644 --- a/douzero/dmc/arguments.py +++ b/douzero/dmc/arguments.py @@ -11,6 +11,8 @@ parser.add_argument('--objective', default='adp', type=str, choices=['adp', 'wp' help='Use ADP or WP as reward (default: ADP)') # Training settings +parser.add_argument('--onnx_sync_interval', default=5, type=int, + help='Time interval (in minutes) at which to sync the onnx model') parser.add_argument('--actor_device_cpu', action='store_true', help='Use CPU as actor device') parser.add_argument('--gpu_devices', default='0', type=str, diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 781ef97..53ff12e 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -130,6 +130,25 @@ def train(flags): frames, stats = 0, {k: 0 for k in stat_keys} position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} + def sync_onnx_model(frames): + for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: + if flags.enable_onnx and position: + model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position) + onnx_params = learner_model.get_model(position)\ + .get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device)) + torch.onnx.export( + learner_model.get_model(position), + onnx_params['args'], + model_path, + export_params=True, + opset_version=10, + do_constant_folding=True, + input_names=onnx_params['input_names'], + output_names=onnx_params['output_names'], + dynamic_axes=onnx_params['dynamic_axes'] + ) + onnx_frame.value = frames + # Load models if any if flags.load_model and os.path.exists(checkpointpath): checkpoint_states = torch.load( @@ -145,6 +164,7 @@ def train(flags): frames = checkpoint_states["frames"] position_frames = checkpoint_states["position_frames"] + sync_onnx_model(frames) log.info(f"Resuming preempted job, current stats:\n{stats}") # Starting actor processes @@ -217,29 +237,13 @@ def train(flags): model_weights_dir = os.path.expandvars(os.path.expanduser( '%s/%s/%s' % (flags.savedir, flags.xpid, "general_" + position + '_' + str(frames) + '.ckpt'))) torch.save(learner_model.get_model(position).state_dict(), model_weights_dir) - if flags.enable_onnx and position: - model_path = '%s/%s/model_%s.onnx' % (flags.savedir, flags.xpid, position) - onnx_params = learner_model.get_model(position)\ - .get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device)) - torch.onnx.export( - learner_model.get_model(position), - onnx_params['args'], - model_path, - export_params=True, - opset_version=10, - do_constant_folding=True, - input_names=onnx_params['input_names'], - output_names=onnx_params['output_names'], - dynamic_axes=onnx_params['dynamic_axes'] - ) - onnx_frame.value = frames shutil.move(checkpointpath + '.new', checkpointpath) - fps_log = [] timer = timeit.default_timer try: last_checkpoint_time = timer() - flags.save_interval * 60 + last_onnx_sync_time = timer() while frames < flags.total_frames: start_frames = frames position_start_frames = {k: position_frames[k] for k in position_frames} @@ -249,6 +253,11 @@ def train(flags): if timer() - last_checkpoint_time > flags.save_interval * 60: checkpoint(frames) last_checkpoint_time = timer() + + if timer() - last_onnx_sync_time > flags.onnx_sync_interval * 60: + sync_onnx_model(frames) + last_onnx_sync_time = timer() + end_time = timer() fps = (frames - start_frames) / (end_time - start_time) diff --git a/requirements.txt b/requirements.txt index 2dfca1a..cab6fd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ gitdb2 rlcard psutil onnx -onnxruntime-gpu +onnxruntime-gpu==1.7