Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ function(make_torchcodec_libraries
Encoder.cpp
ValidationUtils.cpp
Transform.cpp
Metadata.cpp
)

if(ENABLE_CUDA)
Expand Down
110 changes: 110 additions & 0 deletions src/torchcodec/_core/Metadata.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "Metadata.h"

namespace facebook::torchcodec {

std::optional<double> StreamMetadata::getDurationSeconds(
SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
// In exact mode, use the scanned content value
if (endStreamPtsSecondsFromContent.has_value() &&
beginStreamPtsSecondsFromContent.has_value()) {
return endStreamPtsSecondsFromContent.value() -
beginStreamPtsSecondsFromContent.value();
}
return std::nullopt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the benefits of using the seek mode is that we can know that a scan should have been performed, and the metadata from the scan should definitely be there. I think we should instead make this:

  case SeekMode:exact:
    return endStreamPtsSecondsFromContent.value() -
            beginStreamPtsSecondsFromContent.value();

That is, we don't check if the value is present, and we don't return nullopt. The unconditional call to .value() will throw an exception if we somehow get into a situation where we're in exact seek mode and there is no value. But, that is what we want. Such a situation should never happen, and we want an exception thrown.

Copy link
Contributor Author

@mollyxu mollyxu Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you recommend modifying the logic such that exact mode is accessed directly but we keep the checks for custom_frame_mappings? Previously the Python logic didn't have knowledge of seek mode so custom_frame_mappings was oftentimes grouped with exact because scans are performed.

In custom_frame_mappings a scan only happens to specific streams upon add_video_stream() is called (ie. test_get_metadata in test_metadata.py). Removing the checks would lead to cases in which not all streams in the container would have scanned metadata.

Or should custom_frame_mappings that's missing the scanned information behave the same as approximate mode?

Copy link
Contributor

@Dan-Flores Dan-Flores Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only allow one stream to be added and scanned via add_video_stream, and I believe (but do correct me if I am mistaken!) when we access any stream metadata, its from the stream currently stored as activeStreamIndex_. This stream index is updated by the custom_frame_mappings equivalent "scan" function.

Removing the checks would lead to cases in which not all streams in the container would have scanned metadata.

Is it possible to access the stream metadata of a stream other than activeStreamIndex_? If not, it should be safe to assume that the metadata will be present, and treat it as exact mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we access any stream metadata, its from the stream currently stored as activeStreamIndex_

I think in get_container_metadata.py in _metadata.py we call get_stream_json_metadata for each stream in the container. It is at this step that we are triggering the logic in Metadata.cpp for custom_frame_mappings that doesn't exist for streams that are at the activeStreamIndex_

for stream_index in range(container_dict["numStreams"]):
stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index))

The error seems to come not from accessing a non-active stream but rather from how we store metadata.

case SeekMode::approximate:
if (durationSecondsFromHeader.has_value()) {
return durationSecondsFromHeader.value();
}
if (numFramesFromHeader.has_value() && averageFpsFromHeader.has_value() &&
averageFpsFromHeader.value() != 0.0) {
return static_cast<double>(numFramesFromHeader.value()) /
averageFpsFromHeader.value();
}
return std::nullopt;
}
return std::nullopt;
}

double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
if (beginStreamPtsSecondsFromContent.has_value()) {
return beginStreamPtsSecondsFromContent.value();
}
return 0.0;
case SeekMode::approximate:
return 0.0;
}
return 0.0;
}

std::optional<double> StreamMetadata::getEndStreamSeconds(
SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
if (endStreamPtsSecondsFromContent.has_value()) {
return endStreamPtsSecondsFromContent.value();
}
return getDurationSeconds(seekMode);
case SeekMode::approximate:
return getDurationSeconds(seekMode);
}
return std::nullopt;
}

std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
if (numFramesFromContent.has_value()) {
return numFramesFromContent.value();
}
return std::nullopt;
case SeekMode::approximate: {
if (numFramesFromHeader.has_value()) {
return numFramesFromHeader.value();
}
if (averageFpsFromHeader.has_value() &&
durationSecondsFromHeader.has_value()) {
return static_cast<int64_t>(
averageFpsFromHeader.value() * durationSecondsFromHeader.value());
}
return std::nullopt;
}
}
return std::nullopt;
}

