修复BUG

This commit is contained in:
zhiyang7 2022-01-05 09:49:42 +08:00
parent b11ea3a8d9
commit 5276c4e3c6
2 changed files with 7 additions and 3 deletions

View File

@ -122,7 +122,7 @@ def train(flags):
position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0}
if flags.unified_model: if flags.unified_model:
lock = threading.Lock() lock = threading.Lock()
position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock} position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock, 'uni': lock}
else: else:
position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()}
@ -130,7 +130,8 @@ def train(flags):
p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid) p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid)
if not os.path.exists(p_path): if not os.path.exists(p_path):
os.makedirs(p_path) os.makedirs(p_path)
for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: positions = ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
for position in positions:
if flags.enable_onnx: if flags.enable_onnx:
model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position) model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position)
onnx_params = learner_model.get_model(position)\ onnx_params = learner_model.get_model(position)\

View File

@ -775,11 +775,14 @@ class UnifiedModel:
self.onnx_model = None self.onnx_model = None
def set_onnx_model(self, device='cpu'): def set_onnx_model(self, device='cpu'):
model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid)) model_path = os.path.abspath('%s/%s/model_uni.onnx' % (self.flags.onnx_model_path, self.flags.xpid))
positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']
if device == 'cpu': if device == 'cpu':
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider']) self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider'])
else: else:
self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider']) self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider'])
for position in positions:
self.onnx_models[position] = self.onnx_model
def get_onnx_params(self, position): def get_onnx_params(self, position):
self.model.get_onnx_params(self.device) self.model.get_onnx_params(self.device)