修复BUG
This commit is contained in:
parent
b11ea3a8d9
commit
5276c4e3c6
|
@ -122,7 +122,7 @@ def train(flags):
|
||||||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||||
if flags.unified_model:
|
if flags.unified_model:
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock}
|
position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock, 'uni': lock}
|
||||||
else:
|
else:
|
||||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||||
|
|
||||||
|
@ -130,7 +130,8 @@ def train(flags):
|
||||||
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
||||||
if not os.path.exists(p_path):
|
if not os.path.exists(p_path):
|
||||||
os.makedirs(p_path)
|
os.makedirs(p_path)
|
||||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
positions = ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
|
for position in positions:
|
||||||
if flags.enable_onnx:
|
if flags.enable_onnx:
|
||||||
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
||||||
onnx_params = learner_model.get_model(position)\
|
onnx_params = learner_model.get_model(position)\
|
||||||
|
|
|
@ -775,11 +775,14 @@ class UnifiedModel:
|
||||||
self.onnx_model = None
|
self.onnx_model = None
|
||||||
|
|
||||||
def set_onnx_model(self, device='cpu'):
|
def set_onnx_model(self, device='cpu'):
|
||||||
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid))
|
model_path = os.path.abspath('%s/%s/model_uni.onnx' % (self.flags.onnx_model_path, self.flags.xpid))
|
||||||
|
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||||
if device == 'cpu':
|
if device == 'cpu':
|
||||||
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
||||||
else:
|
else:
|
||||||
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
|
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
|
||||||
|
for position in positions:
|
||||||
|
self.onnx_models[position] = self.onnx_model
|
||||||
|
|
||||||
def get_onnx_params(self, position):
|
def get_onnx_params(self, position):
|
||||||
self.model.get_onnx_params(self.device)
|
self.model.get_onnx_params(self.device)
|
||||||
|
|
Loading…
Reference in New Issue