修复BUG
This commit is contained in:
parent
9bdcc85cc7
commit
724f029591
|
@ -62,9 +62,9 @@ class DeepAgent:
|
||||||
def act(self, infoset, with_confidence = False):
|
def act(self, infoset, with_confidence = False):
|
||||||
if self.use_onnx and self.onnx_model is None:
|
if self.use_onnx and self.onnx_model is None:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider'])
|
|
||||||
else:
|
|
||||||
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CUDAExecutionProvider'])
|
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CUDAExecutionProvider'])
|
||||||
|
else:
|
||||||
|
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider'])
|
||||||
|
|
||||||
if not with_confidence and len(infoset.legal_actions) == 1:
|
if not with_confidence and len(infoset.legal_actions) == 1:
|
||||||
return infoset.legal_actions[0]
|
return infoset.legal_actions[0]
|
||||||
|
|
Loading…
Reference in New Issue