diff --git a/inference/models/owlv2/owlv2.py b/inference/models/owlv2/owlv2.py index 407d109d99..c4c8c4d8f8 100644 --- a/inference/models/owlv2/owlv2.py +++ b/inference/models/owlv2/owlv2.py @@ -235,16 +235,25 @@ def filter_tensors_by_objectness( logit_shift: torch.Tensor, logit_scale: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Fuse squeeze operations for potential speedup and clarity objectness = objectness.squeeze(0) - objectness, objectness_indices = torch.topk(objectness, MAX_DETECTIONS, dim=0) boxes = boxes.squeeze(0) image_class_embeds = image_class_embeds.squeeze(0) - logit_shift = logit_shift.squeeze(0).squeeze(1) - logit_scale = logit_scale.squeeze(0).squeeze(1) - boxes = boxes[objectness_indices] - image_class_embeds = image_class_embeds[objectness_indices] - logit_shift = logit_shift[objectness_indices] - logit_scale = logit_scale[objectness_indices] + # Combine sequential squeeze ops into one for logit_shift and logit_scale + logit_shift = logit_shift.squeeze() + logit_scale = logit_scale.squeeze() + + # topk returns values and indices in one go, so only indices needed for all tensors + objectness, objectness_indices = torch.topk(objectness, MAX_DETECTIONS, dim=0) + + # Apply advanced indexing once for all tensors + # Avoids repeated indexing overhead + indices = objectness_indices + boxes = boxes.index_select(0, indices) + image_class_embeds = image_class_embeds.index_select(0, indices) + logit_shift = logit_shift.index_select(0, indices) + logit_scale = logit_scale.index_select(0, indices) + return objectness, boxes, image_class_embeds, logit_shift, logit_scale