export与learn添加互斥锁

This commit is contained in:
zhiyang7 2021-12-22 09:24:06 +08:00
parent a4d2a93afc
commit 8a6010381f
1 changed files with 13 additions and 12 deletions

View File

@ -129,6 +129,7 @@ def train(flags):
] ]
frames, stats = 0, {k: 0 for k in stat_keys} frames, stats = 0, {k: 0 for k in stat_keys}
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
def sync_onnx_model(frames): def sync_onnx_model(frames):
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid) p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
@ -139,17 +140,18 @@ def train(flags):
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position) model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
onnx_params = learner_model.get_model(position)\ onnx_params = learner_model.get_model(position)\
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device)) .get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
torch.onnx.export( with position_locks[position]:
learner_model.get_model(position), torch.onnx.export(
onnx_params['args'], learner_model.get_model(position),
model_path, onnx_params['args'],
export_params=True, model_path,
opset_version=10, export_params=True,
do_constant_folding=True, opset_version=10,
input_names=onnx_params['input_names'], do_constant_folding=True,
output_names=onnx_params['output_names'], input_names=onnx_params['input_names'],
dynamic_axes=onnx_params['dynamic_axes'] output_names=onnx_params['output_names'],
) dynamic_axes=onnx_params['dynamic_axes']
)
onnx_frame.value = frames onnx_frame.value = frames
# Load models if any # Load models if any
@ -214,7 +216,6 @@ def train(flags):
locks = {} locks = {}
for device in device_iterator: for device in device_iterator:
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
for i in range(flags.num_threads): for i in range(flags.num_threads):
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: