Skip to content

Commit b816da6

Browse files
author
Eitan Porat
committed
Add Preprocessing and token placeholders
1 parent 05abc90 commit b816da6

File tree

3 files changed

+272
-17
lines changed

3 files changed

+272
-17
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,10 @@ dtype_mm: "float32" # Data type for multimodal model's vision encoder
892892
remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options.
893893
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
894894
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
895+
audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav"
896+
video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4"
895897
image_placeholder: "<|image|>"
898+
audio_placeholder: "<|audio|>"
896899
posemb_type_for_vit: "learn"
897900
# max_num_images_per_example only applies for training when your image column is a list of images.
898901
# -1 means no limit, and will pad to the max possible number of images determined by sequence length.

src/MaxText/decode.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,17 @@ def main(argv: Sequence[str]) -> None:
9999
text = config.prompt
100100
prefill_length = config.max_prefill_predict_length
101101
processor_outputs = multimodal_utils.PreprocessorOutput()
102+
102103
if config.use_multimodal:
103-
image_path = config.image_path.split(",")
104-
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
105-
processor_outputs = multimodal_utils.pre_process_image(images, model_name=config.model_name)
104+
processor_outputs = multimodal_utils.preprocess_mm_data(config)
106105
image_offsets = multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_outputs)
107106

108107
prefill_length -= image_offsets
109108
text = multimodal_utils.reformat_prompt(
110-
text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images)
109+
text,
110+
image_placeholder=config.image_placeholder,
111+
model_name=config.model_name,
112+
num_images=processor_outputs.num_images,
111113
)
112114

113115
metadata = engine.get_tokenizer()

0 commit comments

Comments
 (0)