export与learn添加互斥锁
This commit is contained in:
parent
a4d2a93afc
commit
8a6010381f
|
@ -129,6 +129,7 @@ def train(flags):
|
||||||
]
|
]
|
||||||
frames, stats = 0, {k: 0 for k in stat_keys}
|
frames, stats = 0, {k: 0 for k in stat_keys}
|
||||||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||||
|
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||||
|
|
||||||
def sync_onnx_model(frames):
|
def sync_onnx_model(frames):
|
||||||
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
||||||
|
@ -139,17 +140,18 @@ def train(flags):
|
||||||
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
||||||
onnx_params = learner_model.get_model(position)\
|
onnx_params = learner_model.get_model(position)\
|
||||||
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
||||||
torch.onnx.export(
|
with position_locks[position]:
|
||||||
learner_model.get_model(position),
|
torch.onnx.export(
|
||||||
onnx_params['args'],
|
learner_model.get_model(position),
|
||||||
model_path,
|
onnx_params['args'],
|
||||||
export_params=True,
|
model_path,
|
||||||
opset_version=10,
|
export_params=True,
|
||||||
do_constant_folding=True,
|
opset_version=10,
|
||||||
input_names=onnx_params['input_names'],
|
do_constant_folding=True,
|
||||||
output_names=onnx_params['output_names'],
|
input_names=onnx_params['input_names'],
|
||||||
dynamic_axes=onnx_params['dynamic_axes']
|
output_names=onnx_params['output_names'],
|
||||||
)
|
dynamic_axes=onnx_params['dynamic_axes']
|
||||||
|
)
|
||||||
onnx_frame.value = frames
|
onnx_frame.value = frames
|
||||||
|
|
||||||
# Load models if any
|
# Load models if any
|
||||||
|
@ -214,7 +216,6 @@ def train(flags):
|
||||||
locks = {}
|
locks = {}
|
||||||
for device in device_iterator:
|
for device in device_iterator:
|
||||||
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
|
||||||
|
|
||||||
for i in range(flags.num_threads):
|
for i in range(flags.num_threads):
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||||
|
|
Loading…
Reference in New Issue