From 5276c4e3c6a3a3bc85867f7aa2bb8acc33156ed1 Mon Sep 17 00:00:00 2001 From: zhiyang7 Date: Wed, 5 Jan 2022 09:49:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBUG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- douzero/dmc/dmc.py | 5 +++-- douzero/dmc/models.py | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/douzero/dmc/dmc.py b/douzero/dmc/dmc.py index 4a3c53d..efafb12 100644 --- a/douzero/dmc/dmc.py +++ b/douzero/dmc/dmc.py @@ -122,7 +122,7 @@ def train(flags): position_frames = {'landlord':0, 'landlord_up':0, 'landlord_front':0, 'landlord_down':0} if flags.unified_model: lock = threading.Lock() - position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock} + position_locks = {'landlord': lock, 'landlord_up': lock, 'landlord_front': lock, 'landlord_down': lock, 'uni': lock} else: position_locks = {'landlord': threading.Lock(), 'landlord_up': threading.Lock(), 'landlord_front': threading.Lock(), 'landlord_down': threading.Lock()} @@ -130,7 +130,8 @@ def train(flags): p_path = '%s/%s' % (flags.onnx_model_path, flags.xpid) if not os.path.exists(p_path): os.makedirs(p_path) - for position in ['landlord', 'landlord_up', 'landlord_front', 'landlord_down']: + positions = ['uni'] if flags.unified_model else ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] + for position in positions: if flags.enable_onnx: model_path = '%s/%s/model_%s.onnx' % (flags.onnx_model_path, flags.xpid, position) onnx_params = learner_model.get_model(position)\ diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index 901791c..a6e5f65 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -775,11 +775,14 @@ class UnifiedModel: self.onnx_model = None def set_onnx_model(self, device='cpu'): - model_path = os.path.abspath('%s/%s/model_%s.onnx' % (self.flags.onnx_model_path, self.flags.xpid)) + model_path = os.path.abspath('%s/%s/model_uni.onnx' % (self.flags.onnx_model_path, self.flags.xpid)) + positions = ['landlord', 'landlord_up', 'landlord_front', 'landlord_down'] if device == 'cpu': self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CPUExecutionProvider']) else: self.onnx_model = onnxruntime.InferenceSession(get_example(model_path), providers=['CUDAExecutionProvider']) + for position in positions: + self.onnx_models[position] = self.onnx_model def get_onnx_params(self, position): self.model.get_onnx_params(self.device)