拆分批次infer逻辑
This commit is contained in:
parent
9c990675ac
commit
32718bbe12
|
@ -491,6 +491,51 @@ model_dict_new_lite['landlord_up'] = GeneralModelLite
|
|||
model_dict_new_lite['landlord_front'] = GeneralModelLite
|
||||
model_dict_new_lite['landlord_down'] = GeneralModelLite
|
||||
|
||||
def forward_logic(self_model, position, z, x, return_value=False, flags=None):
|
||||
legal_count = len(z)
|
||||
if legal_count >= 80:
|
||||
partition_count = int(legal_count / 40)
|
||||
sub_z = np.array_split(z, partition_count)
|
||||
sub_x = np.array_split(x, partition_count)
|
||||
if flags.enable_onnx:
|
||||
model = self_model.onnx_models[position]
|
||||
if legal_count >= 80:
|
||||
values = np.ndarray((legal_count, 1))
|
||||
j = 0
|
||||
for i in range(partition_count):
|
||||
onnx_out = model.run(None, {'z_batch': sub_z[i], 'x_batch': sub_x[i]})
|
||||
values[j:j+len(sub_z[i])] = onnx_out[0]
|
||||
j += len(sub_z[i])
|
||||
else:
|
||||
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||
values = onnx_out[0]
|
||||
else:
|
||||
if legal_count >= 80:
|
||||
values = np.ndarray((legal_count, 1))
|
||||
j = 0
|
||||
for i in range(partition_count):
|
||||
model = self_model.models[position]
|
||||
model_out = model.forward(sub_z[i], sub_x[i])['values']
|
||||
values[j:j+len(sub_z[i])] = model_out
|
||||
j += len(sub_z[i])
|
||||
else:
|
||||
model = self_model.models[position]
|
||||
values = model.forward(z, x)['values']
|
||||
if return_value:
|
||||
return dict(values = values.cpu().detach().numpy() if torch.is_tensor(values) else values)
|
||||
else:
|
||||
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
||||
if torch.is_tensor(values):
|
||||
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
||||
else:
|
||||
action = np.random.randint(0, values.shape[0], (1,))[0]
|
||||
else:
|
||||
if torch.is_tensor(values):
|
||||
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
|
||||
else:
|
||||
action = np.argmax(values, axis=0)[0]
|
||||
return dict(action = action)
|
||||
|
||||
class OldModel:
|
||||
"""
|
||||
The wrapper for the three models. We also wrap several
|
||||
|
@ -531,27 +576,7 @@ class OldModel:
|
|||
self.models[position].get_onnx_params(self.device)
|
||||
|
||||
def forward(self, position, z, x, return_value=False, flags=None):
|
||||
if flags.enable_onnx:
|
||||
model = self.onnx_models[position]
|
||||
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||
values = onnx_out[0]
|
||||
else:
|
||||
model = self.models[position]
|
||||
values = model.forward(z, x)['values']
|
||||
if return_value:
|
||||
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
|
||||
else:
|
||||
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
||||
if flags.enable_onnx:
|
||||
action = np.random.randint(0, values.shape[0], (1,))[0]
|
||||
else:
|
||||
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
||||
else:
|
||||
if flags.enable_onnx:
|
||||
action = np.argmax(values, axis=0)[0]
|
||||
else:
|
||||
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
|
||||
return dict(action = action)
|
||||
return forward_logic(self, position, z, x, return_value, flags)
|
||||
|
||||
def share_memory(self):
|
||||
if self.models['landlord'] is not None:
|
||||
|
@ -620,27 +645,7 @@ class Model:
|
|||
self.models[position].get_onnx_params(self.device)
|
||||
|
||||
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
|
||||
if flags.enable_onnx:
|
||||
model = self.onnx_models[position]
|
||||
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
|
||||
values = onnx_out[0]
|
||||
else:
|
||||
model = self.models[position]
|
||||
values = model.forward(z, x)['values']
|
||||
if return_value:
|
||||
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
|
||||
else:
|
||||
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
|
||||
if flags.enable_onnx:
|
||||
action = np.random.randint(0, values.shape[0], (1,))[0]
|
||||
else:
|
||||
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
|
||||
else:
|
||||
if flags.enable_onnx:
|
||||
action = np.argmax(values, axis=0)[0]
|
||||
else:
|
||||
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
|
||||
return dict(action = action)
|
||||
return forward_logic(self, position, z, x, return_value, flags)
|
||||
|
||||
def share_memory(self):
|
||||
if self.models['landlord'] is not None:
|
||||
|
|
Loading…
Reference in New Issue