Skip to content

Commit 72cbc23

Browse files
author
Eitan Porat
committed
Add Preprocessing and token placeholders
1 parent 05abc90 commit 72cbc23

File tree

3 files changed

+272
-9
lines changed

3 files changed

+272
-9
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,10 @@ dtype_mm: "float32" # Data type for multimodal model's vision encoder
892892
remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options.
893893
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
894894
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
895+
audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav"
896+
video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4"
895897
image_placeholder: "<|image|>"
898+
audio_placeholder: "<|audio|>"
896899
posemb_type_for_vit: "learn"
897900
# max_num_images_per_example only applies for training when your image column is a list of images.
898901
# -1 means no limit, and will pad to the max possible number of images determined by sequence length.

src/MaxText/decode.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,17 @@ def main(argv: Sequence[str]) -> None:
9999
text = config.prompt
100100
prefill_length = config.max_prefill_predict_length
101101
processor_outputs = multimodal_utils.PreprocessorOutput()
102+
102103
if config.use_multimodal:
103-
image_path = config.image_path.split(",")
104-
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
105-
processor_outputs = multimodal_utils.pre_process_image(images, model_name=config.model_name)
104+
processor_outputs = multimodal_utils.preprocess_mm_data(config)
106105
image_offsets = multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_outputs)
107106

108107
prefill_length -= image_offsets
109108
text = multimodal_utils.reformat_prompt(
110-
text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images)
109+
text,
110+
image_placeholder=config.image_placeholder,
111+
model_name=config.model_name,
112+
num_images=processor_outputs.num_images,
111113
)
112114

113115
metadata = engine.get_tokenizer()

src/MaxText/multimodal_utils.py

Lines changed: 263 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@
6666
QWEN3_OMNI_IMAGE_TOKEN = 151655
6767
QWEN3_OMNI_VIDEO_TOKEN = 151656
6868
QWEN3_OMNI_AUDIO_TOKEN = 151675
69+
QWEN3_OMNI_VISION_START_TOKEN = 151652
70+
QWEN3_OMNI_VISION_END_TOKEN = 151653
71+
QWEN3_OMNI_AUDIO_START_TOKEN = 151669
72+
QWEN3_OMNI_AUDIO_END_TOKEN = 151670
6973
QWEN3_TEMPORAL_PATCH_SIZE = 2
7074
QWEN3_OMNI_IMAGE_SIZE = 768
7175

@@ -93,6 +97,7 @@ class PreprocessorOutput:
9397

9498

9599
def convert_to_RGB(image):
100+
"""Convert image to RGB format."""
96101
if image.mode != "RGB":
97102
image = image.convert("RGB")
98103
return image
@@ -372,6 +377,7 @@ def pre_process_gemma3_image(image: np.ndarray | list[np.ndarray]) -> Preprocess
372377
processor_output = PreprocessorOutput(
373378
pixel_values=np.stack(images_out, axis=0).astype(np.float32), # (N, H, W, C)
374379
)
380+
processor_output.num_images = len(image)
375381
return processor_output
376382

377383

@@ -452,10 +458,11 @@ def pre_process_llama4_image(image: np.ndarray | list[np.ndarray]) -> Preprocess
452458
pixel_mask=image_mask,
453459
aspect_ratios=aspect_ratios_array,
454460
)
461+
processor_output.num_images = len(image)
455462
return processor_output
456463

457464

