修复BUG
This commit is contained in:
parent
fb22ab2649
commit
9c990675ac
|
@ -531,13 +531,13 @@ class OldModel:
|
|||
self.models[position].get_onnx_params(self.device)
|
||||
|
||||
def forward(self, position, z, x, return_value=False, flags=None):
|
||||
if flags.enable_onnx:
|
||||
model = self.onnx_models[position]
|
||||
if model is None:
|
||||
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||
values = onnx_out[0]
|
||||
else:
|
||||
model = self.models[position]
|
||||
values = model.forward(z, x)['values']
|
||||
else:
|
||||
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||
values = torch.tensor(onnx_out[0])
|
||||
if return_value:
|
||||
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
|
||||
else:
|
||||
|
@ -548,9 +548,9 @@ class OldModel:
|
|||
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
||||
else:
|
||||
if flags.enable_onnx:
|
||||
action = np.argmax(values, axis=0)[0].cpu().detach().numpy()
|
||||
action = np.argmax(values, axis=0)[0]
|
||||
else:
|
||||
action = torch.argmax(values,dim=0)[0]
|
||||
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
|
||||
return dict(action = action)
|
||||
|
||||
def share_memory(self):
|
||||
|
@ -623,7 +623,7 @@ class Model:
|
|||
if flags.enable_onnx:
|
||||
model = self.onnx_models[position]
|
||||
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||
values = torch.tensor(onnx_out[0])
|
||||
values = onnx_out[0]
|
||||
else:
|
||||
model = self.models[position]
|
||||
values = model.forward(z, x)['values']
|
||||
|
@ -637,9 +637,9 @@ class Model:
|
|||
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
||||
else:
|
||||
if flags.enable_onnx:
|
||||
action = np.argmax(values, axis=0)[0].cpu().detach().numpy()
|
||||
action = np.argmax(values, axis=0)[0]
|
||||
else:
|
||||
action = torch.argmax(values,dim=0)[0]
|
||||
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
|
||||
return dict(action = action)
|
||||
|
||||
def share_memory(self):
|
||||
|
|
Loading…
Reference in New Issue