修复BUG
This commit is contained in:
parent
b11ea3a8d9
commit
5276c4e3c6
|
@ -122,7 +122,7 @@ def train(flags):
|
|||
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
|
||||
if flags.unified_model:
|
||||
lock = threading.Lock()
|
||||
position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock}
|
||||
position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock, 'uni': lock}
|
||||
else:
|
||||
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
|
||||
|
||||
|
@ -130,7 +130,8 @@ def train(flags):
|
|||
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
|
||||
if not os.path.exists(p_path):
|
||||
os.makedirs(p_path)
|
||||
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']:
|
||||
positions = ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
for position in positions:
|
||||
if flags.enable_onnx:
|
||||
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
|
||||
onnx_params = learner_model.get_model(position)\
|
||||
|
|
|
@ -775,11 +775,14 @@ class UnifiedModel:
|
|||
self.onnx_model = None
|
||||
|
||||
def set_onnx_model(self, device='cpu'):
|
||||
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid))
|
||||
model_path = os.path.abspath('%s/%s/model_uni.onnx' % (self.flags.onnx_model_path, self.flags.xpid))
|
||||
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
|
||||
if device == 'cpu':
|
||||
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
|
||||
else:
|
||||
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
|
||||
for position in positions:
|
||||
self.onnx_models[position] = self.onnx_model
|
||||
|
||||
def get_onnx_params(self, position):
|
||||
self.model.get_onnx_params(self.device)
|
||||
|
|
Loading…
Reference in New Issue