From 66b432f52ccceffaba0603ee05491fafe41ea09d Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Tue, 21 Dec 2021 12:49:53 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=9B=E7=A8=8B=E5=AE=88=E6=8A=A4=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E6=B7=BB=E5=8A=A0onnx=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E4=BD=8D=E7=BD=AE=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/arguments.py | 2 ++ douzero/dmc/dmc.py | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/douzero/dmc/arguments.py b/douzero/dmc/arguments.py index 82ba147..ccddcdf 100644 --- a/douzero/dmc/arguments.py +++ b/douzero/dmc/arguments.py @@ -37,6 +37,8 @@ parser.add_argument('--savedir', default='douzero_checkpoints', help='Root dir where experiment data will be saved') parser.add_argument('--enable_onnx', action='store_true', help='Use onnx model for train') +parser.add_argument('--onnx_model_path', default='douzero_checkpoints', + help='Root dir where onnx temp model will be saved') # Hyperparameters parser.add_argument('--total_frames', default=100000000000, type=int, diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 3b3ca33..a46be73 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -104,7 +104,7 @@ def train(flags): # Initialize queues actor_processes = [] ctx = mp.get_context('spawn') - batch_queues = {"landlord": ctx.SimpleQueue(), "landlord_up": ctx.SimpleQueue(), 'landlord_front': ctx.SimpleQueue(), "landlord_down": ctx.SimpleQueue()} + batch_queues = {"landlord": ctx.Queue(flags.unroll_length * 4), "landlord_up": ctx.Queue(flags.unroll_length * 4), 'landlord_front': ctx.Queue(flags.unroll_length * 4), "landlord_down": ctx.Queue(flags.unroll_length * 4)} onnx_frame = ctx.Value('d', -1) # Learner model for training @@ -131,6 +131,9 @@ def train(flags): position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} def sync_onnx_model(frames): + p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid) + if not os.path.exists(p_path): + os.makedirs(p_path) for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: if flags.enable_onnx: model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position) @@ -179,7 +182,11 @@ def train(flags): args=(i, device, batch_queues, models[device], flags, onnx_frame)) actor.daemon = True actor.start() - actor_processes.append(actor) + actor_processes.append({ + 'device': device, + 'i': i, + 'actor': actor + }) parent = psutil.Process() parent.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) @@ -279,6 +286,14 @@ def train(flags): position_fps['landlord_front'], position_fps['landlord_down'], pprint.pformat(stats)) + for proc in actor_processes: + if not proc['actor'].is_alive(): + actor = mp.Process( + target=act, + args=(proc['i'], proc['device'], batch_queues, models[device], flags, onnx_frame)) + actor.daemon = True + actor.start() + proc['actor'] = actor except KeyboardInterrupt: checkpoint(frames)