Skip to content

Commit e6a8e7d

Browse files
authored
Fix binding of video frames to video placeholder in InternVL model (#41237)
* Fix binding video frames to video placeholder in prompt Signed-off-by: Daniel Bershatsky <daniel.bershatsky@gmail.com> * Add test on binding video frames to prompt Signed-off-by: Daniel Bershatsky <daniel.bershatsky@gmail.com> * Fix code style issues Signed-off-by: Daniel Bershatsky <daniel.bershatsky@gmail.com> * Fix broken tests on `InternVLProcessor` Signed-off-by: Daniel Bershatsky <daniel.bershatsky@gmail.com> * Add `return_tensors` to video processor defaults Signed-off-by: Daniel Bershatsky <daniel.bershatsky@gmail.com> --------- Signed-off-by: Daniel Bershatsky <daniel.bershatsky@gmail.com>
1 parent 30b79ef commit e6a8e7d

File tree

2 files changed

+60
-22
lines changed

2 files changed

+60
-22
lines changed

src/transformers/models/internvl/processing_internvl.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ class InternVLProcessorKwargs(ProcessingKwargs, total=False):
4040
"images_kwargs": {
4141
"crop_to_patches": True,
4242
},
43-
"videos_kwargs": {},
43+
"videos_kwargs": {
44+
"return_tensors": "pt",
45+
},
4446
}
4547

4648

