Skip to content

Commit acd9c90

Browse files
committed
adding multi-image support for llama4 sft and decode.
adding support for slidevqa dataset linting fixes. debugging multi-image sft more linting fixes. even more linting. pyink linting adding extra padding for global tile. adding extra tile for dummy image shape. extending padding to 20 tiles moving column merge to input pipeline utils. linting multimodal_utils.py fixing pytype issues. fixing pytype issues.
1 parent 5b01873 commit acd9c90

File tree

6 files changed

+308
-149
lines changed

6 files changed

+308
-149
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
base_config: "base.yml"
16+
17+
use_sft: True
18+
use_multimodal: True
19+
# For vision, the prompt contains image, we only train on completion tokens
20+
sft_train_on_completion_only: True
21+
packing: False # packing is not supported yet
22+
freeze_vision_encoder_params: True
23+
learning_rate: 2.e-5
24+
25+
# -------------- HF pipeline --------------
26+
dataset_type: hf
27+
hf_path: 'NTT-hil-insight/SlideVQA'
28+
train_split: 'train'
29+
hf_eval_split: 'val'
30+
train_data_columns: ['question', 'answer'] # the first column is prompt, second column is completion
31+
eval_data_columns: ['question', 'answer'] # the first column is prompt, second column is completion
32+
train_image_column: ['page_1', 'page_2', 'page_3', 'page_4', 'page_5', 'page_6', 'page_7', 'page_8', 'page_9', 'page_10', 'page_11', 'page_12', 'page_13', 'page_14', 'page_15', 'page_16', 'page_17', 'page_18', 'page_19', 'page_20'] # list of image columns
33+
eval_image_column: ['page_1', 'page_2', 'page_3', 'page_4', 'page_5', 'page_6', 'page_7', 'page_8', 'page_9', 'page_10', 'page_11', 'page_12', 'page_13', 'page_14', 'page_15', 'page_16', 'page_17', 'page_18', 'page_19', 'page_20'] # list of image columns

src/MaxText/decode.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import os
1818
from typing import Sequence
19-
import numpy as np
2019
import jax
2120
import jax.numpy as jnp
2221

@@ -103,10 +102,9 @@ def main(argv: Sequence[str]) -> None:
103102
if config.use_multimodal:
104103
image_path = config.image_path.split(",")
105104
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
106-
processor_outputs = [multimodal_utils.pre_process_image(img, model_name=config.model_name) for img in images]
107-
image_offsets = sum(
108-
multimodal_utils.get_image_offsets(config.model_name, processor_output=po) for po in processor_outputs
109-
)
105+
processor_outputs = multimodal_utils.pre_process_image(images, model_name=config.model_name)
106+
image_offsets = multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_outputs)
107+
110108
prefill_length -= image_offsets
111109
text = multimodal_utils.reformat_prompt(
112110
text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images)
@@ -150,10 +148,8 @@ def main(argv: Sequence[str]) -> None:
150148
prefill_result, first_token = engine.prefill(
151149
params=params,
152150
padded_tokens=tokens,
153-
images=np.stack([po.pixel_values for po in processor_outputs]) if config.use_multimodal else None,
154-
image_masks=np.stack([po.pixel_mask for po in processor_outputs])
155-
if config.use_multimodal and "llama4" in config.model_name
156-
else None,
151+
images=processor_outputs.pixel_values if config.use_multimodal else None,
152+
image_masks=processor_outputs.pixel_mask if config.use_multimodal and "llama4" in config.model_name else None,
157153
true_length=true_length,
158154
rng=rng_prefill,
159155
slot=i,

src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ def vision_sft_preprocessing_pipeline(
4747
if config.enable_data_shuffling:
4848
dataset = dataset.shuffle(seed=config.data_shuffle_seed)
4949

50+
# If multiple image columns are provided, merge them into a single 'images' column.
51+
if isinstance(image_column, list):
52+
dataset = dataset.map(
53+
_input_pipeline_utils.merge_image_columns,
54+
fn_kwargs={
55+
"image_columns": image_column,
56+
"max_num_images_per_example": config.max_num_images_per_example,
57+
},
58+
remove_columns=image_column, # Drop the original image columns
59+
)
60+
5061
dataset = dataset.select_columns(text_columns + [image_column])
5162
if image_column != "images":
5263
dataset = dataset.rename_column(image_column, "images")
@@ -125,7 +136,9 @@ def vision_sft_preprocessing_pipeline(
125136
max_num_images_per_example=config.max_num_images_per_example,
126137
)
127138
)
139+
operations.append(_input_pipeline_utils.ExtractImagesAndMasks())
128140
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=True))
141+
operations.append(_input_pipeline_utils.FoldImagesIntoBatch(model_name=config.model_name))
129142
operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1))
130143
dummy_index_sampler = grain.IndexSampler(
131144
num_records=len(dataset),

0 commit comments

Comments
 (0)