|
| 1 | +# Copyright 2023–2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Multimodal data preprocessor router.""" |
| 16 | + |
| 17 | +from MaxText import multimodal_utils # TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py |
| 18 | + |
| 19 | + |
| 20 | +def preprocess_mm_data(config): |
| 21 | + """Preprocesses multimodal data based on the provided configuration. |
| 22 | + Routes to the appropriate preprocessing function based on the model name. |
| 23 | +
|
| 24 | + Args: |
| 25 | + config: A `pyconfig.Config` object containing configuration parameters. |
| 26 | +
|
| 27 | + Returns: |
| 28 | + A `PreprocessorOutput` object containing the processed multimodal data. |
| 29 | + """ |
| 30 | + processor_outputs = multimodal_utils.PreprocessorOutput() |
| 31 | + |
| 32 | + if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]: |
| 33 | + |
| 34 | + images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")] |
| 35 | + processor_outputs = multimodal_utils.pre_process_gemma3_image(images) |
| 36 | + elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]: |
| 37 | + |
| 38 | + images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")] |
| 39 | + processor_outputs = multimodal_utils.pre_process_llama4_image(images) |
| 40 | + elif config.model_name in ["qwen3-omni-30b-a3b"]: |
| 41 | + from MaxText.multimodal.qwen3_omni_processor import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel |
| 42 | + |
| 43 | + processor_outputs = preprocess_mm_data_qwen3_omni(config) |
| 44 | + else: |
| 45 | + raise ValueError(f"Model {config.model_name} not supported for multimodal preprocessing.") |
| 46 | + |
| 47 | + return processor_outputs |
0 commit comments