Skip to content

Conversation

@NicolasHug
Copy link
Contributor

@NicolasHug NicolasHug commented Nov 7, 2025

We have a "skip seeking" logic where we try to minimize the number of seeks we have to do. This logic lives in

bool SingleStreamDecoder::canWeAvoidSeeking() const {

The problem is: canWeAvoidSeeking() itself can be expensive to call!! This is especially true in approximate mode which, as I just found out, can be slower than exact mode for short videos (10s long).

In this PR, we skip the call to canWeAvoidSeeking() when we can - yes, we "skip the skip-checking logic"! We can do that when we are decoding frames contiguously: 0, 1, 2, 3....

This will provide significant speedups when:

  • seek_mode="approximate"
  • frames are decoded contiguously. That can happen in both get_frames_at() and get_frames_played_at().

Why is canWeAvoidSeeking() slow?

Because it calls getKeyFrameIndexForPts() which, in approximate mode, calls av_index_search_timestamp(). Calling this for all frames can dominate the runtime!

Benchmarks:

Decoding all 300 frames of a short 10s long h264 720p video (testsrc2), approximate mode goes from 1249.56ms to 825.74ms (1.5X faster).

~/dev/torchcodec-cuda (avoid_seeking_checks*) » python ~/benchmark_torchcodec_decord.py ~/videos_h264/ --sampling all --num-threads 1   nicolashug@nicolashug-fedora-PW0H326Y
torchcodec.__version__ = '0.9.0a0+afd5aba'
videos: 720x1280, 30.0 fps, 300 frames long
Using 1 thread(s), averaging over 10 runs


# This PR:
=== TorchCodec approx ===
med = 825.74ms +- 1.31, max = 828.85ms
=== TorchCodec exact ===
med = 828.05ms +- 3.69, max = 836.18ms

# On main
=== TorchCodec approx ===
med = 1249.56ms +- 1.81, max = 1253.79ms
=== TorchCodec exact ===
med = 832.75ms +- 0.98, max = 835.03ms

messy (but correct) benchmarking code:

import argparse
from pathlib import Path
from time import perf_counter_ns

import decord
import psutil
import torch
from joblib import delayed, Parallel
import torchcodec
from torchcodec.decoders import VideoDecoder


def bench(f, *args, num_exp=100, warmup=0, **kwargs):
    process = psutil.Process()

    for _ in range(warmup):
        f(*args, **kwargs)

    times = []
    cpu_utils = []
    mem_usages = []

    for _ in range(num_exp):
        psutil.cpu_percent(interval=None)

        start = perf_counter_ns()
        f(*args, **kwargs)
        end = perf_counter_ns()

        cpu_util = psutil.cpu_percent(interval=None)  # since last call
        mem_end = process.memory_info().rss

        times.append(end - start)
        cpu_utils.append(cpu_util)
        mem_usages.append(mem_end)

    return torch.tensor(times).float(), torch.tensor(cpu_utils).float(), torch.tensor(mem_usages).float()


def report_stats(times, cpu_utils=None, mem_usages=None, unit="ms"):
    mul = {
        "ns": 1,
        "µs": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
    }[unit]
    times = times * mul
    std = times.std().item()
    med = times.median().item()
    max = times.max().item()
    print(f"{med = :.2f}{unit} +- {std:.2f}, {max = :.2f}{unit}")

    if cpu_utils is not None:
        cpu_avg = cpu_utils.mean().item()
        cpu_peak = cpu_utils.max().item()
        print(f"CPU utilization: avg = {cpu_avg:.1f}%, peak = {cpu_peak:.1f}%")

    if mem_usages is not None:
        mem_gb = mem_usages / (1024 ** 3)
        mem_peak = mem_gb.max().item()
        mem_min = mem_gb.min().item()
        mem_delta = mem_peak - mem_min
        print(f"Memory: peak = {mem_peak:.2f}GB, delta = +{mem_delta:.2f}GB")


def decode_one_video_torchcodec(video_path, seek_mode="approximate"):
    decoder = VideoDecoder(str(video_path), device="cpu", seek_mode=seek_mode, num_ffmpeg_threads=1)
    return decoder.get_frames_at(indices)

def decode_one_video_decord(video_path):
    vr = decord.VideoReader(str(video_path), ctx=decord.cpu(), num_threads=1)
    return vr.get_batch(indices.tolist())

def decode_videos(library="torchcodec"):
    if library == "torchcodec":
        decode_one_video = decode_one_video_torchcodec
    elif library == "decord":
        decode_one_video = decode_one_video_decord
    else:
        raise ValueError(f"Unknown library: {library}")

    Parallel(n_jobs=args.num_threads, prefer="threads")(
        delayed(decode_one_video)(video_path) for video_path in video_files
    )


def validate(video_path):
    out_tc = decode_one_video_torchcodec(video_path)
    out_dc = decode_one_video_decord(video_path)

    torch.testing.assert_close(out_tc.data, (out_dc).permute(0, 3, 1, 2), rtol=0, atol=0)
    print("outputs are the same!")


NUM_EXP = 10
parser = argparse.ArgumentParser()
parser.add_argument("video_folder", help="Folder containing .h264 files")
parser.add_argument(
    "--sampling",
    type=str,
    default="all",
    help="Sampling strategy. 'all' for all frames, or an N (int) for N evenly spaced frames.",
)
parser.add_argument(
    "--num-threads",
    type=int,
    default=1,
    help="Number of threads to spawn. Each thread decodes one single video.",
)
args = parser.parse_args()

video_files = list(Path(args.video_folder).glob("*.mp4"))

# We kinda assume all the videos in the folder have the same number of frames
dummy_dec = VideoDecoder(str(video_files[0]), device="cpu")
if str(args.sampling).startswith("first"):
    num_frames_to_samples = int(args.sampling[len("first") :])
    indices = torch.arange(num_frames_to_samples)
elif args.sampling == "all":
    indices = torch.arange(len(dummy_dec))
else:
    num_frames_to_samples = int(args.sampling)
    indices = torch.linspace(
        0, len(dummy_dec) - 1, num_frames_to_samples, dtype=torch.int
    )

decord.bridge.set_bridge("torch")
# validate(video_files[0])

print(f"{torchcodec.__version__ = }")

# print(
#     f"Decoding {args.sampling} frames from {len(video_files)} video files in {args.video_folder}"
# )
print(
    f"videos: {dummy_dec.metadata.height}x{dummy_dec.metadata.width}, {dummy_dec.metadata.average_fps} fps, {dummy_dec.metadata.num_frames} frames long"
)
print(f"Using {args.num_threads} thread(s), averaging over {NUM_EXP} runs")

# print("\n=== TorchCodec ===")
# times_tc, cpu_utils_tc, mem_usages_tc = bench(decode_videos, library="torchcodec", warmup=1, num_exp=NUM_EXP)
# report_stats(times_tc, cpu_utils_tc, mem_usages_tc)

# print("\n=== Decord ===")
# times_dc, cpu_utils_dc, mem_usages_dc = bench(decode_videos, library="decord", warmup=1, num_exp=NUM_EXP)
# report_stats(times_dc, cpu_utils_dc, mem_usages_dc)


print("\n=== TorchCodec approx ===")
times_tc, cpu_utils_tc, mem_usages_tc = bench(decode_one_video_torchcodec, video_path=video_files[0], seek_mode="approximate", warmup=1,  num_exp=NUM_EXP)
report_stats(times_tc, cpu_utils_tc, mem_usages_tc)

print("\n=== TorchCodec exact ===")
times_tc, cpu_utils_tc, mem_usages_tc = bench(decode_one_video_torchcodec, video_path=video_files[0], seek_mode="exact", warmup=1, num_exp=NUM_EXP)
report_stats(times_tc, cpu_utils_tc, mem_usages_tc)

# print("\n=== Decord ===")
# times_dc, cpu_utils_dc, mem_usages_dc = bench(decode_one_video_decord, video_path=video_files[0], warmup=1, num_exp=NUM_EXP)
# report_stats(times_dc, cpu_utils_dc, mem_usages_dc)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 7, 2025

auto result = getNextFrameInternal(preAllocatedOutputTensor);
lastDecodedFrameIndex_ = frameIndex;
return result;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for myself: explain why that's crucially important.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that changing this logic avoids calling canWeAvoidSeeking() because this logic doesn't actually call that function. Does not updating the cursor change the conditions for which we call canWeAvoidSeeking()? Is it maybe clearer to directly update those conditions? These calls don't seem that expensive (although of course they'll cost something). I'm a little worried that we have some assumption elsewhere that we're always updating the cursor, and now we're not.

Copy link
Contributor Author

@NicolasHug NicolasHug Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very indirect and yes, based on assumptions.

canWeAvoidSeeking() is only called if the cursor was set:

if (cursorWasJustSet_) {
maybeSeekToBeforeDesiredPts();
cursorWasJustSet_ = false;
}

and the cursor is only set if we call setCursorPtsInSeconds.

I'm happy to revisit the core logic of how we handle this. TBH, I'm a little bit into a speed-running mode at the moment.

Half of the reason I'm submitting the PR as-is is to check the CI and make this issue visible as early as possible.

@scotts
Copy link
Contributor

scotts commented Nov 7, 2025

Digging into canWeAvoidSeeking() more, we already have a few optimizations where we back out before we do the mapping. But it also seems like the whole point of the mapping is to know the index of the last decoded frame:

int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts_);

Can't we now replace that call with your newly added lastDecodedFrameIndex_? That would actually speed up all scenarios. We may be able to still apply parts of this "skip if sequential", but I think you've now made the most expensive step unneeded in general.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants