diff --git a/douzero/evaluation/deep_agent.py b/douzero/evaluation/deep_agent.py index f328f4f..df1fb4a 100644 --- a/douzero/evaluation/deep_agent.py +++ b/douzero/evaluation/deep_agent.py @@ -62,9 +62,9 @@ class DeepAgent: def act(self, infoset, with_confidence = False): if self.use_onnx and self.onnx_model is None: if torch.cuda.is_available(): - self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider']) - else: self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CUDAExecutionProvider']) + else: + self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider']) if not with_confidence and len(infoset.legal_actions) == 1: return infoset.legal_actions[0]