diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index 4a9916d..e926b34 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -516,7 +516,7 @@ def forward_logic(self_model, position, z, x, return_value=False, flags=None): for i in range(partition_count): model = self_model.models[position] model_out = model.forward(sub_z[i], sub_x[i])['values'] - values[j:j+len(sub_z[i])] = model_out + values[j:j+len(sub_z[i])] = model_out.cpu().detach().numpy() j += len(sub_z[i]) else: model = self_model.models[position]