3737
3838from .utils import apply_chat_template_with_fallback , process_conversation_inputs , save_config_to_constant_methods
3939
40+ def _patch_idefics3_vision_embeddings_for_export (vision_model ):
41+ """
42+ Patch Idefics3VisionEmbeddings to make it export-friendly by removing data-dependent operations.
43+ This assumes batch_size=1 and a full attention mask (all 1s).
44+ """
45+ import types
46+
47+ def export_friendly_forward (self , pixel_values : torch .FloatTensor , patch_attention_mask : torch .BoolTensor ) -> torch .Tensor :
48+ batch_size , _ , max_im_h , max_im_w = pixel_values .shape
49+
50+ patch_embeds = self .patch_embedding (pixel_values )
51+ embeddings = patch_embeds .flatten (2 ).transpose (1 , 2 )
52+
53+ nb_patches_h = max_im_h // self .patch_size
54+ nb_patches_w = max_im_w // self .patch_size
55+ N = self .num_patches_per_side
56+
57+ # For export, we assume full attention mask and compute position IDs statically.
58+ # This avoids the data-dependent loop over batch dimension.
59+ h_indices = torch .arange (nb_patches_h , device = pixel_values .device , dtype = torch .long )
60+ w_indices = torch .arange (nb_patches_w , device = pixel_values .device , dtype = torch .long )
61+
62+ # This replaces bucketize(x, boundaries=[1/N, 2/N, ...], right=True) ≈ floor(x * N), which
63+ # we don't have a kernel for at the moment.
64+ bucket_coords_h = (h_indices * N ) // nb_patches_h
65+ bucket_coords_w = (w_indices * N ) // nb_patches_w
66+
67+ bucket_coords_h = torch .clamp (bucket_coords_h , max = N - 1 )
68+ bucket_coords_w = torch .clamp (bucket_coords_w , max = N - 1 )
69+
70+ pos_ids = (bucket_coords_h [:, None ] * N + bucket_coords_w [None , :]).reshape (- 1 )
71+ position_ids = pos_ids .unsqueeze (0 ).expand (batch_size , - 1 )
72+ embeddings = embeddings + self .position_embedding (position_ids )
73+ return embeddings
74+
75+ # Patch the forward method.
76+ vision_model .embeddings .forward = types .MethodType (export_friendly_forward , vision_model .embeddings )
77+
4078
4179class VisionExportableModule (torch .nn .Module ):
4280 def __init__ (self , model : torch .nn .Module ):
4381 super ().__init__ ()
4482 self .model = model
4583
84+ # Patch Idefics3 vision embeddings if needed
85+ if hasattr (model , 'model' ) and hasattr (model .model , 'vision_model' ):
86+ model_type = getattr (model .config , 'model_type' , '' )
87+ if 'idefics3' in model_type .lower ():
88+ _patch_idefics3_vision_embeddings_for_export (model .model .vision_model )
89+
4690 def prepare_export_inputs (self ):
4791 # 1. Get export inputs
4892 model_id = self .model .config .name_or_path
@@ -61,13 +105,6 @@ def prepare_export_inputs(self):
61105 tokenizer ,
62106 sample_conversation_with_image ,
63107 )
64- # processed_inputs = processor.apply_chat_template(
65- # sample_conversation_with_image,
66- # add_generation_prompt=True,
67- # tokenize=True,
68- # return_dict=True,
69- # return_tensors="pt",
70- # )
71108 if "pixel_values" not in processed_inputs :
72109 raise ValueError (
73110 f"Unable to obtain sample audio encoder inputs for export for { model_id } - the processor did not return formatted inputs with the 'pixel_values' key: { processed_inputs } "
@@ -83,7 +120,9 @@ def forward(
83120 self ,
84121 input_features : torch .FloatTensor ,
85122 ):
86- image_embeds = self .model .get_image_features (input_features )
123+ # Pass pixel_attention_mask=None to avoid data-dependent operations during export.
124+ # The model will create a mask full of 1s internally if None is passed.
125+ image_embeds = self .model .get_image_features (input_features , pixel_attention_mask = None )
87126 if isinstance (image_embeds , list ):
88127 image_embeds = torch .stack (image_embeds )
89128 return image_embeds
0 commit comments