修复BUG
This commit is contained in:
parent
9bdcc85cc7
commit
724f029591
|
@ -62,9 +62,9 @@ class DeepAgent:
|
|||
def act(self, infoset, with_confidence = False):
|
||||
if self.use_onnx and self.onnx_model is None:
|
||||
if torch.cuda.is_available():
|
||||
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider'])
|
||||
else:
|
||||
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CUDAExecutionProvider'])
|
||||
else:
|
||||
self.onnx_model = onnxruntime.InferenceSession(get_example(self.onnx_model_path), providers=['CPUExecutionProvider'])
|
||||
|
||||
if not with_confidence and len(infoset.legal_actions) == 1:
|
||||
return infoset.legal_actions[0]
|
||||
|
|
Loading…
Reference in New Issue