Skip to content

Commit cc08180

Browse files
committed
preprocessing libs
1 parent 32380ea commit cc08180

File tree

11 files changed

+1266
-6
lines changed

11 files changed

+1266
-6
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,9 @@ remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision en
875875
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
876876
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
877877
image_placeholder: "<|image|>"
878+
video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4"
879+
audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav"
880+
use_audio_in_video: False
878881
posemb_type_for_vit: "learn"
879882
# max_num_images_per_example only applies for training when your image column is a list of images.
880883
# -1 means no limit, and will pad to the max possible number of images determined by sequence length.

src/MaxText/configs/types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,9 @@ class MultimodalGeneral(BaseModel):
11801180
-1,
11811181
description="Maximum number of images per example for training with image lists. -1 means no limit.",
11821182
)
1183+
video_path: PathStr = Field("", description="Path to a video for decoding.")
1184+
audio_path: PathStr = Field("", description="Path to an audio file for decoding.")
1185+
use_audio_in_video: bool = Field(False, description="Extract and use audio from video files.")
11831186

11841187

11851188
class VisionTower(BaseModel):
@@ -1850,7 +1853,14 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
18501853
if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1:
18511854
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
18521855
if self.use_multimodal:
1853-
valid_mm_models = ("gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e")
1856+
valid_mm_models = (
1857+
"gemma3-4b",
1858+
"gemma3-12b",
1859+
"gemma3-27b",
1860+
"llama4-17b-16e",
1861+
"llama4-17b-128e",
1862+
"qwen3-omni-30b-a3b",
1863+
)
18541864
if self.model_name not in valid_mm_models and self.model_name != "default":
18551865
raise ValueError(f"Multimodal is only supported for {valid_mm_models}, not {self.model_name}")
18561866
if self.use_sft:

src/MaxText/decode.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from MaxText import pyconfig
2929
from MaxText import profiler
3030
from MaxText import multimodal_utils
31+
from MaxText.multimodal import preprocessor
3132
# Placeholder: internal
3233

3334
# Number of text sequences to process in a single batch.
@@ -100,14 +101,15 @@ def main(argv: Sequence[str]) -> None:
100101
prefill_length = config.max_prefill_predict_length
101102
processor_outputs = multimodal_utils.PreprocessorOutput()
102103
if config.use_multimodal:
103-
image_path = config.image_path.split(",")
104-
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
105-
processor_outputs = multimodal_utils.pre_process_image(images, model_name=config.model_name)
104+
processor_outputs = preprocessor.preprocess_mm_data(config)
106105
image_offsets = multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_outputs)
107106

108107
prefill_length -= image_offsets
109108
text = multimodal_utils.reformat_prompt(
110-
text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images)
109+
text,
110+
image_placeholder=config.image_placeholder,
111+
model_name=config.model_name,
112+
num_images=processor_outputs.num_images,
111113
)
112114

113115
metadata = engine.get_tokenizer()

src/MaxText/layers/decoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@ def _apply_embedding(
580580
image_masks=image_masks,
581581
)
582582
# TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed
583+
elif cfg.model_name in ["qwen3-omni-30b-a3b"]:
584+
pass
583585
else:
584586
raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}")
585587

src/MaxText/multimodal/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Multimodal data preprocessor router."""
16+
17+
from MaxText import multimodal_utils # TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py
18+
19+
20+
def preprocess_mm_data(config):
21+
"""Preprocesses multimodal data based on the provided configuration.
22+
Routes to the appropriate preprocessing function based on the model name.
23+
24+
Args:
25+
config: A `pyconfig.Config` object containing configuration parameters.
26+
27+
Returns:
28+
A `PreprocessorOutput` object containing the processed multimodal data.
29+
"""
30+
processor_outputs = multimodal_utils.PreprocessorOutput()
31+
32+
if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
33+
34+
images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
35+
processor_outputs = multimodal_utils.pre_process_gemma3_image(images)
36+
elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
37+
38+
images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
39+
processor_outputs = multimodal_utils.pre_process_llama4_image(images)
40+
elif config.model_name in ["qwen3-omni-30b-a3b"]:
41+
from MaxText.multimodal.qwen3_omni_processor import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel
42+
43+
processor_outputs = preprocess_mm_data_qwen3_omni(config)
44+
else:
45+
raise ValueError(f"Model {config.model_name} not supported for multimodal preprocessing.")
46+
47+
return processor_outputs

0 commit comments

Comments
 (0)