From 8bee8d713df781db517254df31288d268bee5aad Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Mon, 3 Nov 2025 22:06:12 -0800 Subject: [PATCH 01/13] refactor metadata fallback logic --- src/torchcodec/_core/CMakeLists.txt | 1 + src/torchcodec/_core/Metadata.cpp | 110 ++++++++++++++++ src/torchcodec/_core/Metadata.h | 9 ++ src/torchcodec/_core/SeekMode.h | 13 ++ src/torchcodec/_core/SingleStreamDecoder.cpp | 4 + src/torchcodec/_core/SingleStreamDecoder.h | 6 +- src/torchcodec/_core/_metadata.py | 125 ++++--------------- src/torchcodec/_core/custom_ops.cpp | 34 +++-- test/test_decoders.py | 5 + test/test_metadata.py | 125 +------------------ test/test_samplers.py | 2 + 11 files changed, 207 insertions(+), 227 deletions(-) create mode 100644 src/torchcodec/_core/Metadata.cpp create mode 100644 src/torchcodec/_core/SeekMode.h diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 6b4ccb5d4..825331840 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -96,6 +96,7 @@ function(make_torchcodec_libraries Encoder.cpp ValidationUtils.cpp Transform.cpp + Metadata.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp new file mode 100644 index 000000000..41e41bd5b --- /dev/null +++ b/src/torchcodec/_core/Metadata.cpp @@ -0,0 +1,110 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "Metadata.h" + +namespace facebook::torchcodec { + +std::optional StreamMetadata::getDurationSeconds( + SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + // In exact mode, use the scanned content value + if (endStreamPtsSecondsFromContent.has_value() && + beginStreamPtsSecondsFromContent.has_value()) { + return endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value(); + } + return std::nullopt; + case SeekMode::approximate: + if (durationSecondsFromHeader.has_value()) { + return durationSecondsFromHeader.value(); + } + if (numFramesFromHeader.has_value() && averageFpsFromHeader.has_value() && + averageFpsFromHeader.value() != 0.0) { + return static_cast(numFramesFromHeader.value()) / + averageFpsFromHeader.value(); + } + return std::nullopt; + } + return std::nullopt; +} + +double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + if (beginStreamPtsSecondsFromContent.has_value()) { + return beginStreamPtsSecondsFromContent.value(); + } + return 0.0; + case SeekMode::approximate: + return 0.0; + } + return 0.0; +} + +std::optional StreamMetadata::getEndStreamSeconds( + SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + if (endStreamPtsSecondsFromContent.has_value()) { + return endStreamPtsSecondsFromContent.value(); + } + return getDurationSeconds(seekMode); + case SeekMode::approximate: + return getDurationSeconds(seekMode); + } + return std::nullopt; +} + +std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + if (numFramesFromContent.has_value()) { + return numFramesFromContent.value(); + } + return std::nullopt; + case SeekMode::approximate: { + if (numFramesFromHeader.has_value()) { + return numFramesFromHeader.value(); + } + if (averageFpsFromHeader.has_value() && + durationSecondsFromHeader.has_value()) { + return static_cast( + averageFpsFromHeader.value() * durationSecondsFromHeader.value()); + } + return std::nullopt; + } + } + return std::nullopt; +} + +std::optional StreamMetadata::getAverageFps(SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + if (getNumFrames(seekMode).has_value() && + beginStreamPtsSecondsFromContent.has_value() && + endStreamPtsSecondsFromContent.has_value() && + (beginStreamPtsSecondsFromContent.value() != + endStreamPtsSecondsFromContent.value())) { + return static_cast( + getNumFrames(seekMode).value() / + (endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value())); + } + return averageFpsFromHeader; + case SeekMode::approximate: + return averageFpsFromHeader; + } + return std::nullopt; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Metadata.h b/src/torchcodec/_core/Metadata.h index ace6cf84c..57da8c989 100644 --- a/src/torchcodec/_core/Metadata.h +++ b/src/torchcodec/_core/Metadata.h @@ -10,6 +10,8 @@ #include #include +#include "SeekMode.h" + extern "C" { #include #include @@ -52,6 +54,13 @@ struct StreamMetadata { std::optional sampleRate; std::optional numChannels; std::optional sampleFormat; + + // Computed methods with fallback logic + std::optional getDurationSeconds(SeekMode seekMode) const; + double getBeginStreamSeconds(SeekMode seekMode) const; + std::optional getEndStreamSeconds(SeekMode seekMode) const; + std::optional getNumFrames(SeekMode seekMode) const; + std::optional getAverageFps(SeekMode seekMode) const; }; struct ContainerMetadata { diff --git a/src/torchcodec/_core/SeekMode.h b/src/torchcodec/_core/SeekMode.h new file mode 100644 index 000000000..fa30414ed --- /dev/null +++ b/src/torchcodec/_core/SeekMode.h @@ -0,0 +1,13 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +namespace facebook::torchcodec { + +enum class SeekMode { exact, approximate, custom_frame_mappings }; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 72cd7afac..6778331f9 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -367,6 +367,10 @@ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { return containerMetadata_; } +SeekMode SingleStreamDecoder::getSeekMode() const { + return seekMode_; +} + torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { validateActiveStream(AVMEDIA_TYPE_VIDEO); validateScannedAllStreams("getKeyFrameIndices"); diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 4b41811ff..dd1966d7b 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -16,6 +16,7 @@ #include "DeviceInterface.h" #include "FFMPEGCommon.h" #include "Frame.h" +#include "SeekMode.h" #include "StreamOptions.h" #include "Transform.h" @@ -30,8 +31,6 @@ class SingleStreamDecoder { // CONSTRUCTION API // -------------------------------------------------------------------------- - enum class SeekMode { exact, approximate, custom_frame_mappings }; - // Creates a SingleStreamDecoder from the video at videoFilePath. explicit SingleStreamDecoder( const std::string& videoFilePath, @@ -60,6 +59,9 @@ class SingleStreamDecoder { // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; + // Returns the seek mode of this decoder. + SeekMode getSeekMode() const; + // Returns the key frame indices as a tensor. The tensor is 1D and contains // int64 values, where each value is the frame index for a key frame. torch::Tensor getKeyFrameIndices(); diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 1f011f516..3f40da80f 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -38,6 +38,14 @@ class StreamMetadata: stream_index: int """Index of the stream that this metadata refers to (int).""" + # Computed fields (computed in C++ with fallback logic) + duration_seconds: Optional[float] + """Duration of the stream in seconds. Computed in C++ with fallback logic: + tries to calculate from content if scan was performed, otherwise falls back + to header values.""" + begin_stream_seconds: Optional[float] + """Beginning of the stream, in seconds. Computed in C++ with fallback logic.""" + def __repr__(self): s = self.__class__.__name__ + ":\n" for field in dataclasses.fields(self): @@ -45,7 +53,7 @@ def __repr__(self): return s -@dataclass +@dataclass(repr=False) class VideoStreamMetadata(StreamMetadata): """Metadata of a single video stream.""" @@ -87,103 +95,19 @@ class VideoStreamMetadata(StreamMetadata): is the ratio between the width and height of each pixel (``fractions.Fraction`` or None).""" - @property - def duration_seconds(self) -> Optional[float]: - """Duration of the stream in seconds. We try to calculate the duration - from the actual frames if a :term:`scan` was performed. Otherwise we - fall back to ``duration_seconds_from_header``. If that value is also None, - we instead calculate the duration from ``num_frames_from_header`` and - ``average_fps_from_header``. - """ - if ( - self.end_stream_seconds_from_content is not None - and self.begin_stream_seconds_from_content is not None - ): - return ( - self.end_stream_seconds_from_content - - self.begin_stream_seconds_from_content - ) - elif self.duration_seconds_from_header is not None: - return self.duration_seconds_from_header - elif ( - self.num_frames_from_header is not None - and self.average_fps_from_header is not None - ): - return self.num_frames_from_header / self.average_fps_from_header - else: - return None - - @property - def begin_stream_seconds(self) -> float: - """Beginning of the stream, in seconds (float). Conceptually, this - corresponds to the first frame's :term:`pts`. If - ``begin_stream_seconds_from_content`` is not None, then it is returned. - Otherwise, this value is 0. - """ - if self.begin_stream_seconds_from_content is None: - return 0 - else: - return self.begin_stream_seconds_from_content - - @property - def end_stream_seconds(self) -> Optional[float]: - """End of the stream, in seconds (float or None). - Conceptually, this corresponds to last_frame.pts + last_frame.duration. - If ``end_stream_seconds_from_content`` is not None, then that value is - returned. Otherwise, returns ``duration_seconds``. - """ - if self.end_stream_seconds_from_content is None: - return self.duration_seconds - else: - return self.end_stream_seconds_from_content - - @property - def num_frames(self) -> Optional[int]: - """Number of frames in the stream (int or None). - This corresponds to ``num_frames_from_content`` if a :term:`scan` was made, - otherwise it corresponds to ``num_frames_from_header``. If that value is also - None, the number of frames is calculated from the duration and the average fps. - """ - if self.num_frames_from_content is not None: - return self.num_frames_from_content - elif self.num_frames_from_header is not None: - return self.num_frames_from_header - elif ( - self.average_fps_from_header is not None - and self.duration_seconds_from_header is not None - ): - return int(self.average_fps_from_header * self.duration_seconds_from_header) - else: - return None - - @property - def average_fps(self) -> Optional[float]: - """Average fps of the stream. If a :term:`scan` was perfomed, this is - computed from the number of frames and the duration of the stream. - Otherwise we fall back to ``average_fps_from_header``. - """ - if ( - self.end_stream_seconds_from_content is None - or self.begin_stream_seconds_from_content is None - or self.num_frames is None - # Should never happen, but prevents ZeroDivisionError: - or self.end_stream_seconds_from_content - == self.begin_stream_seconds_from_content - ): - return self.average_fps_from_header - return self.num_frames / ( - self.end_stream_seconds_from_content - - self.begin_stream_seconds_from_content - ) - - def __repr__(self): - s = super().__repr__() - s += f"{SPACES}duration_seconds: {self.duration_seconds}\n" - s += f"{SPACES}begin_stream_seconds: {self.begin_stream_seconds}\n" - s += f"{SPACES}end_stream_seconds: {self.end_stream_seconds}\n" - s += f"{SPACES}num_frames: {self.num_frames}\n" - s += f"{SPACES}average_fps: {self.average_fps}\n" - return s + # Computed fields (computed in C++ with fallback logic) + end_stream_seconds: Optional[float] + """End of the stream, in seconds (float or None). + Conceptually, this corresponds to last_frame.pts + last_frame.duration. + Computed in C++ with fallback logic.""" + num_frames: Optional[int] + """Number of frames in the stream (int or None). + Computed in C++ with fallback logic: uses content if scan was performed, + otherwise falls back to header values or calculates from duration and fps.""" + average_fps: Optional[float] + """Average fps of the stream (float or None). + Computed in C++ with fallback logic: if scan was performed, computes from + num_frames and duration, otherwise uses header value.""" @dataclass @@ -260,10 +184,12 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata: stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index)) common_meta = dict( duration_seconds_from_header=stream_dict.get("durationSecondsFromHeader"), + duration_seconds=stream_dict.get("durationSeconds"), bit_rate=stream_dict.get("bitRate"), begin_stream_seconds_from_header=stream_dict.get( "beginStreamSecondsFromHeader" ), + begin_stream_seconds=stream_dict.get("beginStreamSeconds"), codec=stream_dict.get("codec"), stream_index=stream_index, ) @@ -276,6 +202,9 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata: end_stream_seconds_from_content=stream_dict.get( "endStreamSecondsFromContent" ), + end_stream_seconds=stream_dict.get("endStreamSeconds"), + num_frames=stream_dict.get("numFrames"), + average_fps=stream_dict.get("averageFps"), width=stream_dict.get("width"), height=stream_dict.get("height"), num_frames_from_header=stream_dict.get("numFramesFromHeader"), diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index b78160f6c..14a9f899e 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -176,13 +176,13 @@ std::string mapToJson(const std::map& metadataMap) { return ss.str(); } -SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) { +SeekMode seekModeFromString(std::string_view seekMode) { if (seekMode == "exact") { - return SingleStreamDecoder::SeekMode::exact; + return SeekMode::exact; } else if (seekMode == "approximate") { - return SingleStreamDecoder::SeekMode::approximate; + return SeekMode::approximate; } else if (seekMode == "custom_frame_mappings") { - return SingleStreamDecoder::SeekMode::custom_frame_mappings; + return SeekMode::custom_frame_mappings; } else { TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); } @@ -285,7 +285,7 @@ at::Tensor create_from_file( std::optional seek_mode = std::nullopt) { std::string filenameStr(filename); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SeekMode realSeek = SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } @@ -306,7 +306,7 @@ at::Tensor create_from_tensor( video_tensor.scalar_type() == torch::kUInt8, "video_tensor must be kUInt8"); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SeekMode realSeek = SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } @@ -329,7 +329,7 @@ at::Tensor _create_from_file_like( fileLikeContext != nullptr, "file_like_context must be a valid pointer"); std::unique_ptr avioContextHolder(fileLikeContext); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SeekMode realSeek = SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } @@ -796,6 +796,7 @@ std::string get_stream_json_metadata( } auto streamMetadata = allStreamMetadata[stream_index]; + auto seekMode = videoDecoder->getSeekMode(); std::map map; @@ -803,6 +804,10 @@ std::string get_stream_json_metadata( map["durationSecondsFromHeader"] = std::to_string(*streamMetadata.durationSecondsFromHeader); } + if (streamMetadata.getDurationSeconds(seekMode).has_value()) { + map["durationSeconds"] = + std::to_string(streamMetadata.getDurationSeconds(seekMode).value()); + } if (streamMetadata.bitRate.has_value()) { map["bitRate"] = std::to_string(*streamMetadata.bitRate); } @@ -814,6 +819,11 @@ std::string get_stream_json_metadata( map["numFramesFromHeader"] = std::to_string(*streamMetadata.numFramesFromHeader); } + if (streamMetadata.getNumFrames(seekMode).has_value()) { + map["numFrames"] = + std::to_string(streamMetadata.getNumFrames(seekMode).value()); + } + if (streamMetadata.beginStreamSecondsFromHeader.has_value()) { map["beginStreamSecondsFromHeader"] = std::to_string(*streamMetadata.beginStreamSecondsFromHeader); @@ -822,10 +832,16 @@ std::string get_stream_json_metadata( map["beginStreamSecondsFromContent"] = std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); } + map["beginStreamSeconds"] = + std::to_string(streamMetadata.getBeginStreamSeconds(seekMode)); if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) { map["endStreamSecondsFromContent"] = std::to_string(*streamMetadata.endStreamPtsSecondsFromContent); } + if (streamMetadata.getEndStreamSeconds(seekMode).has_value()) { + map["endStreamSeconds"] = + std::to_string(streamMetadata.getEndStreamSeconds(seekMode).value()); + } if (streamMetadata.codecName.has_value()) { map["codec"] = quoteValue(streamMetadata.codecName.value()); } @@ -845,6 +861,10 @@ std::string get_stream_json_metadata( map["averageFpsFromHeader"] = std::to_string(*streamMetadata.averageFpsFromHeader); } + if (streamMetadata.getAverageFps(seekMode).has_value()) { + map["averageFps"] = + std::to_string(streamMetadata.getAverageFps(seekMode).value()); + } if (streamMetadata.sampleRate.has_value()) { map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); } diff --git a/test/test_decoders.py b/test/test_decoders.py index 5e5028da6..f3049d659 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -899,6 +899,11 @@ def test_get_frames_with_missing_num_frames_metadata( "mediaType": "video", "numFramesFromHeader": None, "numFramesFromContent": None, + "beginStreamSeconds": 0.0, + "durationSeconds": 13.013, + "endStreamSeconds": 13.013, + "numFrames": int(29.97003 * 13.013), # Calculated from fps * duration + "averageFps": 29.97003, } # Set the return value of the mock to be the mock_stream_dict mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict) diff --git a/test/test_metadata.py b/test/test_metadata.py index 628b7a68d..a1db0c77b 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -147,123 +147,6 @@ def test_get_metadata_audio_file(metadata_getter): assert best_audio_stream_metadata.sample_format == "fltp" -@pytest.mark.parametrize( - "num_frames_from_header, num_frames_from_content, expected_num_frames", - [(10, 20, 20), (None, 10, 10), (10, None, 10)], -) -def test_num_frames_fallback( - num_frames_from_header, num_frames_from_content, expected_num_frames -): - """Check that num_frames_from_content always has priority when accessing `.num_frames`""" - metadata = VideoStreamMetadata( - duration_seconds_from_header=4, - bit_rate=123, - num_frames_from_header=num_frames_from_header, - num_frames_from_content=num_frames_from_content, - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=0, - end_stream_seconds_from_content=4, - codec="whatever", - width=123, - height=321, - average_fps_from_header=30, - pixel_aspect_ratio=Fraction(1, 1), - stream_index=0, - ) - - assert metadata.num_frames == expected_num_frames - - -@pytest.mark.parametrize( - "average_fps_from_header, duration_seconds_from_header, expected_num_frames", - [(60, 10, 600), (60, None, None), (None, 10, None), (None, None, None)], -) -def test_calculate_num_frames_using_fps_and_duration( - average_fps_from_header, duration_seconds_from_header, expected_num_frames -): - """Check that if num_frames_from_content and num_frames_from_header are missing, - `.num_frames` is calculated using average_fps_from_header and duration_seconds_from_header - """ - metadata = VideoStreamMetadata( - duration_seconds_from_header=duration_seconds_from_header, - bit_rate=123, - num_frames_from_header=None, # None to test calculating num_frames - num_frames_from_content=None, # None to test calculating num_frames - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=0, - end_stream_seconds_from_content=4, - codec="whatever", - width=123, - height=321, - pixel_aspect_ratio=Fraction(10, 11), - average_fps_from_header=average_fps_from_header, - stream_index=0, - ) - - assert metadata.num_frames == expected_num_frames - - -@pytest.mark.parametrize( - "duration_seconds_from_header, begin_stream_seconds_from_content, end_stream_seconds_from_content, expected_duration_seconds", - [(60, 5, 20, 15), (60, 1, None, 60), (60, None, 1, 60), (None, 0, 10, 10)], -) -def test_duration_seconds_fallback( - duration_seconds_from_header, - begin_stream_seconds_from_content, - end_stream_seconds_from_content, - expected_duration_seconds, -): - """Check that using begin_stream_seconds_from_content and end_stream_seconds_from_content to calculate `.duration_seconds` - has priority. If either value is missing, duration_seconds_from_header is used. - """ - metadata = VideoStreamMetadata( - duration_seconds_from_header=duration_seconds_from_header, - bit_rate=123, - num_frames_from_header=5, - num_frames_from_content=10, - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=begin_stream_seconds_from_content, - end_stream_seconds_from_content=end_stream_seconds_from_content, - codec="whatever", - width=123, - height=321, - pixel_aspect_ratio=Fraction(10, 11), - average_fps_from_header=5, - stream_index=0, - ) - - assert metadata.duration_seconds == expected_duration_seconds - - -@pytest.mark.parametrize( - "num_frames_from_header, average_fps_from_header, expected_duration_seconds", - [(100, 10, 10), (100, None, None), (None, 10, None), (None, None, None)], -) -def test_calculate_duration_seconds_using_fps_and_num_frames( - num_frames_from_header, average_fps_from_header, expected_duration_seconds -): - """Check that duration_seconds is calculated using average_fps_from_header and num_frames_from_header - if duration_seconds_from_header is missing. - """ - metadata = VideoStreamMetadata( - duration_seconds_from_header=None, # None to test calculating duration_seconds - bit_rate=123, - num_frames_from_header=num_frames_from_header, - num_frames_from_content=10, - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=None, # None to test calculating duration_seconds - end_stream_seconds_from_content=None, # None to test calculating duration_seconds - codec="whatever", - width=123, - height=321, - pixel_aspect_ratio=Fraction(10, 11), - average_fps_from_header=average_fps_from_header, - stream_index=0, - ) - assert metadata.duration_seconds_from_header is None - assert metadata.duration_seconds == expected_duration_seconds - - def test_repr(): # Test for calls to print(), str(), etc. Useful to make sure we don't forget # to add additional @properties to __repr__ @@ -275,6 +158,8 @@ def test_repr(): bit_rate: 128783.0 codec: h264 stream_index: 3 + duration_seconds: 13.013 + begin_stream_seconds: 0.0 begin_stream_seconds_from_content: 0.0 end_stream_seconds_from_content: 13.013 width: 480 @@ -283,11 +168,9 @@ def test_repr(): num_frames_from_content: 390 average_fps_from_header: 29.97003 pixel_aspect_ratio: 1 - duration_seconds: 13.013 - begin_stream_seconds: 0.0 end_stream_seconds: 13.013 num_frames: 390 - average_fps: 29.97002997002997 + average_fps: 29.97003 """ ) ffmpeg_major_version = get_ffmpeg_major_version() @@ -303,6 +186,8 @@ def test_repr(): bit_rate: 64000.0 codec: mp3 stream_index: 0 + duration_seconds: 13.013 + begin_stream_seconds: 0.0 sample_rate: 8000 num_channels: 2 sample_format: fltp diff --git a/test/test_samplers.py b/test/test_samplers.py index 10c529062..24a57e2c6 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -595,6 +595,7 @@ def restore_metadata(): decoder.metadata.num_frames_from_header = ( None # Set to none to prevent fallback calculation ) + decoder.metadata.end_stream_seconds = None with pytest.raises( ValueError, match="Could not infer stream end from video metadata" ): @@ -603,6 +604,7 @@ def restore_metadata(): with restore_metadata(): decoder.metadata.end_stream_seconds_from_content = None decoder.metadata.average_fps_from_header = None + decoder.metadata.average_fps = None with pytest.raises(ValueError, match="Could not infer average fps"): sampler(decoder) From 08c348ec8db3f136fa07fbe0ed8628525f26017d Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Tue, 4 Nov 2025 11:55:37 -0800 Subject: [PATCH 02/13] fix tests and move seekmode to metadata.h --- src/torchcodec/_core/Metadata.cpp | 23 +++++----------------- src/torchcodec/_core/Metadata.h | 4 ++-- src/torchcodec/_core/SeekMode.h | 13 ------------ src/torchcodec/_core/SingleStreamDecoder.h | 2 +- test/VideoDecoderTest.cpp | 4 ++-- test/test_metadata.py | 2 +- 6 files changed, 11 insertions(+), 37 deletions(-) delete mode 100644 src/torchcodec/_core/SeekMode.h diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp index 41e41bd5b..c5462af90 100644 --- a/src/torchcodec/_core/Metadata.cpp +++ b/src/torchcodec/_core/Metadata.cpp @@ -14,12 +14,8 @@ std::optional StreamMetadata::getDurationSeconds( case SeekMode::custom_frame_mappings: case SeekMode::exact: // In exact mode, use the scanned content value - if (endStreamPtsSecondsFromContent.has_value() && - beginStreamPtsSecondsFromContent.has_value()) { - return endStreamPtsSecondsFromContent.value() - - beginStreamPtsSecondsFromContent.value(); - } - return std::nullopt; + return endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value(); case SeekMode::approximate: if (durationSecondsFromHeader.has_value()) { return durationSecondsFromHeader.value(); @@ -38,10 +34,7 @@ double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const { switch (seekMode) { case SeekMode::custom_frame_mappings: case SeekMode::exact: - if (beginStreamPtsSecondsFromContent.has_value()) { - return beginStreamPtsSecondsFromContent.value(); - } - return 0.0; + return beginStreamPtsSecondsFromContent.value(); case SeekMode::approximate: return 0.0; } @@ -53,10 +46,7 @@ std::optional StreamMetadata::getEndStreamSeconds( switch (seekMode) { case SeekMode::custom_frame_mappings: case SeekMode::exact: - if (endStreamPtsSecondsFromContent.has_value()) { - return endStreamPtsSecondsFromContent.value(); - } - return getDurationSeconds(seekMode); + return endStreamPtsSecondsFromContent.value(); case SeekMode::approximate: return getDurationSeconds(seekMode); } @@ -67,10 +57,7 @@ std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { switch (seekMode) { case SeekMode::custom_frame_mappings: case SeekMode::exact: - if (numFramesFromContent.has_value()) { - return numFramesFromContent.value(); - } - return std::nullopt; + return numFramesFromContent.value(); case SeekMode::approximate: { if (numFramesFromHeader.has_value()) { return numFramesFromHeader.value(); diff --git a/src/torchcodec/_core/Metadata.h b/src/torchcodec/_core/Metadata.h index 57da8c989..e138d5dc0 100644 --- a/src/torchcodec/_core/Metadata.h +++ b/src/torchcodec/_core/Metadata.h @@ -10,8 +10,6 @@ #include #include -#include "SeekMode.h" - extern "C" { #include #include @@ -20,6 +18,8 @@ extern "C" { namespace facebook::torchcodec { +enum class SeekMode { exact, approximate, custom_frame_mappings }; + struct StreamMetadata { // Common (video and audio) fields derived from the AVStream. int streamIndex; diff --git a/src/torchcodec/_core/SeekMode.h b/src/torchcodec/_core/SeekMode.h deleted file mode 100644 index fa30414ed..000000000 --- a/src/torchcodec/_core/SeekMode.h +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -namespace facebook::torchcodec { - -enum class SeekMode { exact, approximate, custom_frame_mappings }; - -} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index dd1966d7b..597750e0f 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -16,7 +16,7 @@ #include "DeviceInterface.h" #include "FFMPEGCommon.h" #include "Frame.h" -#include "SeekMode.h" +#include "Metadata.h" #include "StreamOptions.h" #include "Transform.h" diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 1481d3a2a..63abb3d76 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -64,10 +64,10 @@ class SingleStreamDecoderTest : public testing::TestWithParam { auto contextHolder = std::make_unique(tensor); return std::make_unique( - std::move(contextHolder), SingleStreamDecoder::SeekMode::approximate); + std::move(contextHolder), SeekMode::approximate); } else { return std::make_unique( - filepath, SingleStreamDecoder::SeekMode::approximate); + filepath, SeekMode::approximate); } } diff --git a/test/test_metadata.py b/test/test_metadata.py index a1db0c77b..a4f6da341 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -186,7 +186,7 @@ def test_repr(): bit_rate: 64000.0 codec: mp3 stream_index: 0 - duration_seconds: 13.013 + duration_seconds: {expected_duration_seconds_from_header} begin_stream_seconds: 0.0 sample_rate: 8000 num_channels: 2 From 868e23e51c629c67ce7c775b4830a7399f08bd29 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Tue, 4 Nov 2025 12:17:10 -0800 Subject: [PATCH 03/13] add checks --- src/torchcodec/_core/Metadata.cpp | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp index c5462af90..41e41bd5b 100644 --- a/src/torchcodec/_core/Metadata.cpp +++ b/src/torchcodec/_core/Metadata.cpp @@ -14,8 +14,12 @@ std::optional StreamMetadata::getDurationSeconds( case SeekMode::custom_frame_mappings: case SeekMode::exact: // In exact mode, use the scanned content value - return endStreamPtsSecondsFromContent.value() - - beginStreamPtsSecondsFromContent.value(); + if (endStreamPtsSecondsFromContent.has_value() && + beginStreamPtsSecondsFromContent.has_value()) { + return endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value(); + } + return std::nullopt; case SeekMode::approximate: if (durationSecondsFromHeader.has_value()) { return durationSecondsFromHeader.value(); @@ -34,7 +38,10 @@ double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const { switch (seekMode) { case SeekMode::custom_frame_mappings: case SeekMode::exact: - return beginStreamPtsSecondsFromContent.value(); + if (beginStreamPtsSecondsFromContent.has_value()) { + return beginStreamPtsSecondsFromContent.value(); + } + return 0.0; case SeekMode::approximate: return 0.0; } @@ -46,7 +53,10 @@ std::optional StreamMetadata::getEndStreamSeconds( switch (seekMode) { case SeekMode::custom_frame_mappings: case SeekMode::exact: - return endStreamPtsSecondsFromContent.value(); + if (endStreamPtsSecondsFromContent.has_value()) { + return endStreamPtsSecondsFromContent.value(); + } + return getDurationSeconds(seekMode); case SeekMode::approximate: return getDurationSeconds(seekMode); } @@ -57,7 +67,10 @@ std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { switch (seekMode) { case SeekMode::custom_frame_mappings: case SeekMode::exact: - return numFramesFromContent.value(); + if (numFramesFromContent.has_value()) { + return numFramesFromContent.value(); + } + return std::nullopt; case SeekMode::approximate: { if (numFramesFromHeader.has_value()) { return numFramesFromHeader.value(); From e58cf39bae3fdb700820eac4a067eb22229786e3 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Tue, 4 Nov 2025 15:25:45 -0800 Subject: [PATCH 04/13] add cpp tests --- test/CMakeLists.txt | 19 ++++ test/MetadataTest.cpp | 212 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 test/MetadataTest.cpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 988def933..f96c06223 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -34,3 +34,22 @@ target_link_libraries( include(GoogleTest) gtest_discover_tests(VideoDecoderTest) + + +add_executable( + MetadataTest + MetadataTest.cpp +) + +target_include_directories(MetadataTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) +target_include_directories(MetadataTest SYSTEM PRIVATE ${libav_include_dirs}) +target_include_directories(MetadataTest PRIVATE ../) + +target_link_libraries( + MetadataTest + ${libtorchcodec_library_name} + ${libtorchcodec_custom_ops_name} + GTest::gtest_main +) + +gtest_discover_tests(MetadataTest) diff --git a/test/MetadataTest.cpp b/test/MetadataTest.cpp new file mode 100644 index 000000000..dbd2a3497 --- /dev/null +++ b/test/MetadataTest.cpp @@ -0,0 +1,212 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/Metadata.h" + +#include + +namespace facebook::torchcodec { + +// Test that num_frames_from_content always has priority when accessing +// getNumFrames() +TEST(MetadataTest, NumFramesFallbackPriority) { + // in exact mode, both header and content available + { + StreamMetadata metadata; + metadata.numFramesFromHeader = 10; + metadata.numFramesFromContent = 20; + metadata.durationSecondsFromHeader = 4.0; + metadata.averageFpsFromHeader = 30.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::exact), 20); + } + + // in exact mode, only content available + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = 10; + metadata.durationSecondsFromHeader = 4.0; + metadata.averageFpsFromHeader = 30.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::exact), 10); + } + + // in approximate mode, header should be used + { + StreamMetadata metadata; + metadata.numFramesFromHeader = 10; + metadata.numFramesFromContent = std::nullopt; + metadata.durationSecondsFromHeader = 4.0; + metadata.averageFpsFromHeader = 30.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), 10); + } +} + +// Test that if num_frames_from_content and num_frames_from_header are missing, +// getNumFrames() is calculated using average_fps_from_header and +// duration_seconds_from_header in approximate mode +TEST(MetadataTest, CalculateNumFramesUsingFpsAndDuration) { + // both fps and duration available + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = 60.0; + metadata.durationSecondsFromHeader = 10.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), 600); + } + + // fps available but duration missing + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = 60.0; + metadata.durationSecondsFromHeader = std::nullopt; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), std::nullopt); + } + + // duration available but fps missing + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = std::nullopt; + metadata.durationSecondsFromHeader = 10.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), std::nullopt); + } + + // both missing + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = std::nullopt; + metadata.durationSecondsFromHeader = std::nullopt; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), std::nullopt); + } +} + +// Test that using begin_stream_seconds_from_content and +// end_stream_seconds_from_content to calculate getDurationSeconds() has +// priority. If either value is missing, duration_seconds_from_header is used. +TEST(MetadataTest, DurationSecondsFallback) { + // in exact mode, both begin and end content available, should calculate from + // them + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = 60.0; + metadata.beginStreamPtsSecondsFromContent = 5.0; + metadata.endStreamPtsSecondsFromContent = 20.0; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::exact).value(), 15.0, 1e-6); + } + + // in exact mode, begin content available but end missing, should fall back to + // header + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = 60.0; + metadata.beginStreamPtsSecondsFromContent = 1.0; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::exact), std::nullopt); + } + + // Test case 3: end content available but begin missing, should fall back + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = 60.0; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = 1.0; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::exact), std::nullopt); + } + + // in exact mode, only content values, no header + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.beginStreamPtsSecondsFromContent = 0.0; + metadata.endStreamPtsSecondsFromContent = 10.0; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::exact).value(), 10.0, 1e-6); + } + + // in approximate mode, header value takes priority (ignores content) + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = 60.0; + metadata.beginStreamPtsSecondsFromContent = 5.0; + metadata.endStreamPtsSecondsFromContent = 20.0; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::approximate).value(), 60.0, 1e-6); + } +} + +// Test that duration_seconds is calculated using average_fps_from_header and +// num_frames_from_header if duration_seconds_from_header is missing. +TEST(MetadataTest, CalculateDurationSecondsUsingFpsAndNumFrames) { + // in approximate mode, both num_frames and fps available + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = 100; + metadata.averageFpsFromHeader = 10.0; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::approximate).value(), 10.0, 1e-6); + } + + // in approximate mode, num_frames available but fps missing + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = 100; + metadata.averageFpsFromHeader = std::nullopt; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::approximate), std::nullopt); + } + + // in approximate mode, fps available but num_frames missing + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = std::nullopt; + metadata.averageFpsFromHeader = 10.0; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::approximate), std::nullopt); + } + + // in approximate mode, both missing + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = std::nullopt; + metadata.averageFpsFromHeader = std::nullopt; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::approximate), std::nullopt); + } +} + +} // namespace facebook::torchcodec From 35febfba420f5619f4096378fd046cf3debaf3b1 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 5 Nov 2025 07:59:40 -0800 Subject: [PATCH 05/13] use streametadata methods in singlestreamdecoder --- src/torchcodec/_core/SingleStreamDecoder.cpp | 32 ++------------------ 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 6778331f9..1da481ce4 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1445,43 +1445,17 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { std::optional SingleStreamDecoder::getNumFrames( const StreamMetadata& streamMetadata) { - switch (seekMode_) { - case SeekMode::custom_frame_mappings: - case SeekMode::exact: - return streamMetadata.numFramesFromContent.value(); - case SeekMode::approximate: { - return streamMetadata.numFramesFromHeader; - } - default: - TORCH_CHECK(false, "Unknown SeekMode"); - } + return streamMetadata.getNumFrames(seekMode_); } double SingleStreamDecoder::getMinSeconds( const StreamMetadata& streamMetadata) { - switch (seekMode_) { - case SeekMode::custom_frame_mappings: - case SeekMode::exact: - return streamMetadata.beginStreamPtsSecondsFromContent.value(); - case SeekMode::approximate: - return 0; - default: - TORCH_CHECK(false, "Unknown SeekMode"); - } + return streamMetadata.getBeginStreamSeconds(seekMode_); } std::optional SingleStreamDecoder::getMaxSeconds( const StreamMetadata& streamMetadata) { - switch (seekMode_) { - case SeekMode::custom_frame_mappings: - case SeekMode::exact: - return streamMetadata.endStreamPtsSecondsFromContent.value(); - case SeekMode::approximate: { - return streamMetadata.durationSecondsFromHeader; - } - default: - TORCH_CHECK(false, "Unknown SeekMode"); - } + return streamMetadata.getEndStreamSeconds(seekMode_); } // -------------------------------------------------------------------------- From 20a7227d2548ce32ef2b2e672b2e36250055c2c3 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 5 Nov 2025 10:21:05 -0800 Subject: [PATCH 06/13] address feedback --- src/torchcodec/_core/SingleStreamDecoder.cpp | 31 ++++++-------------- src/torchcodec/_core/SingleStreamDecoder.h | 4 --- src/torchcodec/_core/_metadata.py | 14 ++++----- 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 1da481ce4..348c200a8 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -615,7 +615,7 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - std::optional numFrames = getNumFrames(streamMetadata); + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); if (numFrames.has_value()) { // If the frameIndex is negative, we convert it to a positive index frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value(); @@ -709,7 +709,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( // Note that if we do not have the number of frames available in our // metadata, then we assume that the upper part of the range is valid. - std::optional numFrames = getNumFrames(streamMetadata); + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); if (numFrames.has_value()) { TORCH_CHECK( stop <= numFrames.value(), @@ -783,8 +783,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - double minSeconds = getMinSeconds(streamMetadata); - std::optional maxSeconds = getMaxSeconds(streamMetadata); + double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_); + std::optional maxSeconds = + streamMetadata.getEndStreamSeconds(seekMode_); // The frame played at timestamp t and the one played at timestamp `t + // eps` are probably the same frame, with the same index. The easiest way to @@ -861,7 +862,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( return frameBatchOutput; } - double minSeconds = getMinSeconds(streamMetadata); + double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_); TORCH_CHECK( startSeconds >= minSeconds, "Start seconds is " + std::to_string(startSeconds) + @@ -870,7 +871,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // Note that if we can't determine the maximum seconds from the metadata, // then we assume upper range is valid. - std::optional maxSeconds = getMaxSeconds(streamMetadata); + std::optional maxSeconds = + streamMetadata.getEndStreamSeconds(seekMode_); if (maxSeconds.has_value()) { TORCH_CHECK( startSeconds < maxSeconds.value(), @@ -1443,21 +1445,6 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { // STREAM AND METADATA APIS // -------------------------------------------------------------------------- -std::optional SingleStreamDecoder::getNumFrames( - const StreamMetadata& streamMetadata) { - return streamMetadata.getNumFrames(seekMode_); -} - -double SingleStreamDecoder::getMinSeconds( - const StreamMetadata& streamMetadata) { - return streamMetadata.getBeginStreamSeconds(seekMode_); -} - -std::optional SingleStreamDecoder::getMaxSeconds( - const StreamMetadata& streamMetadata) { - return streamMetadata.getEndStreamSeconds(seekMode_); -} - // -------------------------------------------------------------------------- // VALIDATION UTILS // -------------------------------------------------------------------------- @@ -1507,7 +1494,7 @@ void SingleStreamDecoder::validateFrameIndex( // Note that if we do not have the number of frames available in our // metadata, then we assume that the frameIndex is valid. - std::optional numFrames = getNumFrames(streamMetadata); + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); if (numFrames.has_value()) { if (frameIndex >= numFrames.value()) { throw std::out_of_range( diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 597750e0f..9ff504c1c 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -314,10 +314,6 @@ class SingleStreamDecoder { // index. Note that this index may be truncated for some files. int getBestStreamIndex(AVMediaType mediaType); - std::optional getNumFrames(const StreamMetadata& streamMetadata); - double getMinSeconds(const StreamMetadata& streamMetadata); - std::optional getMaxSeconds(const StreamMetadata& streamMetadata); - // -------------------------------------------------------------------------- // VALIDATION UTILS // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 3f40da80f..6ac9da44a 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -40,11 +40,10 @@ class StreamMetadata: # Computed fields (computed in C++ with fallback logic) duration_seconds: Optional[float] - """Duration of the stream in seconds. Computed in C++ with fallback logic: - tries to calculate from content if scan was performed, otherwise falls back - to header values.""" + """Duration of the stream in seconds. Tries to calculate from content + if :term:`scan` was performed, otherwise falls back to header values.""" begin_stream_seconds: Optional[float] - """Beginning of the stream, in seconds. Computed in C++ with fallback logic.""" + """Beginning of the stream, in seconds.""" def __repr__(self): s = self.__class__.__name__ + ":\n" @@ -98,15 +97,14 @@ class VideoStreamMetadata(StreamMetadata): # Computed fields (computed in C++ with fallback logic) end_stream_seconds: Optional[float] """End of the stream, in seconds (float or None). - Conceptually, this corresponds to last_frame.pts + last_frame.duration. - Computed in C++ with fallback logic.""" + Conceptually, this corresponds to last_frame.pts + last_frame.duration.""" num_frames: Optional[int] """Number of frames in the stream (int or None). - Computed in C++ with fallback logic: uses content if scan was performed, + Uses content if :term:`scan` was performed, otherwise falls back to header values or calculates from duration and fps.""" average_fps: Optional[float] """Average fps of the stream (float or None). - Computed in C++ with fallback logic: if scan was performed, computes from + if :term:`scan` was performed, computes from num_frames and duration, otherwise uses header value.""" From b677863ff6b6996374a8db422fb2ce810c9c3e08 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 5 Nov 2025 11:13:01 -0800 Subject: [PATCH 07/13] address feedback --- src/torchcodec/_core/_metadata.py | 5 ++- test/test_decoders.py | 57 ------------------------------- 2 files changed, 4 insertions(+), 58 deletions(-) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 6ac9da44a..3ce29c4e6 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -52,7 +52,7 @@ def __repr__(self): return s -@dataclass(repr=False) +@dataclass class VideoStreamMetadata(StreamMetadata): """Metadata of a single video stream.""" @@ -107,6 +107,9 @@ class VideoStreamMetadata(StreamMetadata): if :term:`scan` was performed, computes from num_frames and duration, otherwise uses header value.""" + def __repr__(self): + return super().__repr__() + @dataclass class AudioStreamMetadata(StreamMetadata): diff --git a/test/test_decoders.py b/test/test_decoders.py index f3049d659..e9f3f701b 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -6,9 +6,7 @@ import contextlib import gc -import json from functools import partial -from unittest.mock import patch import numpy import pytest @@ -877,61 +875,6 @@ def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode): ).to(device) assert_frames_equal(frames387_None.data, reference_frame387_389) - @pytest.mark.parametrize("device", all_supported_devices()) - @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) - @patch("torchcodec._core._metadata._get_stream_json_metadata") - def test_get_frames_with_missing_num_frames_metadata( - self, mock_get_stream_json_metadata, device, seek_mode - ): - # Create a mock stream_dict to test that initializing VideoDecoder without - # num_frames_from_header and num_frames_from_content calculates num_frames - # using the average_fps and duration_seconds metadata. - mock_stream_dict = { - "averageFpsFromHeader": 29.97003, - "beginStreamSecondsFromContent": 0.0, - "beginStreamSecondsFromHeader": 0.0, - "bitRate": 128783.0, - "codec": "h264", - "durationSecondsFromHeader": 13.013, - "endStreamSecondsFromContent": 13.013, - "width": 480, - "height": 270, - "mediaType": "video", - "numFramesFromHeader": None, - "numFramesFromContent": None, - "beginStreamSeconds": 0.0, - "durationSeconds": 13.013, - "endStreamSeconds": 13.013, - "numFrames": int(29.97003 * 13.013), # Calculated from fps * duration - "averageFps": 29.97003, - } - # Set the return value of the mock to be the mock_stream_dict - mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict) - - decoder, device = make_video_decoder( - NASA_VIDEO.path, - stream_index=3, - device=device, - seek_mode=seek_mode, - ) - - assert decoder.metadata.num_frames_from_header is None - assert decoder.metadata.num_frames_from_content is None - assert decoder.metadata.duration_seconds is not None - assert decoder.metadata.average_fps is not None - assert decoder.metadata.num_frames == int( - decoder.metadata.duration_seconds * decoder.metadata.average_fps - ) - assert len(decoder) == 390 - - # Test get_frames_in_range Python logic which uses the num_frames metadata mocked earlier. - # The frame is read at the C++ level. - ref_frames9 = NASA_VIDEO.get_frame_data_by_range( - start=9, stop=10, stream_index=3 - ).to(device) - frames9 = decoder.get_frames_in_range(start=9, stop=10) - assert_frames_equal(ref_frames9, frames9.data) - @pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"]) @pytest.mark.parametrize( "frame_getter", From 1478117f1d5938af6daf038888ec5236f09f4ecd Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 5 Nov 2025 11:20:26 -0800 Subject: [PATCH 08/13] modify docstrings --- src/torchcodec/_core/_metadata.py | 32 ++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 3ce29c4e6..08bcf2b55 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -40,10 +40,18 @@ class StreamMetadata: # Computed fields (computed in C++ with fallback logic) duration_seconds: Optional[float] - """Duration of the stream in seconds. Tries to calculate from content - if :term:`scan` was performed, otherwise falls back to header values.""" + """Duration of the stream in seconds. We try to calculate the duration + from the actual frames if a :term:`scan` was performed. Otherwise we + fall back to ``duration_seconds_from_header``. If that value is also None, + we instead calculate the duration from ``num_frames_from_header`` and + ``average_fps_from_header``. + """ begin_stream_seconds: Optional[float] - """Beginning of the stream, in seconds.""" + """Beginning of the stream, in seconds (float). Conceptually, this + corresponds to the first frame's :term:`pts`. If a :term:`scan` was performed + and ``begin_stream_seconds_from_content`` is not None, then it is returned. + Otherwise, this value is 0. + """ def __repr__(self): s = self.__class__.__name__ + ":\n" @@ -97,15 +105,21 @@ class VideoStreamMetadata(StreamMetadata): # Computed fields (computed in C++ with fallback logic) end_stream_seconds: Optional[float] """End of the stream, in seconds (float or None). - Conceptually, this corresponds to last_frame.pts + last_frame.duration.""" + Conceptually, this corresponds to last_frame.pts + last_frame.duration. + If :term:`scan` was performed and``end_stream_seconds_from_content`` is not None, then that value is + returned. Otherwise, returns ``duration_seconds``. + """ num_frames: Optional[int] """Number of frames in the stream (int or None). - Uses content if :term:`scan` was performed, - otherwise falls back to header values or calculates from duration and fps.""" + This corresponds to ``num_frames_from_content`` if a :term:`scan` was made, + otherwise it corresponds to ``num_frames_from_header``. If that value is also + None, the number of frames is calculated from the duration and the average fps. + """ average_fps: Optional[float] - """Average fps of the stream (float or None). - if :term:`scan` was performed, computes from - num_frames and duration, otherwise uses header value.""" + """Average fps of the stream. If a :term:`scan` was perfomed, this is + computed from the number of frames and the duration of the stream. + Otherwise we fall back to ``average_fps_from_header``. + """ def __repr__(self): return super().__repr__() From 3494331549b04e14cb9033c4e0e8961940c1ca39 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 5 Nov 2025 12:40:39 -0800 Subject: [PATCH 09/13] modified fallback logic --- src/torchcodec/_core/Metadata.cpp | 51 +++++++++++++++++++------------ 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp index 41e41bd5b..1b890379a 100644 --- a/src/torchcodec/_core/Metadata.cpp +++ b/src/torchcodec/_core/Metadata.cpp @@ -5,15 +5,17 @@ // LICENSE file in the root directory of this source tree. #include "Metadata.h" +#include "torch/types.h" namespace facebook::torchcodec { std::optional StreamMetadata::getDurationSeconds( SeekMode seekMode) const { switch (seekMode) { - case SeekMode::custom_frame_mappings: case SeekMode::exact: - // In exact mode, use the scanned content value + return endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value(); + case SeekMode::custom_frame_mappings: if (endStreamPtsSecondsFromContent.has_value() && beginStreamPtsSecondsFromContent.has_value()) { return endStreamPtsSecondsFromContent.value() - @@ -30,48 +32,51 @@ std::optional StreamMetadata::getDurationSeconds( averageFpsFromHeader.value(); } return std::nullopt; + default: + TORCH_CHECK(false, "Unknown SeekMode"); } - return std::nullopt; } double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const { switch (seekMode) { - case SeekMode::custom_frame_mappings: case SeekMode::exact: + return beginStreamPtsSecondsFromContent.value(); + case SeekMode::custom_frame_mappings: + case SeekMode::approximate: if (beginStreamPtsSecondsFromContent.has_value()) { return beginStreamPtsSecondsFromContent.value(); } return 0.0; - case SeekMode::approximate: - return 0.0; + default: + TORCH_CHECK(false, "Unknown SeekMode"); } - return 0.0; } std::optional StreamMetadata::getEndStreamSeconds( SeekMode seekMode) const { switch (seekMode) { - case SeekMode::custom_frame_mappings: case SeekMode::exact: + return endStreamPtsSecondsFromContent.value(); + case SeekMode::custom_frame_mappings: + case SeekMode::approximate: if (endStreamPtsSecondsFromContent.has_value()) { return endStreamPtsSecondsFromContent.value(); } return getDurationSeconds(seekMode); - case SeekMode::approximate: - return getDurationSeconds(seekMode); + default: + TORCH_CHECK(false, "Unknown SeekMode"); } - return std::nullopt; } std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { switch (seekMode) { - case SeekMode::custom_frame_mappings: case SeekMode::exact: + return numFramesFromContent.value(); + case SeekMode::custom_frame_mappings: + case SeekMode::approximate: { if (numFramesFromContent.has_value()) { return numFramesFromContent.value(); } - return std::nullopt; - case SeekMode::approximate: { if (numFramesFromHeader.has_value()) { return numFramesFromHeader.value(); } @@ -82,14 +87,23 @@ std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { } return std::nullopt; } + default: + TORCH_CHECK(false, "Unknown SeekMode"); } - return std::nullopt; } std::optional StreamMetadata::getAverageFps(SeekMode seekMode) const { switch (seekMode) { - case SeekMode::custom_frame_mappings: case SeekMode::exact: + if (endStreamPtsSecondsFromContent.value() != + beginStreamPtsSecondsFromContent.value()) { + return static_cast( + getNumFrames(seekMode).value() / + (endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value())); + } + case SeekMode::custom_frame_mappings: + case SeekMode::approximate: if (getNumFrames(seekMode).has_value() && beginStreamPtsSecondsFromContent.has_value() && endStreamPtsSecondsFromContent.has_value() && @@ -101,10 +115,9 @@ std::optional StreamMetadata::getAverageFps(SeekMode seekMode) const { beginStreamPtsSecondsFromContent.value())); } return averageFpsFromHeader; - case SeekMode::approximate: - return averageFpsFromHeader; + default: + TORCH_CHECK(false, "Unknown SeekMode"); } - return std::nullopt; } } // namespace facebook::torchcodec From cdabcb073aea5ac99f357732e0eadca46e1a031b Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 5 Nov 2025 12:54:40 -0800 Subject: [PATCH 10/13] fix fallthrough error --- src/torchcodec/_core/Metadata.cpp | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp index 1b890379a..48096ca14 100644 --- a/src/torchcodec/_core/Metadata.cpp +++ b/src/torchcodec/_core/Metadata.cpp @@ -94,16 +94,8 @@ std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { std::optional StreamMetadata::getAverageFps(SeekMode seekMode) const { switch (seekMode) { - case SeekMode::exact: - if (endStreamPtsSecondsFromContent.value() != - beginStreamPtsSecondsFromContent.value()) { - return static_cast( - getNumFrames(seekMode).value() / - (endStreamPtsSecondsFromContent.value() - - beginStreamPtsSecondsFromContent.value())); - } case SeekMode::custom_frame_mappings: - case SeekMode::approximate: + case SeekMode::exact: if (getNumFrames(seekMode).has_value() && beginStreamPtsSecondsFromContent.has_value() && endStreamPtsSecondsFromContent.has_value() && @@ -115,6 +107,8 @@ std::optional StreamMetadata::getAverageFps(SeekMode seekMode) const { beginStreamPtsSecondsFromContent.value())); } return averageFpsFromHeader; + case SeekMode::approximate: + return averageFpsFromHeader; default: TORCH_CHECK(false, "Unknown SeekMode"); } From 8bf490cd5e7999f611ced21290c46887df237a4d Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 5 Nov 2025 15:01:47 -0800 Subject: [PATCH 11/13] modify fallback logic and add torch_checks --- src/torchcodec/_core/Metadata.cpp | 26 ++- src/torchcodec/_core/SingleStreamDecoder.cpp | 4 + src/torchcodec/_core/SingleStreamDecoder.h | 3 + src/torchcodec/_core/custom_ops.cpp | 227 ++++++++++++------- test/MetadataTest.cpp | 21 -- 5 files changed, 171 insertions(+), 110 deletions(-) diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp index 48096ca14..b0e9425bd 100644 --- a/src/torchcodec/_core/Metadata.cpp +++ b/src/torchcodec/_core/Metadata.cpp @@ -12,16 +12,14 @@ namespace facebook::torchcodec { std::optional StreamMetadata::getDurationSeconds( SeekMode seekMode) const { switch (seekMode) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: + TORCH_CHECK( + endStreamPtsSecondsFromContent.has_value() && + beginStreamPtsSecondsFromContent.has_value(), + "Missing beginStreamPtsSecondsFromContent or endStreamPtsSecondsFromContent"); return endStreamPtsSecondsFromContent.value() - beginStreamPtsSecondsFromContent.value(); - case SeekMode::custom_frame_mappings: - if (endStreamPtsSecondsFromContent.has_value() && - beginStreamPtsSecondsFromContent.has_value()) { - return endStreamPtsSecondsFromContent.value() - - beginStreamPtsSecondsFromContent.value(); - } - return std::nullopt; case SeekMode::approximate: if (durationSecondsFromHeader.has_value()) { return durationSecondsFromHeader.value(); @@ -39,9 +37,12 @@ std::optional StreamMetadata::getDurationSeconds( double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const { switch (seekMode) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: + TORCH_CHECK( + beginStreamPtsSecondsFromContent.has_value(), + "Missing beginStreamPtsSecondsFromContent"); return beginStreamPtsSecondsFromContent.value(); - case SeekMode::custom_frame_mappings: case SeekMode::approximate: if (beginStreamPtsSecondsFromContent.has_value()) { return beginStreamPtsSecondsFromContent.value(); @@ -55,9 +56,12 @@ double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const { std::optional StreamMetadata::getEndStreamSeconds( SeekMode seekMode) const { switch (seekMode) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: + TORCH_CHECK( + endStreamPtsSecondsFromContent.has_value(), + "Missing endStreamPtsSecondsFromContent"); return endStreamPtsSecondsFromContent.value(); - case SeekMode::custom_frame_mappings: case SeekMode::approximate: if (endStreamPtsSecondsFromContent.has_value()) { return endStreamPtsSecondsFromContent.value(); @@ -70,9 +74,11 @@ std::optional StreamMetadata::getEndStreamSeconds( std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { switch (seekMode) { + case SeekMode::custom_frame_mappings: case SeekMode::exact: + TORCH_CHECK( + numFramesFromContent.has_value(), "Missing numFramesFromContent"); return numFramesFromContent.value(); - case SeekMode::custom_frame_mappings: case SeekMode::approximate: { if (numFramesFromContent.has_value()) { return numFramesFromContent.value(); diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 348c200a8..e76ede771 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -371,6 +371,10 @@ SeekMode SingleStreamDecoder::getSeekMode() const { return seekMode_; } +int SingleStreamDecoder::getActiveStreamIndex() const { + return activeStreamIndex_; +} + torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { validateActiveStream(AVMEDIA_TYPE_VIDEO); validateScannedAllStreams("getKeyFrameIndices"); diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 9ff504c1c..77ac548e1 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -62,6 +62,9 @@ class SingleStreamDecoder { // Returns the seek mode of this decoder. SeekMode getSeekMode() const; + // Returns the active stream index. Returns -2 if no stream is active. + int getActiveStreamIndex() const; + // Returns the key frame indices as a tensor. The tensor is 1D and contains // int64 values, where each value is the frame index for a key frame. torch::Tensor getKeyFrameIndices(); diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 14a9f899e..436970152 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -797,91 +797,160 @@ std::string get_stream_json_metadata( auto streamMetadata = allStreamMetadata[stream_index]; auto seekMode = videoDecoder->getSeekMode(); + int activeStreamIndex = videoDecoder->getActiveStreamIndex(); std::map map; - if (streamMetadata.durationSecondsFromHeader.has_value()) { - map["durationSecondsFromHeader"] = - std::to_string(*streamMetadata.durationSecondsFromHeader); - } - if (streamMetadata.getDurationSeconds(seekMode).has_value()) { - map["durationSeconds"] = - std::to_string(streamMetadata.getDurationSeconds(seekMode).value()); - } - if (streamMetadata.bitRate.has_value()) { - map["bitRate"] = std::to_string(*streamMetadata.bitRate); - } - if (streamMetadata.numFramesFromContent.has_value()) { - map["numFramesFromContent"] = - std::to_string(*streamMetadata.numFramesFromContent); - } - if (streamMetadata.numFramesFromHeader.has_value()) { - map["numFramesFromHeader"] = - std::to_string(*streamMetadata.numFramesFromHeader); - } - if (streamMetadata.getNumFrames(seekMode).has_value()) { - map["numFrames"] = - std::to_string(streamMetadata.getNumFrames(seekMode).value()); - } + if ((seekMode == SeekMode::custom_frame_mappings && + stream_index == activeStreamIndex) || + (seekMode != SeekMode::custom_frame_mappings)) { + if (streamMetadata.durationSecondsFromHeader.has_value()) { + map["durationSecondsFromHeader"] = + std::to_string(*streamMetadata.durationSecondsFromHeader); + } + if (streamMetadata.getDurationSeconds(seekMode).has_value()) { + map["durationSeconds"] = + std::to_string(streamMetadata.getDurationSeconds(seekMode).value()); + } + if (streamMetadata.bitRate.has_value()) { + map["bitRate"] = std::to_string(*streamMetadata.bitRate); + } + if (streamMetadata.numFramesFromContent.has_value()) { + map["numFramesFromContent"] = + std::to_string(*streamMetadata.numFramesFromContent); + } + if (streamMetadata.numFramesFromHeader.has_value()) { + map["numFramesFromHeader"] = + std::to_string(*streamMetadata.numFramesFromHeader); + } + if (streamMetadata.getNumFrames(seekMode).has_value()) { + map["numFrames"] = + std::to_string(streamMetadata.getNumFrames(seekMode).value()); + } - if (streamMetadata.beginStreamSecondsFromHeader.has_value()) { - map["beginStreamSecondsFromHeader"] = - std::to_string(*streamMetadata.beginStreamSecondsFromHeader); - } - if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) { - map["beginStreamSecondsFromContent"] = - std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); - } - map["beginStreamSeconds"] = - std::to_string(streamMetadata.getBeginStreamSeconds(seekMode)); - if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) { - map["endStreamSecondsFromContent"] = - std::to_string(*streamMetadata.endStreamPtsSecondsFromContent); - } - if (streamMetadata.getEndStreamSeconds(seekMode).has_value()) { - map["endStreamSeconds"] = - std::to_string(streamMetadata.getEndStreamSeconds(seekMode).value()); - } - if (streamMetadata.codecName.has_value()) { - map["codec"] = quoteValue(streamMetadata.codecName.value()); - } - if (streamMetadata.width.has_value()) { - map["width"] = std::to_string(*streamMetadata.width); - } - if (streamMetadata.height.has_value()) { - map["height"] = std::to_string(*streamMetadata.height); - } - if (streamMetadata.sampleAspectRatio.has_value()) { - map["sampleAspectRatioNum"] = - std::to_string((*streamMetadata.sampleAspectRatio).num); - map["sampleAspectRatioDen"] = - std::to_string((*streamMetadata.sampleAspectRatio).den); - } - if (streamMetadata.averageFpsFromHeader.has_value()) { - map["averageFpsFromHeader"] = - std::to_string(*streamMetadata.averageFpsFromHeader); - } - if (streamMetadata.getAverageFps(seekMode).has_value()) { - map["averageFps"] = - std::to_string(streamMetadata.getAverageFps(seekMode).value()); - } - if (streamMetadata.sampleRate.has_value()) { - map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); - } - if (streamMetadata.numChannels.has_value()) { - map["numChannels"] = std::to_string(*streamMetadata.numChannels); - } - if (streamMetadata.sampleFormat.has_value()) { - map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value()); - } - if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) { - map["mediaType"] = quoteValue("video"); - } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) { - map["mediaType"] = quoteValue("audio"); + if (streamMetadata.beginStreamSecondsFromHeader.has_value()) { + map["beginStreamSecondsFromHeader"] = + std::to_string(*streamMetadata.beginStreamSecondsFromHeader); + } + if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) { + map["beginStreamSecondsFromContent"] = + std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); + } + map["beginStreamSeconds"] = + std::to_string(streamMetadata.getBeginStreamSeconds(seekMode)); + if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) { + map["endStreamSecondsFromContent"] = + std::to_string(*streamMetadata.endStreamPtsSecondsFromContent); + } + if (streamMetadata.getEndStreamSeconds(seekMode).has_value()) { + map["endStreamSeconds"] = + std::to_string(streamMetadata.getEndStreamSeconds(seekMode).value()); + } + if (streamMetadata.codecName.has_value()) { + map["codec"] = quoteValue(streamMetadata.codecName.value()); + } + if (streamMetadata.width.has_value()) { + map["width"] = std::to_string(*streamMetadata.width); + } + if (streamMetadata.height.has_value()) { + map["height"] = std::to_string(*streamMetadata.height); + } + if (streamMetadata.sampleAspectRatio.has_value()) { + map["sampleAspectRatioNum"] = + std::to_string((*streamMetadata.sampleAspectRatio).num); + map["sampleAspectRatioDen"] = + std::to_string((*streamMetadata.sampleAspectRatio).den); + } + if (streamMetadata.averageFpsFromHeader.has_value()) { + map["averageFpsFromHeader"] = + std::to_string(*streamMetadata.averageFpsFromHeader); + } + if (streamMetadata.getAverageFps(seekMode).has_value()) { + map["averageFps"] = + std::to_string(streamMetadata.getAverageFps(seekMode).value()); + } + if (streamMetadata.sampleRate.has_value()) { + map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); + } + if (streamMetadata.numChannels.has_value()) { + map["numChannels"] = std::to_string(*streamMetadata.numChannels); + } + if (streamMetadata.sampleFormat.has_value()) { + map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value()); + } + if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) { + map["mediaType"] = quoteValue("video"); + } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) { + map["mediaType"] = quoteValue("audio"); + } else { + map["mediaType"] = quoteValue("other"); + } + return mapToJson(map); } else { - map["mediaType"] = quoteValue("other"); + if (streamMetadata.durationSecondsFromHeader.has_value()) { + map["durationSecondsFromHeader"] = + std::to_string(*streamMetadata.durationSecondsFromHeader); + } + if (streamMetadata.bitRate.has_value()) { + map["bitRate"] = std::to_string(*streamMetadata.bitRate); + } + if (streamMetadata.numFramesFromContent.has_value()) { + map["numFramesFromContent"] = + std::to_string(*streamMetadata.numFramesFromContent); + } + if (streamMetadata.numFramesFromHeader.has_value()) { + map["numFramesFromHeader"] = + std::to_string(*streamMetadata.numFramesFromHeader); + } + if (streamMetadata.beginStreamSecondsFromHeader.has_value()) { + map["beginStreamSecondsFromHeader"] = + std::to_string(*streamMetadata.beginStreamSecondsFromHeader); + } + if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) { + map["beginStreamSecondsFromContent"] = + std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); + } + if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) { + map["endStreamSecondsFromContent"] = + std::to_string(*streamMetadata.endStreamPtsSecondsFromContent); + } + if (streamMetadata.codecName.has_value()) { + map["codec"] = quoteValue(streamMetadata.codecName.value()); + } + if (streamMetadata.width.has_value()) { + map["width"] = std::to_string(*streamMetadata.width); + } + if (streamMetadata.height.has_value()) { + map["height"] = std::to_string(*streamMetadata.height); + } + if (streamMetadata.sampleAspectRatio.has_value()) { + map["sampleAspectRatioNum"] = + std::to_string((*streamMetadata.sampleAspectRatio).num); + map["sampleAspectRatioDen"] = + std::to_string((*streamMetadata.sampleAspectRatio).den); + } + if (streamMetadata.averageFpsFromHeader.has_value()) { + map["averageFpsFromHeader"] = + std::to_string(*streamMetadata.averageFpsFromHeader); + } + if (streamMetadata.sampleRate.has_value()) { + map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); + } + if (streamMetadata.numChannels.has_value()) { + map["numChannels"] = std::to_string(*streamMetadata.numChannels); + } + if (streamMetadata.sampleFormat.has_value()) { + map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value()); + } + if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) { + map["mediaType"] = quoteValue("video"); + } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) { + map["mediaType"] = quoteValue("audio"); + } else { + map["mediaType"] = quoteValue("other"); + } + return mapToJson(map); } - return mapToJson(map); } // Returns version information about the various FFMPEG libraries that are diff --git a/test/MetadataTest.cpp b/test/MetadataTest.cpp index dbd2a3497..3a65b6dc5 100644 --- a/test/MetadataTest.cpp +++ b/test/MetadataTest.cpp @@ -112,27 +112,6 @@ TEST(MetadataTest, DurationSecondsFallback) { metadata.getDurationSeconds(SeekMode::exact).value(), 15.0, 1e-6); } - // in exact mode, begin content available but end missing, should fall back to - // header - { - StreamMetadata metadata; - metadata.durationSecondsFromHeader = 60.0; - metadata.beginStreamPtsSecondsFromContent = 1.0; - metadata.endStreamPtsSecondsFromContent = std::nullopt; - - EXPECT_EQ(metadata.getDurationSeconds(SeekMode::exact), std::nullopt); - } - - // Test case 3: end content available but begin missing, should fall back - { - StreamMetadata metadata; - metadata.durationSecondsFromHeader = 60.0; - metadata.beginStreamPtsSecondsFromContent = std::nullopt; - metadata.endStreamPtsSecondsFromContent = 1.0; - - EXPECT_EQ(metadata.getDurationSeconds(SeekMode::exact), std::nullopt); - } - // in exact mode, only content values, no header { StreamMetadata metadata; From e3b91ea2a3394a87f9b95831019567ce7c64681c Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Thu, 6 Nov 2025 13:32:03 -0800 Subject: [PATCH 12/13] address feedback --- src/torchcodec/_core/Metadata.cpp | 3 --- src/torchcodec/_core/custom_ops.cpp | 14 +++++++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp index b0e9425bd..bfb6b3d31 100644 --- a/src/torchcodec/_core/Metadata.cpp +++ b/src/torchcodec/_core/Metadata.cpp @@ -80,9 +80,6 @@ std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { numFramesFromContent.has_value(), "Missing numFramesFromContent"); return numFramesFromContent.value(); case SeekMode::approximate: { - if (numFramesFromContent.has_value()) { - return numFramesFromContent.value(); - } if (numFramesFromHeader.has_value()) { return numFramesFromHeader.value(); } diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 436970152..fbadc8997 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -801,9 +801,17 @@ std::string get_stream_json_metadata( std::map map; - if ((seekMode == SeekMode::custom_frame_mappings && - stream_index == activeStreamIndex) || - (seekMode != SeekMode::custom_frame_mappings)) { + // Check whether content-based metadata is available for this stream. + // In exact mode: content-based metadata exists for all streams. + // In approximate mode: content-based metadata does not exist for any stream. + // In custom_frame_mappings: content-based metadata exists only for the active + // stream. + // Our fallback logic assumes content-based metadata is available. + // It is available for decoding on the active stream, but would break + // when getting metadata from non-active streams. + if ((seekMode != SeekMode::custom_frame_mappings) || + (seekMode == SeekMode::custom_frame_mappings && + stream_index == activeStreamIndex)) { if (streamMetadata.durationSecondsFromHeader.has_value()) { map["durationSecondsFromHeader"] = std::to_string(*streamMetadata.durationSecondsFromHeader); From 71f486107c37e7d7bfca1927fa98049ed44d30e1 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Fri, 7 Nov 2025 11:01:46 -0800 Subject: [PATCH 13/13] address feedback --- src/torchcodec/_core/Metadata.cpp | 19 +-- src/torchcodec/_core/custom_ops.cpp | 193 ++++++++++------------------ 2 files changed, 75 insertions(+), 137 deletions(-) diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp index bfb6b3d31..58a115dcf 100644 --- a/src/torchcodec/_core/Metadata.cpp +++ b/src/torchcodec/_core/Metadata.cpp @@ -98,18 +98,19 @@ std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { std::optional StreamMetadata::getAverageFps(SeekMode seekMode) const { switch (seekMode) { case SeekMode::custom_frame_mappings: - case SeekMode::exact: - if (getNumFrames(seekMode).has_value() && + case SeekMode::exact: { + auto numFrames = getNumFrames(seekMode); + if (numFrames.has_value() && beginStreamPtsSecondsFromContent.has_value() && - endStreamPtsSecondsFromContent.has_value() && - (beginStreamPtsSecondsFromContent.value() != - endStreamPtsSecondsFromContent.value())) { - return static_cast( - getNumFrames(seekMode).value() / - (endStreamPtsSecondsFromContent.value() - - beginStreamPtsSecondsFromContent.value())); + endStreamPtsSecondsFromContent.has_value()) { + double duration = endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value(); + if (duration != 0.0) { + return static_cast(numFrames.value()) / duration; + } } return averageFpsFromHeader; + } case SeekMode::approximate: return averageFpsFromHeader; default: diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index fbadc8997..2044ac2d8 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -801,6 +801,69 @@ std::string get_stream_json_metadata( std::map map; + if (streamMetadata.durationSecondsFromHeader.has_value()) { + map["durationSecondsFromHeader"] = + std::to_string(*streamMetadata.durationSecondsFromHeader); + } + if (streamMetadata.bitRate.has_value()) { + map["bitRate"] = std::to_string(*streamMetadata.bitRate); + } + if (streamMetadata.numFramesFromContent.has_value()) { + map["numFramesFromContent"] = + std::to_string(*streamMetadata.numFramesFromContent); + } + if (streamMetadata.numFramesFromHeader.has_value()) { + map["numFramesFromHeader"] = + std::to_string(*streamMetadata.numFramesFromHeader); + } + if (streamMetadata.beginStreamSecondsFromHeader.has_value()) { + map["beginStreamSecondsFromHeader"] = + std::to_string(*streamMetadata.beginStreamSecondsFromHeader); + } + if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) { + map["beginStreamSecondsFromContent"] = + std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); + } + if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) { + map["endStreamSecondsFromContent"] = + std::to_string(*streamMetadata.endStreamPtsSecondsFromContent); + } + if (streamMetadata.codecName.has_value()) { + map["codec"] = quoteValue(streamMetadata.codecName.value()); + } + if (streamMetadata.width.has_value()) { + map["width"] = std::to_string(*streamMetadata.width); + } + if (streamMetadata.height.has_value()) { + map["height"] = std::to_string(*streamMetadata.height); + } + if (streamMetadata.sampleAspectRatio.has_value()) { + map["sampleAspectRatioNum"] = + std::to_string((*streamMetadata.sampleAspectRatio).num); + map["sampleAspectRatioDen"] = + std::to_string((*streamMetadata.sampleAspectRatio).den); + } + if (streamMetadata.averageFpsFromHeader.has_value()) { + map["averageFpsFromHeader"] = + std::to_string(*streamMetadata.averageFpsFromHeader); + } + if (streamMetadata.sampleRate.has_value()) { + map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); + } + if (streamMetadata.numChannels.has_value()) { + map["numChannels"] = std::to_string(*streamMetadata.numChannels); + } + if (streamMetadata.sampleFormat.has_value()) { + map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value()); + } + if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) { + map["mediaType"] = quoteValue("video"); + } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) { + map["mediaType"] = quoteValue("audio"); + } else { + map["mediaType"] = quoteValue("other"); + } + // Check whether content-based metadata is available for this stream. // In exact mode: content-based metadata exists for all streams. // In approximate mode: content-based metadata does not exist for any stream. @@ -812,153 +875,27 @@ std::string get_stream_json_metadata( if ((seekMode != SeekMode::custom_frame_mappings) || (seekMode == SeekMode::custom_frame_mappings && stream_index == activeStreamIndex)) { - if (streamMetadata.durationSecondsFromHeader.has_value()) { - map["durationSecondsFromHeader"] = - std::to_string(*streamMetadata.durationSecondsFromHeader); - } if (streamMetadata.getDurationSeconds(seekMode).has_value()) { map["durationSeconds"] = std::to_string(streamMetadata.getDurationSeconds(seekMode).value()); } - if (streamMetadata.bitRate.has_value()) { - map["bitRate"] = std::to_string(*streamMetadata.bitRate); - } - if (streamMetadata.numFramesFromContent.has_value()) { - map["numFramesFromContent"] = - std::to_string(*streamMetadata.numFramesFromContent); - } - if (streamMetadata.numFramesFromHeader.has_value()) { - map["numFramesFromHeader"] = - std::to_string(*streamMetadata.numFramesFromHeader); - } if (streamMetadata.getNumFrames(seekMode).has_value()) { map["numFrames"] = std::to_string(streamMetadata.getNumFrames(seekMode).value()); } - - if (streamMetadata.beginStreamSecondsFromHeader.has_value()) { - map["beginStreamSecondsFromHeader"] = - std::to_string(*streamMetadata.beginStreamSecondsFromHeader); - } - if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) { - map["beginStreamSecondsFromContent"] = - std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); - } map["beginStreamSeconds"] = std::to_string(streamMetadata.getBeginStreamSeconds(seekMode)); - if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) { - map["endStreamSecondsFromContent"] = - std::to_string(*streamMetadata.endStreamPtsSecondsFromContent); - } if (streamMetadata.getEndStreamSeconds(seekMode).has_value()) { map["endStreamSeconds"] = std::to_string(streamMetadata.getEndStreamSeconds(seekMode).value()); } - if (streamMetadata.codecName.has_value()) { - map["codec"] = quoteValue(streamMetadata.codecName.value()); - } - if (streamMetadata.width.has_value()) { - map["width"] = std::to_string(*streamMetadata.width); - } - if (streamMetadata.height.has_value()) { - map["height"] = std::to_string(*streamMetadata.height); - } - if (streamMetadata.sampleAspectRatio.has_value()) { - map["sampleAspectRatioNum"] = - std::to_string((*streamMetadata.sampleAspectRatio).num); - map["sampleAspectRatioDen"] = - std::to_string((*streamMetadata.sampleAspectRatio).den); - } - if (streamMetadata.averageFpsFromHeader.has_value()) { - map["averageFpsFromHeader"] = - std::to_string(*streamMetadata.averageFpsFromHeader); - } if (streamMetadata.getAverageFps(seekMode).has_value()) { map["averageFps"] = std::to_string(streamMetadata.getAverageFps(seekMode).value()); } - if (streamMetadata.sampleRate.has_value()) { - map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); - } - if (streamMetadata.numChannels.has_value()) { - map["numChannels"] = std::to_string(*streamMetadata.numChannels); - } - if (streamMetadata.sampleFormat.has_value()) { - map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value()); - } - if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) { - map["mediaType"] = quoteValue("video"); - } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) { - map["mediaType"] = quoteValue("audio"); - } else { - map["mediaType"] = quoteValue("other"); - } - return mapToJson(map); - } else { - if (streamMetadata.durationSecondsFromHeader.has_value()) { - map["durationSecondsFromHeader"] = - std::to_string(*streamMetadata.durationSecondsFromHeader); - } - if (streamMetadata.bitRate.has_value()) { - map["bitRate"] = std::to_string(*streamMetadata.bitRate); - } - if (streamMetadata.numFramesFromContent.has_value()) { - map["numFramesFromContent"] = - std::to_string(*streamMetadata.numFramesFromContent); - } - if (streamMetadata.numFramesFromHeader.has_value()) { - map["numFramesFromHeader"] = - std::to_string(*streamMetadata.numFramesFromHeader); - } - if (streamMetadata.beginStreamSecondsFromHeader.has_value()) { - map["beginStreamSecondsFromHeader"] = - std::to_string(*streamMetadata.beginStreamSecondsFromHeader); - } - if (streamMetadata.beginStreamPtsSecondsFromContent.has_value()) { - map["beginStreamSecondsFromContent"] = - std::to_string(*streamMetadata.beginStreamPtsSecondsFromContent); - } - if (streamMetadata.endStreamPtsSecondsFromContent.has_value()) { - map["endStreamSecondsFromContent"] = - std::to_string(*streamMetadata.endStreamPtsSecondsFromContent); - } - if (streamMetadata.codecName.has_value()) { - map["codec"] = quoteValue(streamMetadata.codecName.value()); - } - if (streamMetadata.width.has_value()) { - map["width"] = std::to_string(*streamMetadata.width); - } - if (streamMetadata.height.has_value()) { - map["height"] = std::to_string(*streamMetadata.height); - } - if (streamMetadata.sampleAspectRatio.has_value()) { - map["sampleAspectRatioNum"] = - std::to_string((*streamMetadata.sampleAspectRatio).num); - map["sampleAspectRatioDen"] = - std::to_string((*streamMetadata.sampleAspectRatio).den); - } - if (streamMetadata.averageFpsFromHeader.has_value()) { - map["averageFpsFromHeader"] = - std::to_string(*streamMetadata.averageFpsFromHeader); - } - if (streamMetadata.sampleRate.has_value()) { - map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); - } - if (streamMetadata.numChannels.has_value()) { - map["numChannels"] = std::to_string(*streamMetadata.numChannels); - } - if (streamMetadata.sampleFormat.has_value()) { - map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value()); - } - if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) { - map["mediaType"] = quoteValue("video"); - } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) { - map["mediaType"] = quoteValue("audio"); - } else { - map["mediaType"] = quoteValue("other"); - } - return mapToJson(map); } + + return mapToJson(map); } // Returns version information about the various FFMPEG libraries that are