From 8a6010381f20d31dd89219d7cc26f5c845d004fb Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Wed, 22 Dec 2021 09:24:06 +0800 Subject: [PATCH] =?UTF-8?q?export=E4=B8=8Elearn=E6=B7=BB=E5=8A=A0=E4=BA=92?= =?UTF-8?q?=E6=96=A5=E9=94=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/dmc.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index b2920a0..858f622 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -129,6 +129,7 @@ def train(flags): ] frames, stats = 0, {k: 0 for k in stat_keys} position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} + position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} def sync_onnx_model(frames): p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid) @@ -139,17 +140,18 @@ def train(flags): model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position) onnx_params = learner_model.get_model(position)\ .get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device)) - torch.onnx.export( - learner_model.get_model(position), - onnx_params['args'], - model_path, - export_params=True, - opset_version=10, - do_constant_folding=True, - input_names=onnx_params['input_names'], - output_names=onnx_params['output_names'], - dynamic_axes=onnx_params['dynamic_axes'] - ) + with position_locks[position]: + torch.onnx.export( + learner_model.get_model(position), + onnx_params['args'], + model_path, + export_params=True, + opset_version=10, + do_constant_folding=True, + input_names=onnx_params['input_names'], + output_names=onnx_params['output_names'], + dynamic_axes=onnx_params['dynamic_axes'] + ) onnx_frame.value = frames # Load models if any @@ -214,7 +216,6 @@ def train(flags): locks = {} for device in device_iterator: locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} - position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} for i in range(flags.num_threads): for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: