Skip to content

Commit 8335b49

Browse files
authored
Modify image embedding return for Llava compatibility (#147)
* Modify image embedding return for Llava compatibility For `Gemma3` we don't need to unsqueeze the image embedding. For `Llava` we got a list of 2D tensors and here I'm assuming it only contains 1 tensor. https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L214 * Address comments
1 parent 08b55f1 commit 8335b49

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

optimum/exporters/executorch/integrations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def forward(
7777
input_features: torch.FloatTensor,
7878
):
7979
image_embeds = self.model.get_image_features(input_features)
80-
return image_embeds.unsqueeze(0)
80+
if isinstance(image_embeds, list):
81+
image_embeds = torch.stack(image_embeds)
82+
return image_embeds
8183

8284

8385
class AudioExportableModule(torch.nn.Module):

0 commit comments

Comments
 (0)