Skip to content

Commit 660a018

Browse files
committed
pre
pylint video image/video/audio pyink
1 parent 0b24223 commit 660a018

File tree

11 files changed

+1133
-6
lines changed

11 files changed

+1133
-6
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,9 @@ remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision en
873873
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
874874
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
875875
image_placeholder: "<|image|>"
876+
video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4"
877+
audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav"
878+
use_audio_in_video: False
876879
posemb_type_for_vit: "learn"
877880
# max_num_images_per_example only applies for training when your image column is a list of images.
878881
# -1 means no limit, and will pad to the max possible number of images determined by sequence length.

src/MaxText/configs/models/qwen3-omni-30b-a3b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ base_emb_dim: 2048
2020
base_mlp_dim: 768
2121
base_num_query_heads: 32
2222
base_num_kv_heads: 4
23-
base_num_decoder_layers: 48
23+
base_num_decoder_layers: 1
2424
head_dim: 128
2525
mlp_activations: ["silu", "linear"]
2626
vocab_size: 152064

src/MaxText/decode.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from MaxText import pyconfig
2929
from MaxText import profiler
3030
from MaxText import multimodal_utils
31+
from MaxText.multimodal import preprocessor
3132
# Placeholder: internal
3233

3334
# Number of text sequences to process in a single batch.
@@ -100,14 +101,15 @@ def main(argv: Sequence[str]) -> None:
100101
prefill_length = config.max_prefill_predict_length
101102
processor_outputs = multimodal_utils.PreprocessorOutput()
102103
if config.use_multimodal:
103-
image_path = config.image_path.split(",")
104-
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
105-
processor_outputs = multimodal_utils.pre_process_image(images, model_name=config.model_name)
104+
processor_outputs = preprocessor.preprocess_mm_data(config)
106105
image_offsets = multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_outputs)
107106

108107
prefill_length -= image_offsets
109108
text = multimodal_utils.reformat_prompt(
110-
text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images)
109+
text,
110+
image_placeholder=config.image_placeholder,
111+
model_name=config.model_name,
112+
num_images=processor_outputs.num_images,
111113
)
112114

113115
metadata = engine.get_tokenizer()

src/MaxText/layers/decoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,8 @@ def _apply_embedding(
566566
image_masks=image_masks,
567567
)
568568
# TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed
569+
elif cfg.model_name in ["qwen3-omni-30b-a3b"]:
570+
pass
569571
else:
570572
raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}")
571573

src/MaxText/layers/encoders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def get_vision_encoder_layers(self):
4444
from MaxText.layers import llama4 # pylint: disable=import-outside-toplevel
4545

4646
return [llama4.llama4visionmodel_as_linen, llama4.llama4multimodalprojector_as_linen]
47+
elif self.config.model_name in ["qwen3-omni-30b-a3b"]:
48+
from MaxText.layers import gemma3 # pylint: disable=import-outside-toplevel
49+
50+
return [gemma3.gemma3visionencoder_as_linen, gemma3.visionembedder_as_linen]
4751
else:
4852
raise ValueError(f"No VisionEncoder implemented for {self.config.model_name} yet")
4953

src/MaxText/multimodal/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from MaxText import multimodal_utils # TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py
16+
17+
18+
def preprocess_mm_data(config):
19+
"""Preprocesses multimodal data based on the provided configuration.
20+
Routes to the appropriate preprocessing function based on the model name.
21+
22+
Args:
23+
config: A `pyconfig.Config` object containing configuration parameters.
24+
25+
Returns:
26+
A `PreprocessorOutput` object containing the processed multimodal data.
27+
"""
28+
processor_outputs = multimodal_utils.PreprocessorOutput()
29+
30+
if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
31+
32+
images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
33+
processor_outputs = multimodal_utils.pre_process_gemma3_image(images)
34+
elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
35+
36+
images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
37+
processor_outputs = multimodal_utils.pre_process_llama4_image(images)
38+
elif config.model_name in ["qwen3-omni-30b-a3b"]:
39+
from MaxText.multimodal.qwen3_omni_processor import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel
40+
41+
processor_outputs = preprocess_mm_data_qwen3_omni(config)
42+
else:
43+
raise ValueError(f"Model {config.model_name} not supported for multimodal preprocessing.")
44+
45+
return processor_outputs

0 commit comments

Comments
 (0)