修复BUG
This commit is contained in:
parent
fb22ab2649
commit
9c990675ac
|
@ -531,13 +531,13 @@ class OldModel:
|
||||||
self.models[position].get_onnx_params(self.device)
|
self.models[position].get_onnx_params(self.device)
|
||||||
|
|
||||||
def forward(self, position, z, x, return_value=False, flags=None):
|
def forward(self, position, z, x, return_value=False, flags=None):
|
||||||
model = self.onnx_models[position]
|
if flags.enable_onnx:
|
||||||
if model is None:
|
model = self.onnx_models[position]
|
||||||
|
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||||
|
values = onnx_out[0]
|
||||||
|
else:
|
||||||
model = self.models[position]
|
model = self.models[position]
|
||||||
values = model.forward(z, x)['values']
|
values = model.forward(z, x)['values']
|
||||||
else:
|
|
||||||
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
|
||||||
values = torch.tensor(onnx_out[0])
|
|
||||||
if return_value:
|
if return_value:
|
||||||
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
|
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
|
||||||
else:
|
else:
|
||||||
|
@ -548,9 +548,9 @@ class OldModel:
|
||||||
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
if flags.enable_onnx:
|
if flags.enable_onnx:
|
||||||
action = np.argmax(values, axis=0)[0].cpu().detach().numpy()
|
action = np.argmax(values, axis=0)[0]
|
||||||
else:
|
else:
|
||||||
action = torch.argmax(values,dim=0)[0]
|
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
|
||||||
return dict(action = action)
|
return dict(action = action)
|
||||||
|
|
||||||
def share_memory(self):
|
def share_memory(self):
|
||||||
|
@ -623,7 +623,7 @@ class Model:
|
||||||
if flags.enable_onnx:
|
if flags.enable_onnx:
|
||||||
model = self.onnx_models[position]
|
model = self.onnx_models[position]
|
||||||
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||||
values = torch.tensor(onnx_out[0])
|
values = onnx_out[0]
|
||||||
else:
|
else:
|
||||||
model = self.models[position]
|
model = self.models[position]
|
||||||
values = model.forward(z, x)['values']
|
values = model.forward(z, x)['values']
|
||||||
|
@ -637,9 +637,9 @@ class Model:
|
||||||
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
if flags.enable_onnx:
|
if flags.enable_onnx:
|
||||||
action = np.argmax(values, axis=0)[0].cpu().detach().numpy()
|
action = np.argmax(values, axis=0)[0]
|
||||||
else:
|
else:
|
||||||
action = torch.argmax(values,dim=0)[0]
|
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
|
||||||
return dict(action = action)
|
return dict(action = action)
|
||||||
|
|
||||||
def share_memory(self):
|
def share_memory(self):
|
||||||
|
|
Loading…
Reference in New Issue