458-
def pre_process_image(image, model_name):
465+
def pre_process_image(image, model_name, config=None):
459466
"""Pre-process image according to different model's requirements.
460467
Args:
461468
image: The np.array image [H, W, C] or images [N, H, W, C] to pre-process.
@@ -471,6 +478,28 @@ def pre_process_image(image, model_name):
471478
raise ValueError(f"Model {model_name} does not support multimodal inference.")
472479

473480

481+
def preprocess_mm_data(config):
482+
"""Preprocess multimodal data according to model requirements.
483+
484+
Args:
485+
config: The configuration object containing model_name and data paths.
486+
487+
Returns:
488+
PreprocessorOutput with preprocessed multimodal data.
489+
"""
490+
if config.model_name.startswith("qwen3-omni"):
491+
from MaxText.multimodal.qwen3_omni_processor import preprocess_mm_data_qwen3_omni
492+
return preprocess_mm_data_qwen3_omni(config)
493+
elif config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
494+
images = [load_image_from_path(p) for p in config.image_path.split(",")]
495+
return pre_process_image(images, model_name=config.model_name)
496+
elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
497+
images = [load_image_from_path(p) for p in config.image_path.split(",")]
498+
return pre_process_image(images, model_name=config.model_name)
499+
else:
500+
raise ValueError(f"Model {config.model_name} does not support multimodal preprocessing.")
501+
502+
474503
def reformat_prompt(prompt, image_placeholder, model_name, num_images):
475504
"""Reformat prompt for different models."""
476505
if model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
@@ -492,6 +521,17 @@ def reformat_prompt(prompt, image_placeholder, model_name, num_images):
492521
f"{prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n"
493522
)
494523
return formatted_prompt
524+
elif model_name in ["qwen3-omni-30b-a3b"]:
525+
# Qwen3-Omni vision format: <|vision_start|><|image_pad|><|vision_end|>
526+
qwen3_image_placeholder = "<|vision_start|><|image_pad|><|vision_end|>"
527+
if image_placeholder in prompt:
528+
prompt = prompt.replace(image_placeholder, qwen3_image_placeholder)
529+
image_placeholder_count = prompt.count(qwen3_image_placeholder)
530+
if image_placeholder_count < num_images:
531+
prompt = qwen3_image_placeholder * (num_images - image_placeholder_count) + prompt
532+
# Qwen chat template
533+
formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
534+
return formatted_prompt
495535
else:
496536
return prompt
497537

@@ -504,6 +544,9 @@ def reformat_response(response, model_name):
504544
elif model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
505545
formatted_response = f"{response}<end_of_turn>"
506546
return formatted_response
547+
elif model_name in ["qwen3-omni-30b-a3b"]:
548+
formatted_response = f"{response}<|im_end|>"
549+
return formatted_response
507550
else:
508551
return response
509552

@@ -516,7 +559,7 @@ def get_image_offsets(model_name, processor_output: PreprocessorOutput | None):
516559
return (
517560
GEMMA_NUM_TOKENS_PER_MEDIA - 1
518561
) * num_images # -1 because <start_of_image> is already present in the input tokens.
519-
if model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
562+
elif model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
520563
assert processor_output is not None, "Processor output must be provided for Llama4 image fusion."
521564
assert processor_output.aspect_ratios is not None, "Aspect ratio must be provided for Llama4 image fusion."
522565
image_height, image_width = LLAMA4_TILE_SIZE, LLAMA4_TILE_SIZE
@@ -532,6 +575,36 @@ def get_image_offsets(model_name, processor_output: PreprocessorOutput | None):
532575
)
533576
images_offsets = image_tokens_count - num_images
534577
return images_offsets # -num_images because replacing every <|image|> tokens.
578+
elif model_name.startswith("qwen3-omni"):
579+
# Calculate token expansion for Qwen3-Omni multimodal inputs
580+
if processor_output is None:
581+
return 0
582+
583+
total_offset = 0
584+
spatial_merge_size = 2 # Default for Qwen3-Omni
585+
merge_length = spatial_merge_size**2
586+
587+
# Image tokens: <|image_pad|> expands to multiple image tokens
588+
if processor_output.pixel_grid_thw is not None:
589+
image_grid_thw = processor_output.pixel_grid_thw
590+
for _i, grid in enumerate(image_grid_thw):
591+
num_image_tokens = int((grid[0] * grid[1] * grid[2]) // merge_length)
592+
total_offset += num_image_tokens - 1 # -1 for the original <|image_pad|> token
593+
594+
# Video tokens: <|video_pad|> expands to multiple video tokens
595+
if processor_output.video_grid_thw is not None:
596+
video_grid_thw = processor_output.video_grid_thw
597+
for _i, grid in enumerate(video_grid_thw):
598+
num_video_tokens = int((grid[0] * grid[1] * grid[2]) // merge_length)
599+
total_offset += num_video_tokens - 1 # -1 for the original <|video_pad|> token
600+
601+
# Audio tokens: <|audio_pad|> expands based on audio_lengths
602+
if processor_output.audio_lengths is not None:
603+
audio_lengths = processor_output.audio_lengths
604+
for audio_len in audio_lengths:
605+
total_offset += int(audio_len) - 1 # -1 for the original <|audio_pad|> token
606+
607+
return total_offset
535608
else:
536609
return 0
537610

@@ -568,12 +641,61 @@ def get_dummy_image_shape_for_init(
568641
return image_shape
569642

570643

571-
def prepare_text_for_image_fusion(texts, model_name, processor_output=None):
644+
def prepare_text_for_image_fusion(
645+
texts,
646+
model_name,
647+
processor_output=None,
648+
image_grid_thw=None,
649+
video_grid_thw=None,
650+
audio_lengths=None,
651+
spatial_merge_size=2,
652+
use_audio_in_video=False,
653+
second_per_grids=None,
654+
position_id_per_seconds=25,
655+
):
656+
"""Prepare text tokens for multimodal fusion by expanding special tokens.
657+
658+
Args:
659+
texts: Input token sequence.
660+
model_name: Model name to determine processing logic.
661+
processor_output: Preprocessor output for Gemma3/Llama4 (contains pixel_values, aspect_ratios).
662+
image_grid_thw: Image dimensions for Qwen3-Omni (num_images, 3).
663+
video_grid_thw: Video dimensions for Qwen3-Omni (num_videos, 3).
664+
audio_lengths: Audio sequence lengths for Qwen3-Omni (num_audios,).
665+
spatial_merge_size: Spatial merge size for Qwen3-Omni.
666+
use_audio_in_video: Whether to interleave audio with video for Qwen3-Omni.
667+
second_per_grids: Time per grid for Qwen3-Omni videos (num_videos,).
668+
position_id_per_seconds: Temporal granularity for Qwen3-Omni.
669+
670+
Returns:
671+
Expanded token sequence with multimodal tokens inserted.
672+
"""
572673
if model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
573674
num_images = processor_output.pixel_values.shape[0] if processor_output else 1
574675
return add_extra_tokens_for_images_gemma3(texts, max_num_images=num_images)
575-
if model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
676+
elif model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
576677
return add_extra_tokens_for_images_llama4(texts, processor_output)
678+
elif model_name.startswith("qwen3-omni"):
679+
# Extract Qwen3-Omni specific parameters from processor_output if not provided
680+
if image_grid_thw is None and processor_output is not None:
681+
image_grid_thw = getattr(processor_output, "pixel_grid_thw", None)
682+
if video_grid_thw is None and processor_output is not None:
683+
video_grid_thw = getattr(processor_output, "video_grid_thw", None)
684+
if audio_lengths is None and processor_output is not None:
685+
audio_lengths = getattr(processor_output, "audio_lengths", None)
686+
if second_per_grids is None and processor_output is not None:
687+
second_per_grids = getattr(processor_output, "video_second_per_grid", None)
688+
689+
return add_extra_tokens_for_qwen3_omni(
690+
tokens=texts,
691+
image_grid_thw=image_grid_thw,
692+
video_grid_thw=video_grid_thw,
693+
audio_lengths=audio_lengths,
694+
spatial_merge_size=spatial_merge_size,
695+
use_audio_in_video=use_audio_in_video,
696+
second_per_grids=second_per_grids,
697+
position_id_per_seconds=position_id_per_seconds,
698+
)
577699
else:
578700
raise ValueError(f"Model {model_name} does not support multimodal inference.")
579701

@@ -700,6 +822,142 @@ def get_num_tokens_for_this_image(this_aspect_ratio, num_patches_per_chunk):
700822
return int(num_img_tokens)
701823

702824

825+
def add_extra_tokens_for_qwen3_omni(
826+
tokens: np.ndarray | list,
827+
image_grid_thw: np.ndarray | None = None,
828+
video_grid_thw: np.ndarray | None = None,
829+
audio_lengths: np.ndarray | None = None,
830+
spatial_merge_size: int = 2,
831+
use_audio_in_video: bool = False,
832+
second_per_grids: np.ndarray | None = None,
833+
position_id_per_seconds: int = 25,
834+
):
835+
"""Add extra tokens for Qwen3-Omni multimodal sequences.
836+
837+
Expands special tokens (<|image_pad|>, <|video_pad|>, <|audio_pad|>) into
838+
the correct number of placeholder tokens based on grid dimensions and merge size.
839+
840+
For audio-in-video mode, interleaves audio and video tokens based on temporal ordering.
841+
842+
Args:
843+
tokens: Input token sequence (1D array or list).
844+
image_grid_thw: Image dimensions (num_images, 3) with [temporal, height, width].
845+
video_grid_thw: Video dimensions (num_videos, 3) with [temporal, height, width].
846+
audio_lengths: Pre-computed audio token counts (num_audios,).
847+
spatial_merge_size: Number of patches merged spatially (e.g., 2 for 2x2→1).
848+
use_audio_in_video: If True, interleave audio and video tokens.
849+
second_per_grids: Time interval per temporal grid (num_videos,).
850+
position_id_per_seconds: Temporal granularity (tokens per second).
851+
852+
Returns:
853+
Expanded token sequence with correct number of image/video/audio tokens.
854+
"""
855+
if not isinstance(tokens, np.ndarray):
856+
tokens = np.asarray(tokens)
857+
858+
tokens = tokens.flatten() # Ensure 1D
859+
860+
# Merge lengths for computing number of tokens
861+
merge_length = spatial_merge_size**2
862+
863+
# Convert to list for easier manipulation
864+
token_list = tokens.tolist()
865+
new_tokens = []
866+
867+
image_idx = 0
868+
video_idx = 0
869+
audio_idx = 0
870+
871+
i = 0
872+
while i < len(token_list):
873+
token = token_list[i]
874+
875+
# Handle image tokens
876+
if token == QWEN3_OMNI_IMAGE_TOKEN and image_grid_thw is not None and image_idx < len(image_grid_thw):
877+
grid = image_grid_thw[image_idx]
878+
num_image_tokens = int((grid[0] * grid[1] * grid[2]) // merge_length)
879+
new_tokens.extend([QWEN3_OMNI_IMAGE_TOKEN] * num_image_tokens)
880+
image_idx += 1
881+
882+
# Handle audio-in-video: <|vision_start|><|video_pad|><|vision_end|>
883+
elif (
884+
use_audio_in_video
885+
and token == QWEN3_OMNI_VISION_START_TOKEN
886+
and i + 2 < len(token_list)
887+
and token_list[i + 1] == QWEN3_OMNI_VIDEO_TOKEN
888+
and token_list[i + 2] == QWEN3_OMNI_VISION_END_TOKEN
889+
and video_grid_thw is not None
890+
and video_idx < len(video_grid_thw)
891+
):
892+
893+
if audio_lengths is None or audio_idx >= len(audio_lengths):
894+
raise ValueError("audio_lengths required for audio-in-video mode")
895+
if second_per_grids is None or video_idx >= len(second_per_grids):
896+
raise ValueError("second_per_grids required for audio-in-video mode")
897+
898+
audio_length = audio_lengths[audio_idx]
899+
audio_token_indices = np.arange(audio_length)
900+
901+
curr_video_grid = video_grid_thw[video_idx]
902+
height = curr_video_grid[1] // spatial_merge_size
903+
width = curr_video_grid[2] // spatial_merge_size
904+
num_frames = curr_video_grid[0]
905+
906+
video_token_indices = np.arange(num_frames).reshape(-1, 1, 1)
907+
video_token_indices = np.broadcast_to(video_token_indices, (num_frames, height, width)).flatten()
908+
video_token_indices = video_token_indices * second_per_grids[video_idx] * position_id_per_seconds
909+
910+
new_tokens.append(QWEN3_OMNI_VISION_START_TOKEN)
911+
new_tokens.append(QWEN3_OMNI_AUDIO_START_TOKEN)
912+
913+
video_data_idx = 0
914+
audio_data_idx = 0
915+
916+
while video_data_idx < len(video_token_indices) and audio_data_idx < len(audio_token_indices):
917+
if video_token_indices[video_data_idx] <= audio_token_indices[audio_data_idx]:
918+
new_tokens.append(QWEN3_OMNI_VIDEO_TOKEN)
919+
video_data_idx += 1
920+
else:
921+
new_tokens.append(QWEN3_OMNI_AUDIO_TOKEN)
922+
audio_data_idx += 1
923+
924+
while video_data_idx < len(video_token_indices):
925+
new_tokens.append(QWEN3_OMNI_VIDEO_TOKEN)
926+
video_data_idx += 1
927+
928+
while audio_data_idx < len(audio_token_indices):
929+
new_tokens.append(QWEN3_OMNI_AUDIO_TOKEN)
930+
audio_data_idx += 1
931+
932+
new_tokens.append(QWEN3_OMNI_AUDIO_END_TOKEN)
933+
new_tokens.append(QWEN3_OMNI_VISION_END_TOKEN)
934+
935+
video_idx += 1
936+
audio_idx += 1
937+
i += 2
938+
939+
# Handle video tokens (without audio-in-video)
940+
elif token == QWEN3_OMNI_VIDEO_TOKEN and video_grid_thw is not None and video_idx < len(video_grid_thw):
941+
grid = video_grid_thw[video_idx]
942+
num_video_tokens = int((grid[0] * grid[1] * grid[2]) // merge_length)
943+
new_tokens.extend([QWEN3_OMNI_VIDEO_TOKEN] * num_video_tokens)
944+
video_idx += 1
945+
946+
# Handle audio tokens (standalone, not in video)
947+
elif token == QWEN3_OMNI_AUDIO_TOKEN and audio_lengths is not None and audio_idx < len(audio_lengths):
948+
num_audio_tokens = int(audio_lengths[audio_idx])
949+
new_tokens.extend([QWEN3_OMNI_AUDIO_TOKEN] * num_audio_tokens)
950+
audio_idx += 1
951+
952+
# All other tokens pass through unchanged
953+
else:
954+
new_tokens.append(token)
955+
956+
i += 1
957+
958+
return np.array(new_tokens, dtype=np.int32)
959+
960+
703961
def add_extra_tokens_for_images_gemma3(
704962
tokens: np.ndarray | list,
705963
*,
@@ -956,4 +1214,4 @@ def _merge_mm_embeddings_inner(
9561214
# Restore the first position's embedding, in case it was overwritten.
9571215
merged = merged.at[0].set(first_pos_embedding)
9581216

959-
return merged
1217+
return merged

0 commit comments

Comments
 (0)