@@ -132,10 +134,10 @@ def _insert_media_placeholders(
132134
# Get the slice of patches corresponding to the current video
133135
# Here we need to account for both the multiple video frames and the potential multiple patches per frame
134136
# As of now, InternVL only supports one patch per frame, but we keep the code flexible for future updates
135-
current_patch_index = video_patch_indices[video_index - 1] if video_index > 0 else 0
136-
end_patch_index = video_patch_indices[video_index]
137-
start_index = video_num_patches_indices[current_patch_index] if video_index > 0 else 0
138-
end_index = video_num_patches_indices[end_patch_index - 1]
137+
current_patch_index = video_patch_indices[video_index]
138+
end_patch_index = video_patch_indices[video_index + 1]
139+
start_index = video_num_patches_indices[current_patch_index]
140+
end_index = video_num_patches_indices[end_patch_index]
139141
image_video_patches.append(video_pixel_values[start_index:end_index])
140142
# Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
141143
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
@@ -206,31 +208,38 @@ def __call__(
206208

207209
# Process images and videos separately, as videos don't support crop_to_patches
208210
image_num_patches = []
209-
video_num_patches = []
210-
image_videos_inputs = {}
211211
image_pixel_values = None
212-
video_pixel_values = None
213212
image_num_patches_indices = np.array([0])
214-
video_patch_indices = np.array([0])
215-
video_num_patches_indices = np.array([0])
216213
if images is not None:
217214
images = self.image_processor.fetch_images(images)
218215
images = make_flat_list_of_images(images)
219216
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
220217
image_num_patches = image_inputs.pop("num_patches")
221218
image_pixel_values = image_inputs.pop("pixel_values")
222219
image_num_patches_indices = np.cumsum(image_num_patches)
220+
221+
video_num_patches = [] # per frame
222+
video_pixel_values = None
223+
video_patch_indices = np.array([0])
224+
video_num_patches_indices = np.array([0])
223225
if videos is not None:
224-
video_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
226+
video_kwargs = output_kwargs["videos_kwargs"]
227+
video_inputs = self.video_processor(videos=videos, **video_kwargs)
225228
video_pixel_values = video_inputs.pop("pixel_values_videos")
226229

227-
# Obtain per frame information first and then flatten to (BS * T, ...)
228-
num_frames_per_video = [len(video) for video in video_pixel_values]
229-
video_num_patches = [1 for frames in num_frames_per_video for _ in range(frames)]
230-
video_patch_indices = np.cumsum(num_frames_per_video)
231-
video_num_patches_indices = np.cumsum(video_num_patches)
230+
batch_size, num_frames, *_ = video_pixel_values.shape
231+
num_frames_per_video = np.full(batch_size, num_frames)
232+
num_frames = sum(num_frames_per_video) # total
233+
video_patch_indices = np.empty(batch_size + 1, int)
234+
video_patch_indices[0] = 0
235+
video_patch_indices[1:] = np.cumsum(num_frames_per_video)
236+
video_num_patches = [1] * num_frames
237+
video_num_patches_indices = np.empty(num_frames + 1, int)
238+
video_num_patches_indices[0] = 0
239+
video_num_patches_indices[1:] = np.cumsum(video_num_patches)
232240
video_pixel_values = video_pixel_values.flatten(0, 1)
233241

242+
image_videos_inputs = {}
234243
if images is not None or videos is not None:
235244
text, image_video_patches, image_index, video_index = self._insert_media_placeholders(
236245
text,

tests/models/internvl/test_processing_internvl.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import tempfile
1818
import unittest
1919

20+
from parameterized import parameterized
21+
2022
from transformers import AutoProcessor, AutoTokenizer, InternVLProcessor
2123
from transformers.testing_utils import require_av, require_torch, require_vision
2224
from transformers.utils import is_torch_available, is_vision_available
@@ -345,25 +347,30 @@ def _test_apply_chat_template(
345347
for idx, url in enumerate(input_data[:batch_size]):
346348
batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": modality, "url": url}]
347349

350+
num_frames = 2 # by default no more than 2 frames, otherwise too slow
348351
out_dict = processor.apply_chat_template(
349352
batch_messages,
350353
add_generation_prompt=True,
351354
tokenize=True,
352355
return_dict=True,
353356
return_tensors="pt",
354-
num_frames=2, # by default no more than 2 frames, otherwise too slow
357+
num_frames=num_frames,
355358
)
356359
self.assertTrue(self.videos_input_name in out_dict)
357360
self.assertEqual(len(out_dict["input_ids"]), batch_size)
358361
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
359362

360363
# InternVL internally collects frames from all the videos in a batch and flattens the batch dimension (B T C H W) -> (B*T C H W) then patches and removes the frames
361364
# hence output length does not equal batch size
362-
# removed hardcoded video length check video_len = 2 if batch_size == 1 else 3
363-
# from experiment video_len looks like batch_size + 1
364-
# TODO: update expected video_len calculation based on the internal processing logic of InternVLProcessor
365-
output_len = batch_size + 1 if modality == "video" else batch_size
366-
self.assertEqual(len(out_dict[self.videos_input_name]), output_len)
365+
num_pixel_planes = 0 # i.e. images + video frames
366+
for message_thread in batch_messages:
367+
for message in message_thread:
368+
for content in message.get("content", []):
369+
if (content_type := content.get("type")) == "image":
370+
num_pixel_planes += 1
371+
elif content_type == "video":
372+
num_pixel_planes += num_frames
373+
self.assertEqual(len(out_dict[self.videos_input_name]), num_pixel_planes)
367374
for k in out_dict:
368375
self.assertIsInstance(out_dict[k], torch.Tensor)
369376

@@ -377,3 +384,25 @@ def _test_apply_chat_template(
377384
continue_prompt = processor.apply_chat_template(batch_messages, continue_final_message=True, tokenize=False)
378385
for prompt in continue_prompt:
379386
self.assertTrue(prompt.endswith("It is the sound of")) # no `eos` token at the end
387+
388+
@parameterized.expand([(1,), (2,)])
389+
@require_torch
390+
def test_frames_binding(self, batch_size: int):
391+
texts = [
392+
"<video>\nAre there any cyan objects that enter the scene?\nno",
393+
"<video>\nAre there any red spheres that enter the scene?\nno",
394+
]
395+
frames = torch.ones((4, 448, 448, 3), dtype=torch.float32)
396+
videos = [frames, frames]
397+
398+
processor = self.get_processor()
399+
inputs = processor(
400+
text=texts[:batch_size],
401+
return_tensors="pt",
402+
videos=videos[:batch_size],
403+
videos_kwargs={"size": (448, 448)},
404+
)
405+
406+
actual_num_frames = inputs.pixel_values.shape[0]
407+
expected_num_frames = sum(x.shape[0] for x in videos[:batch_size])
408+
assert actual_num_frames == expected_num_frames

0 commit comments

Comments
 (0)