修复BUG

This commit is contained in:
zhiyang7 2021-12-23 10:45:57 +08:00
parent fb22ab2649
commit 9c990675ac
1 changed files with 10 additions and 10 deletions

View File

@ -531,13 +531,13 @@ class OldModel:
self.models[position].get_onnx_params(self.device) self.models[position].get_onnx_params(self.device)
def forward(self, position, z, x, return_value=False, flags=None): def forward(self, position, z, x, return_value=False, flags=None):
if flags.enable_onnx:
model = self.onnx_models[position] model = self.onnx_models[position]
if model is None: onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = onnx_out[0]
else:
model = self.models[position] model = self.models[position]
values = model.forward(z, x)['values'] values = model.forward(z, x)['values']
else:
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = torch.tensor(onnx_out[0])
if return_value: if return_value:
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy()) return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
else: else:
@ -548,9 +548,9 @@ class OldModel:
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy() action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
else: else:
if flags.enable_onnx: if flags.enable_onnx:
action = np.argmax(values, axis=0)[0].cpu().detach().numpy() action = np.argmax(values, axis=0)[0]
else: else:
action = torch.argmax(values,dim=0)[0] action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
return dict(action = action) return dict(action = action)
def share_memory(self): def share_memory(self):
@ -623,7 +623,7 @@ class Model:
if flags.enable_onnx: if flags.enable_onnx:
model = self.onnx_models[position] model = self.onnx_models[position]
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x}) onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = torch.tensor(onnx_out[0]) values = onnx_out[0]
else: else:
model = self.models[position] model = self.models[position]
values = model.forward(z, x)['values'] values = model.forward(z, x)['values']
@ -637,9 +637,9 @@ class Model:
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy() action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
else: else:
if flags.enable_onnx: if flags.enable_onnx:
action = np.argmax(values, axis=0)[0].cpu().detach().numpy() action = np.argmax(values, axis=0)[0]
else: else:
action = torch.argmax(values,dim=0)[0] action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
return dict(action = action) return dict(action = action)
def share_memory(self): def share_memory(self):