Skip to content

Commit 1168768

Browse files
[Optimization] Early return for _apply_matches and _iter_placeholders (#29668)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 8e7a891 commit 1168768

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

vllm/multimodal/processing.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,18 +727,35 @@ def _find_matches(
727727
return mode, matches_to_apply
728728

729729

730+
def _all_items_found(
731+
mm_item_counts: dict[str, int],
732+
mm_found_counts: dict[str, int],
733+
) -> bool:
734+
return all(
735+
item_idx >= mm_item_counts[modality]
736+
for modality, item_idx in mm_found_counts.items()
737+
)
738+
739+
730740
def _apply_matches(
731741
prompt: _S,
732742
mm_prompt_updates: "MultiModalPromptUpdates",
733743
tokenizer: AnyTokenizer,
734744
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
735745
prompt_len = len(prompt)
746+
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
736747

737748
out_seqs = list[str | list[int]]()
738749
out_result: MultiModalPromptUpdatesApplyResult = {
739750
m: [None] * len(items) for m, items in mm_prompt_updates.items()
740751
}
741752

753+
mm_found_counts = {
754+
m: sum(r is not None for r in res) for m, res in out_result.items()
755+
}
756+
if _all_items_found(mm_item_counts, mm_found_counts):
757+
return [prompt], out_result
758+
742759
start_idx = prev_end_idx = 0
743760
while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt
744761
found = False
@@ -776,6 +793,12 @@ def _apply_matches(
776793
# Exclude overlapping matches
777794
start_idx = prev_end_idx = match.end_idx
778795

796+
mm_found_counts = {
797+
m: sum(r is not None for r in res) for m, res in out_result.items()
798+
}
799+
if _all_items_found(mm_item_counts, mm_found_counts):
800+
break
801+
779802
if not found:
780803
start_idx += 1
781804

@@ -832,12 +855,15 @@ def _iter_placeholders(
832855
833856
Note that empty matches are ignored.
834857
"""
835-
prompt_len = len(prompt)
836858
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
859+
item_idx_by_modality = {modality: 0 for modality in mm_prompt_updates}
837860

838-
item_idx_by_modality = defaultdict[str, int](lambda: 0)
861+
if _all_items_found(mm_item_counts, item_idx_by_modality):
862+
return
839863

864+
prompt_len = len(prompt)
840865
start_idx = 0
866+
841867
while start_idx < prompt_len:
842868
found = False
843869

@@ -875,6 +901,9 @@ def _iter_placeholders(
875901
break
876902

877903
if found:
904+
if _all_items_found(mm_item_counts, item_idx_by_modality):
905+
return
906+
878907
break # Go back to the outer while loop
879908

880909
if not found:

0 commit comments

Comments
 (0)