diff --git a/supervision/utils/video.py b/supervision/utils/video.py index 3b281b4e2..5fb2c4bf9 100644 --- a/supervision/utils/video.py +++ b/supervision/utils/video.py @@ -1,9 +1,11 @@ from __future__ import annotations +import threading import time from collections import deque from collections.abc import Callable, Generator from dataclasses import dataclass +from queue import Queue import cv2 import numpy as np @@ -255,6 +257,130 @@ def callback(scene: np.ndarray, index: int) -> np.ndarray: sink.write_frame(frame=result_frame) +def process_video_threads( + source_path: str, + target_path: str, + callback: Callable[[np.ndarray, int], np.ndarray], + *, + max_frames: int | None = None, + prefetch: int = 32, + writer_buffer: int = 32, + show_progress: bool = False, + progress_message: str = "Processing video (with threads)", +) -> None: + """ + Process a video using a threaded pipeline that asynchronously + reads frames, applies a callback to each, and writes the results + to an output file. + + Overview: + This function implements a three-stage pipeline designed to maximize + frame throughput. + + │ Reader │ >> │ Processor │ >> │ Writer │ + (thread) (main) (thread) + + - Reader thread: reads frames from disk into a bounded queue ('read_q') + until full, then blocks. This ensures we never load more than 'prefetch' + frames into memory at once. + + - Main thread: dequeues frames, applies the 'callback(frame, idx)', + and enqueues the processed result into 'write_q'. + This is the compute stage. It's important to note that it's not threaded, + so you can safely use any detectors, trackers, or other stateful objects + without synchronization issues. + + - Writer thread: dequeues frames and writes them to disk. + + Both queues are bounded to enforce back-pressure: + - The reader cannot outpace processing (avoids unbounded RAM usage). + - The processor cannot outpace writing (avoids output buffer bloat). + + Summary: + - It's thread-safe: because the callback runs only in the main thread, + using a single stateful detector/tracker inside callback does not require + synchronization with the reader/writer threads. + + - While the main thread processes frame N, the reader is already decoding frame N+1, + and the writer is encoding frame N-1. They operate concurrently without blocking + each other. + + - When is it fastest? + - When there's heavy computation in the callback function that releases + the Python GIL (for example, OpenCV filters, resizes, color conversions, ...) + - When using CUDA or GPU-accelerated inference. + + - When is it better not to use it? + - When the callback function is Python-heavy and GIL-bound. In that case, + using a process-based approach is more effective. + + Args: + source_path (str): The path to the source video file. + target_path (str): The path to the target video file. + callback (Callable[[np.ndarray, int], np.ndarray]): A function that takes in + a numpy ndarray representation of a video frame and an + int index of the frame and returns a processed numpy ndarray + representation of the frame. + max_frames (Optional[int]): The maximum number of frames to process. + prefetch (int): The maximum number of frames buffered by the reader thread. + writer_buffer (int): The maximum number of frames buffered before writing. + show_progress (bool): Whether to show a progress bar. + progress_message (str): The message to display in the progress bar. + """ + + source_video_info = VideoInfo.from_video_path(video_path=source_path) + total_frames = ( + min(source_video_info.total_frames, max_frames) + if max_frames is not None + else source_video_info.total_frames + ) + + # Each queue includes frames + sentinel + read_q: Queue[tuple[int, np.ndarray] | None] = Queue(maxsize=prefetch) + write_q: Queue[np.ndarray | None] = Queue(maxsize=writer_buffer) + + def reader_thread(): + gen = get_video_frames_generator(source_path=source_path, end=max_frames) + for idx, frame in enumerate(gen): + read_q.put((idx, frame)) + read_q.put(None) # sentinel + + def writer_thread(video_sink: VideoSink): + while True: + frame = write_q.get() + if frame is None: + break + video_sink.write_frame(frame=frame) + + # Heads up! We set 'daemon=True' so this thread won't block program exit + # if the main thread finishes first. + t_reader = threading.Thread(target=reader_thread, daemon=True) + with VideoSink(target_path=target_path, video_info=source_video_info) as sink: + t_writer = threading.Thread(target=writer_thread, args=(sink,), daemon=True) + t_reader.start() + t_writer.start() + + process_bar = tqdm( + total=total_frames, disable=not show_progress, desc=progress_message + ) + + # Main thread: we take a frame, apply function and update process bar. + while True: + item = read_q.get() + if item is None: + break + idx, frame = item + out = callback(frame, idx) + write_q.put(out) + if total_frames is not None: + process_bar.update(1) + + write_q.put(None) + t_reader.join() + t_writer.join() + process_bar.close() + + class FPSMonitor: """ A class for monitoring frames per second (FPS) to benchmark latency. diff --git a/test/utils/test_process_video.py b/test/utils/test_process_video.py new file mode 100644 index 000000000..2c3c68a9b --- /dev/null +++ b/test/utils/test_process_video.py @@ -0,0 +1,95 @@ +from pathlib import Path + +import cv2 +import numpy as np +import pytest + +import supervision as sv + + +def make_video( + path: Path, w: int = 160, h: int = 96, fps: int = 20, frames: int = 24 +) -> None: + """Create a small synthetic test video with predictable frame-colors.""" + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + writer = cv2.VideoWriter(str(path), fourcc, fps, (w, h)) + assert writer.isOpened(), "Failed to open VideoWriter" + for i in range(frames): + v = (i * 11) % 250 + frame = np.full((h, w, 3), (v, 255 - v, (2 * v) % 255), np.uint8) + writer.write(frame) + writer.release() + + +def read_frames(path: Path) -> list[np.ndarray]: + """Read all frames from a video into memory.""" + cap = cv2.VideoCapture(str(path)) + assert cap.isOpened(), f"Cannot open video: {path}" + out = [] + while True: + ok, frame = cap.read() + if not ok: + break + out.append(frame) + cap.release() + return out + + +def frames_equal(a: np.ndarray, b: np.ndarray, max_abs_tol: int = 0) -> bool: + """Return True if frames are the same within acertain tolerance.""" + if a.shape != b.shape: + return False + diff = np.abs(a.astype(np.int16) - b.astype(np.int16)) + return diff.max() <= max_abs_tol + + +def callback_noop(frame: np.ndarray, idx: int) -> np.ndarray: + """No-op callback: validates pure pipeline correctness.""" + return frame + + +def callbackb_opencv(frame: np.ndarray, idx: int) -> np.ndarray: + """ + Simulations some cv2 task... + """ + g = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + return cv2.cvtColor(g, cv2.COLOR_GRAY2BGR) + + +@pytest.mark.parametrize( + "callback", [callback_noop, callbackb_opencv], ids=["identity", "opencv"] +) +def test_process_video_vs_threads_same_output(callback, tmp_path: Path): + """ + Ensure that process_video() and process_video_threads() produce identical + results for the same synthetic source video and callback. + """ + name = callback.__name__ + src = tmp_path / f"src_{name}.mp4" + dst_single = tmp_path / f"out_single_{name}.mp4" + dst_threads = tmp_path / f"out_threads_{name}.mp4" + + make_video(src, frames=24) + + sv.utils.video.process_video( + source_path=str(src), + target_path=str(dst_single), + callback=callback, + show_progress=False, + ) + sv.utils.video.process_video_threads( + source_path=str(src), + target_path=str(dst_threads), + callback=callback, + prefetch=4, + writer_buffer=4, + show_progress=False, + ) + + frames_single = read_frames(dst_single) + frames_threads = read_frames(dst_threads) + + assert len(frames_single) == len(frames_threads) != 0, "Frame count mismatch." + + for i, (fs, ft) in enumerate(zip(frames_single, frames_threads)): + assert frames_equal(fs, ft), f"Frame {i} is different."