diff --git a/supervision/utils/video.py b/supervision/utils/video.py index 3b281b4e2..0623e1193 100644 --- a/supervision/utils/video.py +++ b/supervision/utils/video.py @@ -230,29 +230,24 @@ def callback(scene: np.ndarray, index: int) -> np.ndarray: ``` """ source_video_info = VideoInfo.from_video_path(video_path=source_path) + max_frames = max_frames or source_video_info.total_frames + if source_video_info.total_frames is not None and max_frames is not None: + max_frames = min(max_frames, source_video_info.total_frames) + video_frames_generator = get_video_frames_generator( source_path=source_path, end=max_frames ) with VideoSink(target_path=target_path, video_info=source_video_info) as sink: - total_frames = ( - min(source_video_info.total_frames, max_frames) - if max_frames is not None - else source_video_info.total_frames - ) for index, frame in enumerate( tqdm( video_frames_generator, - total=total_frames, + total=max_frames, disable=not show_progress, desc=progress_message, ) ): result_frame = callback(frame, index) sink.write_frame(frame=result_frame) - else: - for index, frame in enumerate(video_frames_generator): - result_frame = callback(frame, index) - sink.write_frame(frame=result_frame) class FPSMonitor: diff --git a/test/utils/test_video.py b/test/utils/test_video.py new file mode 100644 index 000000000..1343154f1 --- /dev/null +++ b/test/utils/test_video.py @@ -0,0 +1,29 @@ +import cv2 +import numpy as np + +from supervision.utils.video import process_video + + +def create_test_video(path, num_frames, width=20, height=10): + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter(path, fourcc, 1.0, (width, height)) + + for _ in range(num_frames): + frame = np.zeros((height, width, 3), dtype=np.uint8) + out.write(frame) + + out.release() + + +def test_process_video_max_frames_exceeds_total_frames(tmp_path): + source_path = tmp_path / "source.mp4" + target_path = tmp_path / "target.mp4" + + create_test_video(str(source_path), num_frames=5) + + process_video( + source_path=str(source_path), + target_path=str(target_path), + callback=lambda frame, _: frame, + max_frames=10, + )