修复BUG

This commit is contained in:
zhiyang7 2021-12-27 14:58:47 +08:00
parent 9bdcc85cc7
commit 724f029591
1 changed files with 2 additions and 2 deletions

View File

@ -62,9 +62,9 @@ class DeepAgent:
def act(self, infoset, with_confidence = False):
if self.use_onnx and self.onnx_model is None:
if torch.cuda.is_available():
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider'])
else:
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CUDAExecutionProvider'])
else:
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider'])
if not with_confidence and len(infoset.legal_actions) == 1:
return infoset.legal_actions[0]