@@ -727,18 +727,35 @@ def _find_matches(
727727 return mode , matches_to_apply
728728
729729
730+ def _all_items_found (
731+ mm_item_counts : dict [str , int ],
732+ mm_found_counts : dict [str , int ],
733+ ) -> bool :
734+ return all (
735+ item_idx >= mm_item_counts [modality ]
736+ for modality , item_idx in mm_found_counts .items ()
737+ )
738+
739+
730740def _apply_matches (
731741 prompt : _S ,
732742 mm_prompt_updates : "MultiModalPromptUpdates" ,
733743 tokenizer : AnyTokenizer ,
734744) -> tuple [list [_S ], "MultiModalPromptUpdatesApplyResult" ]:
735745 prompt_len = len (prompt )
746+ mm_item_counts = {m : len (items ) for m , items in mm_prompt_updates .items ()}
736747
737748 out_seqs = list [str | list [int ]]()
738749 out_result : MultiModalPromptUpdatesApplyResult = {
739750 m : [None ] * len (items ) for m , items in mm_prompt_updates .items ()
740751 }
741752
753+ mm_found_counts = {
754+ m : sum (r is not None for r in res ) for m , res in out_result .items ()
755+ }
756+ if _all_items_found (mm_item_counts , mm_found_counts ):
757+ return [prompt ], out_result
758+
742759 start_idx = prev_end_idx = 0
743760 while start_idx < max (prompt_len , 1 ): # Allow inserts into empty prompt
744761 found = False
@@ -776,6 +793,12 @@ def _apply_matches(
776793 # Exclude overlapping matches
777794 start_idx = prev_end_idx = match .end_idx
778795
796+ mm_found_counts = {
797+ m : sum (r is not None for r in res ) for m , res in out_result .items ()
798+ }
799+ if _all_items_found (mm_item_counts , mm_found_counts ):
800+ break
801+
779802 if not found :
780803 start_idx += 1
781804
@@ -832,12 +855,15 @@ def _iter_placeholders(
832855
833856 Note that empty matches are ignored.
834857 """
835- prompt_len = len (prompt )
836858 mm_item_counts = {m : len (items ) for m , items in mm_prompt_updates .items ()}
859+ item_idx_by_modality = {modality : 0 for modality in mm_prompt_updates }
837860
838- item_idx_by_modality = defaultdict [str , int ](lambda : 0 )
861+ if _all_items_found (mm_item_counts , item_idx_by_modality ):
862+ return
839863
864+ prompt_len = len (prompt )
840865 start_idx = 0
866+
841867 while start_idx < prompt_len :
842868 found = False
843869
@@ -875,6 +901,9 @@ def _iter_placeholders(
875901 break
876902
877903 if found :
904+ if _all_items_found (mm_item_counts , item_idx_by_modality ):
905+ return
906+
878907 break # Go back to the outer while loop
879908
880909 if not found :
0 commit comments