修复BUG

This commit is contained in:
zhiyang7 2021-12-23 10:45:57 +08:00
parent fb22ab2649
commit 9c990675ac
1 changed files with 10 additions and 10 deletions

View File

@ -531,13 +531,13 @@ class OldModel:
self.models[position].get_onnx_params(self.device)
def forward(self, position, z, x, return_value=False, flags=None):
model = self.onnx_models[position]
if model is None:
if flags.enable_onnx:
model = self.onnx_models[position]
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = onnx_out[0]
else:
model = self.models[position]
values = model.forward(z, x)['values']
else:
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = torch.tensor(onnx_out[0])
if return_value:
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
else:
@ -548,9 +548,9 @@ class OldModel:
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
else:
if flags.enable_onnx:
action = np.argmax(values, axis=0)[0].cpu().detach().numpy()
action = np.argmax(values, axis=0)[0]
else:
action = torch.argmax(values,dim=0)[0]
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
return dict(action = action)
def share_memory(self):
@ -623,7 +623,7 @@ class Model:
if flags.enable_onnx:
model = self.onnx_models[position]
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = torch.tensor(onnx_out[0])
values = onnx_out[0]
else:
model = self.models[position]
values = model.forward(z, x)['values']
@ -637,9 +637,9 @@ class Model:
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
else:
if flags.enable_onnx:
action = np.argmax(values, axis=0)[0].cpu().detach().numpy()
action = np.argmax(values, axis=0)[0]
else:
action = torch.argmax(values,dim=0)[0]
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
return dict(action = action)
def share_memory(self):