Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ function(make_torchcodec_libraries
Encoder.cpp
ValidationUtils.cpp
Transform.cpp
Metadata.cpp
)

if(ENABLE_CUDA)
Expand Down
121 changes: 121 additions & 0 deletions src/torchcodec/_core/Metadata.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "Metadata.h"
#include "torch/types.h"

namespace facebook::torchcodec {

std::optional<double> StreamMetadata::getDurationSeconds(
SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
TORCH_CHECK(
endStreamPtsSecondsFromContent.has_value() &&
beginStreamPtsSecondsFromContent.has_value(),
"Missing beginStreamPtsSecondsFromContent or endStreamPtsSecondsFromContent");
return endStreamPtsSecondsFromContent.value() -
beginStreamPtsSecondsFromContent.value();
case SeekMode::approximate:
if (durationSecondsFromHeader.has_value()) {
return durationSecondsFromHeader.value();
}
if (numFramesFromHeader.has_value() && averageFpsFromHeader.has_value() &&
averageFpsFromHeader.value() != 0.0) {
return static_cast<double>(numFramesFromHeader.value()) /
averageFpsFromHeader.value();
}
return std::nullopt;
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
TORCH_CHECK(
beginStreamPtsSecondsFromContent.has_value(),
"Missing beginStreamPtsSecondsFromContent");
return beginStreamPtsSecondsFromContent.value();
case SeekMode::approximate:
if (beginStreamPtsSecondsFromContent.has_value()) {
return beginStreamPtsSecondsFromContent.value();
}
return 0.0;
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

std::optional<double> StreamMetadata::getEndStreamSeconds(
SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
TORCH_CHECK(
endStreamPtsSecondsFromContent.has_value(),
"Missing endStreamPtsSecondsFromContent");
return endStreamPtsSecondsFromContent.value();
case SeekMode::approximate:
if (endStreamPtsSecondsFromContent.has_value()) {
return endStreamPtsSecondsFromContent.value();
}
return getDurationSeconds(seekMode);
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
TORCH_CHECK(
numFramesFromContent.has_value(), "Missing numFramesFromContent");
return numFramesFromContent.value();
case SeekMode::approximate: {
if (numFramesFromHeader.has_value()) {
return numFramesFromHeader.value();
}
if (averageFpsFromHeader.has_value() &&
durationSecondsFromHeader.has_value()) {
return static_cast<int64_t>(
averageFpsFromHeader.value() * durationSecondsFromHeader.value());
}
return std::nullopt;
}
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

std::optional<double> StreamMetadata::getAverageFps(SeekMode seekMode) const {
switch (seekMode) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact: {
auto numFrames = getNumFrames(seekMode);
if (numFrames.has_value() &&
beginStreamPtsSecondsFromContent.has_value() &&
endStreamPtsSecondsFromContent.has_value()) {
double duration = endStreamPtsSecondsFromContent.value() -
beginStreamPtsSecondsFromContent.value();
if (duration != 0.0) {
return static_cast<double>(numFrames.value()) / duration;
}
}
return averageFpsFromHeader;
}
case SeekMode::approximate:
return averageFpsFromHeader;
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

} // namespace facebook::torchcodec
9 changes: 9 additions & 0 deletions src/torchcodec/_core/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ extern "C" {

namespace facebook::torchcodec {

enum class SeekMode { exact, approximate, custom_frame_mappings };

struct StreamMetadata {
// Common (video and audio) fields derived from the AVStream.
int streamIndex;
Expand Down Expand Up @@ -52,6 +54,13 @@ struct StreamMetadata {
std::optional<int64_t> sampleRate;
std::optional<int64_t> numChannels;
std::optional<std::string> sampleFormat;

// Computed methods with fallback logic
std::optional<double> getDurationSeconds(SeekMode seekMode) const;
double getBeginStreamSeconds(SeekMode seekMode) const;
std::optional<double> getEndStreamSeconds(SeekMode seekMode) const;
std::optional<int64_t> getNumFrames(SeekMode seekMode) const;
std::optional<double> getAverageFps(SeekMode seekMode) const;
};

struct ContainerMetadata {
Expand Down
65 changes: 17 additions & 48 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
return containerMetadata_;
}

SeekMode SingleStreamDecoder::getSeekMode() const {
return seekMode_;
}

int SingleStreamDecoder::getActiveStreamIndex() const {
return activeStreamIndex_;
}

torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
validateScannedAllStreams("getKeyFrameIndices");
Expand Down Expand Up @@ -611,7 +619,7 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];

std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
if (numFrames.has_value()) {
// If the frameIndex is negative, we convert it to a positive index
frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value();
Expand Down Expand Up @@ -705,7 +713,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(

// Note that if we do not have the number of frames available in our
// metadata, then we assume that the upper part of the range is valid.
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
if (numFrames.has_value()) {
TORCH_CHECK(
stop <= numFrames.value(),
Expand Down Expand Up @@ -779,8 +787,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];

double minSeconds = getMinSeconds(streamMetadata);
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
std::optional<double> maxSeconds =
streamMetadata.getEndStreamSeconds(seekMode_);

// The frame played at timestamp t and the one played at timestamp `t +
// eps` are probably the same frame, with the same index. The easiest way to
Expand Down Expand Up @@ -857,7 +866,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
return frameBatchOutput;
}

double minSeconds = getMinSeconds(streamMetadata);
double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
TORCH_CHECK(
startSeconds >= minSeconds,
"Start seconds is " + std::to_string(startSeconds) +
Expand All @@ -866,7 +875,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(

// Note that if we can't determine the maximum seconds from the metadata,
// then we assume upper range is valid.
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
std::optional<double> maxSeconds =
streamMetadata.getEndStreamSeconds(seekMode_);
if (maxSeconds.has_value()) {
TORCH_CHECK(
startSeconds < maxSeconds.value(),
Expand Down Expand Up @@ -1439,47 +1449,6 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
// STREAM AND METADATA APIS
// --------------------------------------------------------------------------

std::optional<int64_t> SingleStreamDecoder::getNumFrames(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamMetadata.numFramesFromContent.value();
case SeekMode::approximate: {
return streamMetadata.numFramesFromHeader;
}
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

double SingleStreamDecoder::getMinSeconds(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamMetadata.beginStreamPtsSecondsFromContent.value();
case SeekMode::approximate:
return 0;
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

std::optional<double> SingleStreamDecoder::getMaxSeconds(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::custom_frame_mappings:
case SeekMode::exact:
return streamMetadata.endStreamPtsSecondsFromContent.value();
case SeekMode::approximate: {
return streamMetadata.durationSecondsFromHeader;
}
default:
TORCH_CHECK(false, "Unknown SeekMode");
}
}

// --------------------------------------------------------------------------
// VALIDATION UTILS
// --------------------------------------------------------------------------
Expand Down Expand Up @@ -1529,7 +1498,7 @@ void SingleStreamDecoder::validateFrameIndex(

// Note that if we do not have the number of frames available in our
// metadata, then we assume that the frameIndex is valid.
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
if (numFrames.has_value()) {
if (frameIndex >= numFrames.value()) {
throw std::out_of_range(
Expand Down
13 changes: 7 additions & 6 deletions src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "DeviceInterface.h"
#include "FFMPEGCommon.h"
#include "Frame.h"
#include "Metadata.h"
#include "StreamOptions.h"
#include "Transform.h"

Expand All @@ -30,8 +31,6 @@ class SingleStreamDecoder {
// CONSTRUCTION API
// --------------------------------------------------------------------------

enum class SeekMode { exact, approximate, custom_frame_mappings };

// Creates a SingleStreamDecoder from the video at videoFilePath.
explicit SingleStreamDecoder(
const std::string& videoFilePath,
Expand Down Expand Up @@ -60,6 +59,12 @@ class SingleStreamDecoder {
// Returns the metadata for the container.
ContainerMetadata getContainerMetadata() const;

// Returns the seek mode of this decoder.
SeekMode getSeekMode() const;

// Returns the active stream index. Returns -2 if no stream is active.
int getActiveStreamIndex() const;

// Returns the key frame indices as a tensor. The tensor is 1D and contains
// int64 values, where each value is the frame index for a key frame.
torch::Tensor getKeyFrameIndices();
Expand Down Expand Up @@ -312,10 +317,6 @@ class SingleStreamDecoder {
// index. Note that this index may be truncated for some files.
int getBestStreamIndex(AVMediaType mediaType);

std::optional<int64_t> getNumFrames(const StreamMetadata& streamMetadata);
double getMinSeconds(const StreamMetadata& streamMetadata);
std::optional<double> getMaxSeconds(const StreamMetadata& streamMetadata);

// --------------------------------------------------------------------------
// VALIDATION UTILS
// --------------------------------------------------------------------------
Expand Down
Loading
Loading