拆分批次infer逻辑

This commit is contained in:
zhiyang7 2021-12-23 11:37:32 +08:00
parent 9c990675ac
commit 32718bbe12
1 changed files with 47 additions and 42 deletions

View File

@ -491,6 +491,51 @@ model_dict_new_lite['landlord_up'] = GeneralModelLite
model_dict_new_lite['landlord_front'] = GeneralModelLite
model_dict_new_lite['landlord_down'] = GeneralModelLite
def forward_logic(self_model, position, z, x, return_value=False, flags=None):
legal_count = len(z)
if legal_count >= 80:
partition_count = int(legal_count / 40)
sub_z = np.array_split(z, partition_count)
sub_x = np.array_split(x, partition_count)
if flags.enable_onnx:
model = self_model.onnx_models[position]
if legal_count >= 80:
values = np.ndarray((legal_count, 1))
j = 0
for i in range(partition_count):
onnx_out = model.run(None, {'z_batch': sub_z[i], 'x_batch': sub_x[i]})
values[j:j+len(sub_z[i])] = onnx_out[0]
j += len(sub_z[i])
else:
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = onnx_out[0]
else:
if legal_count >= 80:
values = np.ndarray((legal_count, 1))
j = 0
for i in range(partition_count):
model = self_model.models[position]
model_out = model.forward(sub_z[i], sub_x[i])['values']
values[j:j+len(sub_z[i])] = model_out
j += len(sub_z[i])
else:
model = self_model.models[position]
values = model.forward(z, x)['values']
if return_value:
return dict(values = values.cpu().detach().numpy() if torch.is_tensor(values) else values)
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
if torch.is_tensor(values):
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
else:
action = np.random.randint(0, values.shape[0], (1,))[0]
else:
if torch.is_tensor(values):
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
else:
action = np.argmax(values, axis=0)[0]
return dict(action = action)
class OldModel:
"""
The wrapper for the three models. We also wrap several
@ -531,27 +576,7 @@ class OldModel:
self.models[position].get_onnx_params(self.device)
def forward(self, position, z, x, return_value=False, flags=None):
if flags.enable_onnx:
model = self.onnx_models[position]
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = onnx_out[0]
else:
model = self.models[position]
values = model.forward(z, x)['values']
if return_value:
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
if flags.enable_onnx:
action = np.random.randint(0, values.shape[0], (1,))[0]
else:
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
else:
if flags.enable_onnx:
action = np.argmax(values, axis=0)[0]
else:
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
return dict(action = action)
return forward_logic(self, position, z, x, return_value, flags)
def share_memory(self):
if self.models['landlord'] is not None:
@ -620,27 +645,7 @@ class Model:
self.models[position].get_onnx_params(self.device)
def forward(self, position, z, x, return_value=False, flags=None, debug=False):
if flags.enable_onnx:
model = self.onnx_models[position]
onnx_out = model.run(None, {'z_batch': z, 'x_batch': x})
values = onnx_out[0]
else:
model = self.models[position]
values = model.forward(z, x)['values']
if return_value:
return dict(values = values if flags.enable_onnx else values.cpu().detach().numpy())
else:
if flags is not None and flags.exp_epsilon > 0 and np.random.rand() < flags.exp_epsilon:
if flags.enable_onnx:
action = np.random.randint(0, values.shape[0], (1,))[0]
else:
action = torch.randint(values.shape[0], (1,))[0].cpu().detach().numpy()
else:
if flags.enable_onnx:
action = np.argmax(values, axis=0)[0]
else:
action = torch.argmax(values,dim=0)[0].cpu().detach().numpy()
return dict(action = action)
return forward_logic(self, position, z, x, return_value, flags)
def share_memory(self):
if self.models['landlord'] is not None: