Skip to content

Commit 58dba0a

Browse files
Merge pull request #2471 from AI-Hypercomputer:nicogrande/llama4-multi-image
PiperOrigin-RevId: 823084497
2 parents 6b1ef88 + acd9c90 commit 58dba0a

File tree

6 files changed

+308
-149
lines changed

6 files changed

+308
-149
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
base_config: "base.yml"
16+
17+
use_sft: True
18+
use_multimodal: True
19+
# For vision, the prompt contains image, we only train on completion tokens
20+
sft_train_on_completion_only: True
21+
packing: False # packing is not supported yet
22+
freeze_vision_encoder_params: True
23+
learning_rate: 2.e-5
24+
25+
# -------------- HF pipeline --------------
26+
dataset_type: hf
27+
hf_path: 'NTT-hil-insight/SlideVQA'
28+
train_split: 'train'
29+
hf_eval_split: 'val'
30+
train_data_columns: ['question', 'answer'] # the first column is prompt, second column is completion
31+
eval_data_columns: ['question', 'answer'] # the first column is prompt, second column is completion
32+
train_image_column: ['page_1', 'page_2', 'page_3', 'page_4', 'page_5', 'page_6', 'page_7', 'page_8', 'page_9', 'page_10', 'page_11', 'page_12', 'page_13', 'page_14', 'page_15', 'page_16', 'page_17', 'page_18', 'page_19', 'page_20'] # list of image columns
33+
eval_image_column: ['page_1', 'page_2', 'page_3', 'page_4', 'page_5', 'page_6', 'page_7', 'page_8', 'page_9', 'page_10', 'page_11', 'page_12', 'page_13', 'page_14', 'page_15', 'page_16', 'page_17', 'page_18', 'page_19', 'page_20'] # list of image columns

src/MaxText/decode.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import os
1818
from typing import Sequence
19-
import numpy as np
2019
import jax
2120
import jax.numpy as jnp
2221

@@ -103,10 +102,9 @@ def main(argv: Sequence[str]) -> None:
103102
if config.use_multimodal:
104103
image_path = config.image_path.split(",")
105104
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
106-
processor_outputs = [multimodal_utils.pre_process_image(img, model_name=config.model_name) for img in images]
107-
image_offsets = sum(
108-
multimodal_utils.get_image_offsets(config.model_name, processor_output=po) for po in processor_outputs
109-
)
105+
processor_outputs = multimodal_utils.pre_process_image(images, model_name=config.model_name)
106+
image_offsets = multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_outputs)
107+
110108
prefill_length -= image_offsets
111109
text = multimodal_utils.reformat_prompt(
112110
text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images)
@@ -150,10 +148,8 @@ def main(argv: Sequence[str]) -> None:
150148
prefill_result, first_token = engine.prefill(
151149
params=params,
152150
padded_tokens=tokens,
153-
images=np.stack([po.pixel_values for po in processor_outputs]) if config.use_multimodal else None,
154-
image_masks=np.stack([po.pixel_mask for po in processor_outputs])
155-
if config.use_multimodal and "llama4" in config.model_name
156-
else None,
151+
images=processor_outputs.pixel_values if config.use_multimodal else None,
152+
image_masks=processor_outputs.pixel_mask if config.use_multimodal and "llama4" in config.model_name else None,
157153
true_length=true_length,
158154
rng=rng_prefill,
159155
slot=i,

src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ def vision_sft_preprocessing_pipeline(
4747
if config.enable_data_shuffling:
4848
dataset = dataset.shuffle(seed=config.data_shuffle_seed)
4949

50+
# If multiple image columns are provided, merge them into a single 'images' column.
51+
if isinstance(image_column, list):
52+
dataset = dataset.map(
53+
_input_pipeline_utils.merge_image_columns,
54+
fn_kwargs={
55+
"image_columns": image_column,
56+
"max_num_images_per_example": config.max_num_images_per_example,
57+
},
58+
remove_columns=image_column, # Drop the original image columns
59+
)
60+
5061
dataset = dataset.select_columns(text_columns + [image_column])
5162
if image_column != "images":
5263
dataset = dataset.rename_column(image_column, "images")
@@ -125,7 +136,9 @@ def vision_sft_preprocessing_pipeline(
125136
max_num_images_per_example=config.max_num_images_per_example,
126137
)
127138
)
139+
operations.append(_input_pipeline_utils.ExtractImagesAndMasks())
128140
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=True))
141+
operations.append(_input_pipeline_utils.FoldImagesIntoBatch(model_name=config.model_name))
129142
operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1))
130143
dummy_index_sampler = grain.IndexSampler(
131144
num_records=len(dataset),

0 commit comments

Comments
 (0)