export与learn添加互斥锁
This commit is contained in:
parent
a4d2a93afc
commit
8a6010381f
|
@ -129,6 +129,7 @@ def train(flags):
|
|||
]
|
||||
frames, stats = 0, {k: 0 for k in stat_keys}
|
||||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||
|
||||
def sync_onnx_model(frames):
|
||||
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
||||
|
@ -139,17 +140,18 @@ def train(flags):
|
|||
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
||||
onnx_params = learner_model.get_model(position)\
|
||||
.get_onnx_params(torch.device('cpu') if flags.training_device == 'cpu' else torch.device('cuda:' + flags.training_device))
|
||||
torch.onnx.export(
|
||||
learner_model.get_model(position),
|
||||
onnx_params['args'],
|
||||
model_path,
|
||||
export_params=True,
|
||||
opset_version=10,
|
||||
do_constant_folding=True,
|
||||
input_names=onnx_params['input_names'],
|
||||
output_names=onnx_params['output_names'],
|
||||
dynamic_axes=onnx_params['dynamic_axes']
|
||||
)
|
||||
with position_locks[position]:
|
||||
torch.onnx.export(
|
||||
learner_model.get_model(position),
|
||||
onnx_params['args'],
|
||||
model_path,
|
||||
export_params=True,
|
||||
opset_version=10,
|
||||
do_constant_folding=True,
|
||||
input_names=onnx_params['input_names'],
|
||||
output_names=onnx_params['output_names'],
|
||||
dynamic_axes=onnx_params['dynamic_axes']
|
||||
)
|
||||
onnx_frame.value = frames
|
||||
|
||||
# Load models if any
|
||||
|
@ -214,7 +216,6 @@ def train(flags):
|
|||
locks = {}
|
||||
for device in device_iterator:
|
||||
locks[device] = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||
|
||||
for i in range(flags.num_threads):
|
||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||
|
|
Loading…
Reference in New Issue