6666QWEN3_OMNI_IMAGE_TOKEN = 151655
6767QWEN3_OMNI_VIDEO_TOKEN = 151656
6868QWEN3_OMNI_AUDIO_TOKEN = 151675
69+ QWEN3_OMNI_VISION_START_TOKEN = 151652
70+ QWEN3_OMNI_VISION_END_TOKEN = 151653
71+ QWEN3_OMNI_AUDIO_START_TOKEN = 151669
72+ QWEN3_OMNI_AUDIO_END_TOKEN = 151670
6973QWEN3_TEMPORAL_PATCH_SIZE = 2
7074QWEN3_OMNI_IMAGE_SIZE = 768
7175
@@ -93,6 +97,7 @@ class PreprocessorOutput:
9397
9498
9599def convert_to_RGB (image ):
100+ """Convert image to RGB format."""
96101 if image .mode != "RGB" :
97102 image = image .convert ("RGB" )
98103 return image
@@ -372,6 +377,7 @@ def pre_process_gemma3_image(image: np.ndarray | list[np.ndarray]) -> Preprocess
372377 processor_output = PreprocessorOutput (
373378 pixel_values = np .stack (images_out , axis = 0 ).astype (np .float32 ), # (N, H, W, C)
374379 )
380+ processor_output .num_images = len (image )
375381 return processor_output
376382
377383
@@ -452,10 +458,11 @@ def pre_process_llama4_image(image: np.ndarray | list[np.ndarray]) -> Preprocess
452458 pixel_mask = image_mask ,
453459 aspect_ratios = aspect_ratios_array ,
454460 )
461+ processor_output .num_images = len (image )
455462 return processor_output
456463
457464
458- def pre_process_image (image , model_name ):
465+ def pre_process_image (image , model_name , config = None ):
459466 """Pre-process image according to different model's requirements.
460467 Args:
461468 image: The np.array image [H, W, C] or images [N, H, W, C] to pre-process.
@@ -471,6 +478,28 @@ def pre_process_image(image, model_name):
471478 raise ValueError (f"Model { model_name } does not support multimodal inference." )
472479
473480
481+ def preprocess_mm_data (config ):
482+ """Preprocess multimodal data according to model requirements.
483+
484+ Args:
485+ config: The configuration object containing model_name and data paths.
486+
487+ Returns:
488+ PreprocessorOutput with preprocessed multimodal data.
489+ """
490+ if config .model_name .startswith ("qwen3-omni" ):
491+ from MaxText .multimodal .qwen3_omni_processor import preprocess_mm_data_qwen3_omni
492+ return preprocess_mm_data_qwen3_omni (config )
493+ elif config .model_name in ["gemma3-4b" , "gemma3-12b" , "gemma3-27b" ]:
494+ images = [load_image_from_path (p ) for p in config .image_path .split ("," )]
495+ return pre_process_image (images , model_name = config .model_name )
496+ elif config .model_name in ["llama4-17b-16e" , "llama4-17b-128e" ]:
497+ images = [load_image_from_path (p ) for p in config .image_path .split ("," )]
498+ return pre_process_image (images , model_name = config .model_name )
499+ else :
500+ raise ValueError (f"Model { config .model_name } does not support multimodal preprocessing." )
501+
502+
474503def reformat_prompt (prompt , image_placeholder , model_name , num_images ):
475504 """Reformat prompt for different models."""
476505 if model_name in ["gemma3-4b" , "gemma3-12b" , "gemma3-27b" ]:
@@ -492,6 +521,17 @@ def reformat_prompt(prompt, image_placeholder, model_name, num_images):
492521 f"{ prompt } <|eot|><|header_start|>assistant<|header_end|>\n \n "
493522 )
494523 return formatted_prompt
524+ elif model_name in ["qwen3-omni-30b-a3b" ]:
525+ # Qwen3-Omni vision format: <|vision_start|><|image_pad|><|vision_end|>
526+ qwen3_image_placeholder = "<|vision_start|><|image_pad|><|vision_end|>"
527+ if image_placeholder in prompt :
528+ prompt = prompt .replace (image_placeholder , qwen3_image_placeholder )
529+ image_placeholder_count = prompt .count (qwen3_image_placeholder )
530+ if image_placeholder_count < num_images :
531+ prompt = qwen3_image_placeholder * (num_images - image_placeholder_count ) + prompt
532+ # Qwen chat template
533+ formatted_prompt = f"<|im_start|>user\n { prompt } <|im_end|>\n <|im_start|>assistant\n "
534+ return formatted_prompt
495535 else :
496536 return prompt
497537
@@ -504,6 +544,9 @@ def reformat_response(response, model_name):
504544 elif model_name in ["gemma3-4b" , "gemma3-12b" , "gemma3-27b" ]:
505545 formatted_response = f"{ response } <end_of_turn>"
506546 return formatted_response
547+ elif model_name in ["qwen3-omni-30b-a3b" ]:
548+ formatted_response = f"{ response } <|im_end|>"
549+ return formatted_response
507550 else :
508551 return response
509552
@@ -516,7 +559,7 @@ def get_image_offsets(model_name, processor_output: PreprocessorOutput | None):
516559 return (
517560 GEMMA_NUM_TOKENS_PER_MEDIA - 1
518561 ) * num_images # -1 because <start_of_image> is already present in the input tokens.
519- if model_name in ["llama4-17b-16e" , "llama4-17b-128e" ]:
562+ elif model_name in ["llama4-17b-16e" , "llama4-17b-128e" ]:
520563 assert processor_output is not None , "Processor output must be provided for Llama4 image fusion."
521564 assert processor_output .aspect_ratios is not None , "Aspect ratio must be provided for Llama4 image fusion."
522565 image_height , image_width = LLAMA4_TILE_SIZE , LLAMA4_TILE_SIZE
@@ -532,6 +575,36 @@ def get_image_offsets(model_name, processor_output: PreprocessorOutput | None):
532575 )
533576 images_offsets = image_tokens_count - num_images
534577 return images_offsets # -num_images because replacing every <|image|> tokens.
578+ elif model_name .startswith ("qwen3-omni" ):
579+ # Calculate token expansion for Qwen3-Omni multimodal inputs
580+ if processor_output is None :
581+ return 0
582+
583+ total_offset = 0
584+ spatial_merge_size = 2 # Default for Qwen3-Omni
585+ merge_length = spatial_merge_size ** 2
586+
587+ # Image tokens: <|image_pad|> expands to multiple image tokens
588+ if processor_output .pixel_grid_thw is not None :
589+ image_grid_thw = processor_output .pixel_grid_thw
590+ for _i , grid in enumerate (image_grid_thw ):
591+ num_image_tokens = int ((grid [0 ] * grid [1 ] * grid [2 ]) // merge_length )
592+ total_offset += num_image_tokens - 1 # -1 for the original <|image_pad|> token
593+
594+ # Video tokens: <|video_pad|> expands to multiple video tokens
595+ if processor_output .video_grid_thw is not None :
596+ video_grid_thw = processor_output .video_grid_thw
597+ for _i , grid in enumerate (video_grid_thw ):
598+ num_video_tokens = int ((grid [0 ] * grid [1 ] * grid [2 ]) // merge_length )
599+ total_offset += num_video_tokens - 1 # -1 for the original <|video_pad|> token
600+
601+ # Audio tokens: <|audio_pad|> expands based on audio_lengths
602+ if processor_output .audio_lengths is not None :
603+ audio_lengths = processor_output .audio_lengths
604+ for audio_len in audio_lengths :
605+ total_offset += int (audio_len ) - 1 # -1 for the original <|audio_pad|> token
606+
607+ return total_offset
535608 else :
536609 return 0
537610
@@ -568,12 +641,61 @@ def get_dummy_image_shape_for_init(
568641 return image_shape
569642
570643
571- def prepare_text_for_image_fusion (texts , model_name , processor_output = None ):
644+ def prepare_text_for_image_fusion (
645+ texts ,
646+ model_name ,
647+ processor_output = None ,
648+ image_grid_thw = None ,
649+ video_grid_thw = None ,
650+ audio_lengths = None ,
651+ spatial_merge_size = 2 ,
652+ use_audio_in_video = False ,
653+ second_per_grids = None ,
654+ position_id_per_seconds = 25 ,
655+ ):
656+ """Prepare text tokens for multimodal fusion by expanding special tokens.
657+
658+ Args:
659+ texts: Input token sequence.
660+ model_name: Model name to determine processing logic.
661+ processor_output: Preprocessor output for Gemma3/Llama4 (contains pixel_values, aspect_ratios).
662+ image_grid_thw: Image dimensions for Qwen3-Omni (num_images, 3).
663+ video_grid_thw: Video dimensions for Qwen3-Omni (num_videos, 3).
664+ audio_lengths: Audio sequence lengths for Qwen3-Omni (num_audios,).
665+ spatial_merge_size: Spatial merge size for Qwen3-Omni.
666+ use_audio_in_video: Whether to interleave audio with video for Qwen3-Omni.
667+ second_per_grids: Time per grid for Qwen3-Omni videos (num_videos,).
668+ position_id_per_seconds: Temporal granularity for Qwen3-Omni.
669+
670+ Returns:
671+ Expanded token sequence with multimodal tokens inserted.
672+ """
572673 if model_name in ["gemma3-4b" , "gemma3-12b" , "gemma3-27b" ]:
573674 num_images = processor_output .pixel_values .shape [0 ] if processor_output else 1
574675 return add_extra_tokens_for_images_gemma3 (texts , max_num_images = num_images )
575- if model_name in ["llama4-17b-16e" , "llama4-17b-128e" ]:
676+ elif model_name in ["llama4-17b-16e" , "llama4-17b-128e" ]:
576677 return add_extra_tokens_for_images_llama4 (texts , processor_output )
678+ elif model_name .startswith ("qwen3-omni" ):
679+ # Extract Qwen3-Omni specific parameters from processor_output if not provided
680+ if image_grid_thw is None and processor_output is not None :
681+ image_grid_thw = getattr (processor_output , "pixel_grid_thw" , None )
682+ if video_grid_thw is None and processor_output is not None :
683+ video_grid_thw = getattr (processor_output , "video_grid_thw" , None )
684+ if audio_lengths is None and processor_output is not None :
685+ audio_lengths = getattr (processor_output , "audio_lengths" , None )
686+ if second_per_grids is None and processor_output is not None :
687+ second_per_grids = getattr (processor_output , "video_second_per_grid" , None )
688+
689+ return add_extra_tokens_for_qwen3_omni (
690+ tokens = texts ,
691+ image_grid_thw = image_grid_thw ,
692+ video_grid_thw = video_grid_thw ,
693+ audio_lengths = audio_lengths ,
694+ spatial_merge_size = spatial_merge_size ,
695+ use_audio_in_video = use_audio_in_video ,
696+ second_per_grids = second_per_grids ,
697+ position_id_per_seconds = position_id_per_seconds ,
698+ )
577699 else :
578700 raise ValueError (f"Model { model_name } does not support multimodal inference." )
579701
@@ -700,6 +822,142 @@ def get_num_tokens_for_this_image(this_aspect_ratio, num_patches_per_chunk):
700822 return int (num_img_tokens )
701823
702824
825+ def add_extra_tokens_for_qwen3_omni (
826+ tokens : np .ndarray | list ,
827+ image_grid_thw : np .ndarray | None = None ,
828+ video_grid_thw : np .ndarray | None = None ,
829+ audio_lengths : np .ndarray | None = None ,
830+ spatial_merge_size : int = 2 ,
831+ use_audio_in_video : bool = False ,
832+ second_per_grids : np .ndarray | None = None ,
833+ position_id_per_seconds : int = 25 ,
834+ ):
835+ """Add extra tokens for Qwen3-Omni multimodal sequences.
836+
837+ Expands special tokens (<|image_pad|>, <|video_pad|>, <|audio_pad|>) into
838+ the correct number of placeholder tokens based on grid dimensions and merge size.
839+
840+ For audio-in-video mode, interleaves audio and video tokens based on temporal ordering.
841+
842+ Args:
843+ tokens: Input token sequence (1D array or list).
844+ image_grid_thw: Image dimensions (num_images, 3) with [temporal, height, width].
845+ video_grid_thw: Video dimensions (num_videos, 3) with [temporal, height, width].
846+ audio_lengths: Pre-computed audio token counts (num_audios,).
847+ spatial_merge_size: Number of patches merged spatially (e.g., 2 for 2x2→1).
848+ use_audio_in_video: If True, interleave audio and video tokens.
849+ second_per_grids: Time interval per temporal grid (num_videos,).
850+ position_id_per_seconds: Temporal granularity (tokens per second).
851+
852+ Returns:
853+ Expanded token sequence with correct number of image/video/audio tokens.
854+ """
855+ if not isinstance (tokens , np .ndarray ):
856+ tokens = np .asarray (tokens )
857+
858+ tokens = tokens .flatten () # Ensure 1D
859+
860+ # Merge lengths for computing number of tokens
861+ merge_length = spatial_merge_size ** 2
862+
863+ # Convert to list for easier manipulation
864+ token_list = tokens .tolist ()
865+ new_tokens = []
866+
867+ image_idx = 0
868+ video_idx = 0
869+ audio_idx = 0
870+
871+ i = 0
872+ while i < len (token_list ):
873+ token = token_list [i ]
874+
875+ # Handle image tokens
876+ if token == QWEN3_OMNI_IMAGE_TOKEN and image_grid_thw is not None and image_idx < len (image_grid_thw ):
877+ grid = image_grid_thw [image_idx ]
878+ num_image_tokens = int ((grid [0 ] * grid [1 ] * grid [2 ]) // merge_length )
879+ new_tokens .extend ([QWEN3_OMNI_IMAGE_TOKEN ] * num_image_tokens )
880+ image_idx += 1
881+
882+ # Handle audio-in-video: <|vision_start|><|video_pad|><|vision_end|>
883+ elif (
884+ use_audio_in_video
885+ and token == QWEN3_OMNI_VISION_START_TOKEN
886+ and i + 2 < len (token_list )
887+ and token_list [i + 1 ] == QWEN3_OMNI_VIDEO_TOKEN
888+ and token_list [i + 2 ] == QWEN3_OMNI_VISION_END_TOKEN
889+ and video_grid_thw is not None
890+ and video_idx < len (video_grid_thw )
891+ ):
892+
893+ if audio_lengths is None or audio_idx >= len (audio_lengths ):
894+ raise ValueError ("audio_lengths required for audio-in-video mode" )
895+ if second_per_grids is None or video_idx >= len (second_per_grids ):
896+ raise ValueError ("second_per_grids required for audio-in-video mode" )
897+
898+ audio_length = audio_lengths [audio_idx ]
899+ audio_token_indices = np .arange (audio_length )
900+
901+ curr_video_grid = video_grid_thw [video_idx ]
902+ height = curr_video_grid [1 ] // spatial_merge_size
903+ width = curr_video_grid [2 ] // spatial_merge_size
904+ num_frames = curr_video_grid [0 ]
905+
906+ video_token_indices = np .arange (num_frames ).reshape (- 1 , 1 , 1 )
907+ video_token_indices = np .broadcast_to (video_token_indices , (num_frames , height , width )).flatten ()
908+ video_token_indices = video_token_indices * second_per_grids [video_idx ] * position_id_per_seconds
909+
910+ new_tokens .append (QWEN3_OMNI_VISION_START_TOKEN )
911+ new_tokens .append (QWEN3_OMNI_AUDIO_START_TOKEN )
912+
913+ video_data_idx = 0
914+ audio_data_idx = 0
915+
916+ while video_data_idx < len (video_token_indices ) and audio_data_idx < len (audio_token_indices ):
917+ if video_token_indices [video_data_idx ] <= audio_token_indices [audio_data_idx ]:
918+ new_tokens .append (QWEN3_OMNI_VIDEO_TOKEN )
919+ video_data_idx += 1
920+ else :
921+ new_tokens .append (QWEN3_OMNI_AUDIO_TOKEN )
922+ audio_data_idx += 1
923+
924+ while video_data_idx < len (video_token_indices ):
925+ new_tokens .append (QWEN3_OMNI_VIDEO_TOKEN )
926+ video_data_idx += 1
927+
928+ while audio_data_idx < len (audio_token_indices ):
929+ new_tokens .append (QWEN3_OMNI_AUDIO_TOKEN )
930+ audio_data_idx += 1
931+
932+ new_tokens .append (QWEN3_OMNI_AUDIO_END_TOKEN )
933+ new_tokens .append (QWEN3_OMNI_VISION_END_TOKEN )
934+
935+ video_idx += 1
936+ audio_idx += 1
937+ i += 2
938+
939+ # Handle video tokens (without audio-in-video)
940+ elif token == QWEN3_OMNI_VIDEO_TOKEN and video_grid_thw is not None and video_idx < len (video_grid_thw ):
941+ grid = video_grid_thw [video_idx ]
942+ num_video_tokens = int ((grid [0 ] * grid [1 ] * grid [2 ]) // merge_length )
943+ new_tokens .extend ([QWEN3_OMNI_VIDEO_TOKEN ] * num_video_tokens )
944+ video_idx += 1
945+
946+ # Handle audio tokens (standalone, not in video)
947+ elif token == QWEN3_OMNI_AUDIO_TOKEN and audio_lengths is not None and audio_idx < len (audio_lengths ):
948+ num_audio_tokens = int (audio_lengths [audio_idx ])
949+ new_tokens .extend ([QWEN3_OMNI_AUDIO_TOKEN ] * num_audio_tokens )
950+ audio_idx += 1
951+
952+ # All other tokens pass through unchanged
953+ else :
954+ new_tokens .append (token )
955+
956+ i += 1
957+
958+ return np .array (new_tokens , dtype = np .int32 )
959+
960+
703961def add_extra_tokens_for_images_gemma3 (
704962 tokens : np .ndarray | list ,
705963 * ,
@@ -956,4 +1214,4 @@ def _merge_mm_embeddings_inner(
9561214 # Restore the first position's embedding, in case it was overwritten.
9571215 merged = merged .at [0 ].set (first_pos_embedding )
9581216
959- return merged
1217+ return merged
0 commit comments