Skip to content

Commit 4b7520f

Browse files
committed
Manual progress
1 parent 44d8d54 commit 4b7520f

File tree

3 files changed

+56
-14
lines changed

3 files changed

+56
-14
lines changed

optimum/exporters/executorch/integrations.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from transformers import (
2323
AutoConfig,
2424
AutoProcessor,
25+
AutoTokenizer,
2526
PreTrainedModel,
2627
StaticCache,
2728
T5ForConditionalGeneration,
@@ -34,7 +35,7 @@
3435

3536
from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache
3637

37-
from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods
38+
from .utils import apply_chat_template_with_fallback, process_conversation_inputs, save_config_to_constant_methods
3839

3940

4041
class VisionExportableModule(torch.nn.Module):
@@ -46,6 +47,7 @@ def prepare_export_inputs(self):
4647
# 1. Get export inputs
4748
model_id = self.model.config.name_or_path
4849
processor = AutoProcessor.from_pretrained(model_id)
50+
tokenizer = AutoTokenizer.from_pretrained(model_id)
4951
sample_conversation_with_image = [
5052
{
5153
"role": "user",
@@ -54,13 +56,18 @@ def prepare_export_inputs(self):
5456
],
5557
},
5658
]
57-
processed_inputs = processor.apply_chat_template(
59+
processed_inputs = process_conversation_inputs(
60+
processor,
61+
tokenizer,
5862
sample_conversation_with_image,
59-
add_generation_prompt=True,
60-
tokenize=True,
61-
return_dict=True,
62-
return_tensors="pt",
6363
)
64+
# processed_inputs = processor.apply_chat_template(
65+
# sample_conversation_with_image,
66+
# add_generation_prompt=True,
67+
# tokenize=True,
68+
# return_dict=True,
69+
# return_tensors="pt",
70+
# )
6471
if "pixel_values" not in processed_inputs:
6572
raise ValueError(
6673
f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'pixel_values' key: {processed_inputs}"

optimum/exporters/executorch/tasks/multimodal_text_to_text.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,19 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
180180
"device": device,
181181
},
182182
)
183-
decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model)
184-
encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name
183+
184+
# Most <Model>ForConditionalGeneration> will have the text_model and encoder models as attributes, however
185+
# some have `self.model = <Model>` (the base version not for conditional generation), and this `self.model`
186+
# contains the text_model and encoder model attributes.
187+
if hasattr(eager_model, "model"):
188+
decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model.model)
189+
# Set these as top level attributes.
190+
setattr(eager_model, decoder_name, getattr(eager_model.model, decoder_name))
191+
encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name
192+
setattr(eager_model, encoder_name, getattr(eager_model.model, encoder_name))
193+
else:
194+
decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model)
195+
encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name
185196

186197
# Need to do this since apparently when nested modules (e.g. model.language_model) access the .property
187198
# config, it always comes from the generation_config.json file, not the `generation_config` override

optimum/exporters/executorch/utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,12 @@ def process_conversation_inputs(
139139
input_conversation: List[Dict[str, Any]],
140140
):
141141
"""
142-
Process input conversation for multimodal models.
143-
144-
This function handles the preprocessing of conversation inputs, with special handling for
145-
GraniteSpeechProcessor which requires extracting and processing audio content from conversations
146-
prior to feeding into the processor.
142+
Process an input conversation into tensor inputs for multimodal models.
147143
148144
Args:
149145
processor: The processor to use for input processing
150146
tokenizer: The tokenizer to use for text processing
151-
input_conversation: List of conversation messages, may contain audio content
147+
input_conversation: List of conversation messages
152148
153149
Returns:
154150
Processed inputs ready for model consumption
@@ -190,6 +186,34 @@ def process_conversation_inputs(
190186
# Generate text prompt and process with audio
191187
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
192188
inputs = processor(prompt, wav, return_tensors="pt")
189+
elif isinstance(processor, transformers.SmolVLMProcessor):
190+
from transformers.image_utils import load_image
191+
192+
conversation = copy.deepcopy(input_conversation)
193+
images = []
194+
195+
# Extract image URLs from conversation
196+
for message in conversation:
197+
if isinstance(message.get("content"), list):
198+
# Filter out image entries and collect URLs
199+
image_urls = [item["url"] for item in message["content"] if item.get("type") == "image"]
200+
images.extend([load_image(url) for url in image_urls])
201+
202+
# Remove image entries from content
203+
message["content"] = [item for item in message["content"] if item.get("type") != "image"]
204+
205+
# Apply chat template to get text prompt
206+
prompt = apply_chat_template_with_fallback(
207+
processor,
208+
conversation,
209+
add_generation_prompt=True,
210+
tokenize=True,
211+
return_dict=True,
212+
return_tensors="pt",
213+
)
214+
215+
# Process with text and images
216+
inputs = processor(text=prompt, images=images, return_tensors="pt")
193217
else:
194218
# Standard processing for other processors
195219
inputs = apply_chat_template_with_fallback(

0 commit comments

Comments
 (0)