std::optional<double> StreamMetadata::getAverageFps(SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
if (getNumFrames(seekMode).has_value() &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's define:

auto numFrames = getNumFrames(seekMode);

And use it inside the expressions.

beginStreamPtsSecondsFromContent.has_value() &&
endStreamPtsSecondsFromContent.has_value() &&
(beginStreamPtsSecondsFromContent.value() !=
endStreamPtsSecondsFromContent.value())) {
return static_cast<double>(
getNumFrames(seekMode).value() /
(endStreamPtsSecondsFromContent.value() -
beginStreamPtsSecondsFromContent.value()));
}
return averageFpsFromHeader;
case SeekMode::approximate:
return averageFpsFromHeader;
}
return std::nullopt;
}

} // namespace facebook::torchcodec
9 changes: 9 additions & 0 deletions src/torchcodec/_core/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ extern "C" {

namespace facebook::torchcodec {

enum class SeekMode { exact, approximate, custom_frame_mappings };

struct StreamMetadata {
// Common (video and audio) fields derived from the AVStream.
int streamIndex;
Expand Down Expand Up @@ -52,6 +54,13 @@ struct StreamMetadata {
std::optional<int64_t> sampleRate;
std::optional<int64_t> numChannels;
std::optional<std::string> sampleFormat;

// Computed methods with fallback logic
std::optional<double> getDurationSeconds(SeekMode seekMode) const;
double getBeginStreamSeconds(SeekMode seekMode) const;
std::optional<double> getEndStreamSeconds(SeekMode seekMode) const;
std::optional<int64_t> getNumFrames(SeekMode seekMode) const;
std::optional<double> getAverageFps(SeekMode seekMode) const;
};

struct ContainerMetadata {
Expand Down
61 changes: 13 additions & 48 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
return containerMetadata_;
}

SeekMode SingleStreamDecoder::getSeekMode() const {
return seekMode_;
}

torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
validateScannedAllStreams("getKeyFrameIndices");
Expand Down Expand Up @@ -611,7 +615,7 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];

std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
if (numFrames.has_value()) {
// If the frameIndex is negative, we convert it to a positive index
frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value();
Expand Down Expand Up @@ -705,7 +709,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(

// Note that if we do not have the number of frames available in our
// metadata, then we assume that the upper part of the range is valid.
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
if (numFrames.has_value()) {
TORCH_CHECK(
stop <= numFrames.value(),
Expand Down Expand Up @@ -779,8 +783,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];

double minSeconds = getMinSeconds(streamMetadata);
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
std::optional<double> maxSeconds =
streamMetadata.getEndStreamSeconds(seekMode_);

// The frame played at timestamp t and the one played at timestamp `t +
// eps` are probably the same frame, with the same index. The easiest way to
Expand Down Expand Up @@ -857,7 +862,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
return frameBatchOutput;
}

double minSeconds = getMinSeconds(streamMetadata);
double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
TORCH_CHECK(
startSeconds >= minSeconds,
"Start seconds is " + std::to_string(startSeconds) +
Expand All @@ -866,7 +871,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(

// Note that if we can't determine the maximum seconds from the metadata,
// then we assume upper range is valid.
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
std::optional<double> maxSeconds =
streamMetadata.getEndStreamSeconds(seekMode_);
if (maxSeconds.has_value()) {
TORCH_CHECK(
startSeconds < maxSeconds.value(),
Expand Down Expand Up @@ -1439,47 +1445,6 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
// STREAM AND METADATA APIS
// --------------------------------------------------------------------------

std::optional<int64_t> SingleStreamDecoder::getNumFrames(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamMetadata.numFramesFromContent.value();
case SeekMode::approximate: {
return streamMetadata.numFramesFromHeader;
}
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

double SingleStreamDecoder::getMinSeconds(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamMetadata.beginStreamPtsSecondsFromContent.value();
case SeekMode::approximate:
return 0;
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

std::optional<double> SingleStreamDecoder::getMaxSeconds(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamMetadata.endStreamPtsSecondsFromContent.value();
case SeekMode::approximate: {
return streamMetadata.durationSecondsFromHeader;
}
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

// --------------------------------------------------------------------------
// VALIDATION UTILS
// --------------------------------------------------------------------------
Expand Down Expand Up @@ -1529,7 +1494,7 @@ void SingleStreamDecoder::validateFrameIndex(

// Note that if we do not have the number of frames available in our
// metadata, then we assume that the frameIndex is valid.
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
if (numFrames.has_value()) {
if (frameIndex >= numFrames.value()) {
throw std::out_of_range(
Expand Down
10 changes: 4 additions & 6 deletions src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "DeviceInterface.h"
#include "FFMPEGCommon.h"
#include "Frame.h"
#include "Metadata.h"
#include "StreamOptions.h"
#include "Transform.h"

Expand All @@ -30,8 +31,6 @@ class SingleStreamDecoder {
// CONSTRUCTION API
// --------------------------------------------------------------------------

enum class SeekMode { exact, approximate, custom_frame_mappings };

// Creates a SingleStreamDecoder from the video at videoFilePath.
explicit SingleStreamDecoder(
const std::string& videoFilePath,
Expand Down Expand Up @@ -60,6 +59,9 @@ class SingleStreamDecoder {
// Returns the metadata for the container.
ContainerMetadata getContainerMetadata() const;

// Returns the seek mode of this decoder.
SeekMode getSeekMode() const;

// Returns the key frame indices as a tensor. The tensor is 1D and contains
// int64 values, where each value is the frame index for a key frame.
torch::Tensor getKeyFrameIndices();
Expand Down Expand Up @@ -312,10 +314,6 @@ class SingleStreamDecoder {
// index. Note that this index may be truncated for some files.
int getBestStreamIndex(AVMediaType mediaType);

std::optional<int64_t> getNumFrames(const StreamMetadata& streamMetadata);
double getMinSeconds(const StreamMetadata& streamMetadata);
std::optional<double> getMaxSeconds(const StreamMetadata& streamMetadata);

// --------------------------------------------------------------------------
// VALIDATION UTILS
// --------------------------------------------------------------------------
Expand Down
Loading
Loading