Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,9 @@ remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision en
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
image_placeholder: "<|image|>"
video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4"
audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav"
use_audio_in_video: False
posemb_type_for_vit: "learn"
# max_num_images_per_example only applies for training when your image column is a list of images.
# -1 means no limit, and will pad to the max possible number of images determined by sequence length.
Expand Down
12 changes: 11 additions & 1 deletion src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,9 @@ class MultimodalGeneral(BaseModel):
-1,
description="Maximum number of images per example for training with image lists. -1 means no limit.",
)
video_path: PathStr = Field("", description="Path to a video for decoding.")
audio_path: PathStr = Field("", description="Path to an audio file for decoding.")
use_audio_in_video: bool = Field(False, description="Extract and use audio from video files.")


class VisionTower(BaseModel):
Expand Down Expand Up @@ -1850,7 +1853,14 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1:
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
if self.use_multimodal:
valid_mm_models = ("gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e")
valid_mm_models = (
"gemma3-4b",
"gemma3-12b",
"gemma3-27b",
"llama4-17b-16e",
"llama4-17b-128e",
"qwen3-omni-30b-a3b",
)
if self.model_name not in valid_mm_models and self.model_name != "default":
raise ValueError(f"Multimodal is only supported for {valid_mm_models}, not {self.model_name}")
if self.use_sft:
Expand Down
10 changes: 6 additions & 4 deletions src/MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from MaxText import pyconfig
from MaxText import profiler
from MaxText import multimodal_utils
from MaxText.multimodal import preprocessor
# Placeholder: internal

# Number of text sequences to process in a single batch.
Expand Down Expand Up @@ -100,14 +101,15 @@ def main(argv: Sequence[str]) -> None:
prefill_length = config.max_prefill_predict_length
processor_outputs = multimodal_utils.PreprocessorOutput()
if config.use_multimodal:
image_path = config.image_path.split(",")
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
processor_outputs = multimodal_utils.pre_process_image(images, model_name=config.model_name)
processor_outputs = preprocessor.preprocess_mm_data(config)
image_offsets = multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_outputs)

prefill_length -= image_offsets
text = multimodal_utils.reformat_prompt(
text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images)
text,
image_placeholder=config.image_placeholder,
model_name=config.model_name,
num_images=processor_outputs.num_images,
)

metadata = engine.get_tokenizer()
Expand Down
13 changes: 13 additions & 0 deletions src/MaxText/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
47 changes: 47 additions & 0 deletions src/MaxText/multimodal/preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multimodal data preprocessor router."""

from MaxText import multimodal_utils # TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The TODO comment on line 18 indicates a future refactoring to deprecate MaxText/multimodal_utils.py. It would be beneficial to have a clear plan or follow-up issue for this refactoring to ensure multimodal_utils.py is eventually removed and its relevant functions are moved to MaxText/multimodal/utils.py to keep the codebase clean and organized.


def preprocess_mm_data(config):
"""Preprocesses multimodal data based on the provided configuration.
Routes to the appropriate preprocessing function based on the model name.

Args:
config: A `pyconfig.Config` object containing configuration parameters.

Returns:
A `PreprocessorOutput` object containing the processed multimodal data.
"""
processor_outputs = multimodal_utils.PreprocessorOutput()

if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:

images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
processor_outputs = multimodal_utils.pre_process_gemma3_image(images)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename the functions to preprocess_mm_data_gemma3 ?

maybe it would be better to use a factory pattern here

elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The PreprocessorOutput is imported from MaxText.multimodal_utils. Since a new MaxText.multimodal.utils has been introduced, it would be more consistent to use PreprocessorOutput from MaxText.multimodal.utils instead of the old multimodal_utils. This also aligns with the TODO to deprecate multimodal_utils.py.

Suggested change
elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
from MaxText.multimodal import utils as mm_utils
from MaxText.multimodal.qwen3_omni_processor import Qwen3OmniPreprocessorOutput # To resolve a potential circular dependency
# TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py
def preprocess_mm_data(config):
"""Preprocesses multimodal data based on the provided configuration.
Routes to the appropriate preprocessing function based on the model name.
Args:
config: A `pyconfig.Config` object containing configuration parameters.
Returns:
A `PreprocessorOutput` object containing the processed multimodal data.
"""
processor_outputs = mm_utils.PreprocessorOutput() # Using the new utils


images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
processor_outputs = multimodal_utils.pre_process_llama4_image(images)
elif config.model_name in ["qwen3-omni-30b-a3b"]:
from MaxText.multimodal.qwen3_omni_processor import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel

processor_outputs = preprocess_mm_data_qwen3_omni(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it accept a config?

else:
raise ValueError(f"Model {config.model_name} not supported for multimodal preprocessing.")

return processor_outputs
Loading
Loading