diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 6b4ccb5d4..825331840 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -96,6 +96,7 @@ function(make_torchcodec_libraries Encoder.cpp ValidationUtils.cpp Transform.cpp + Metadata.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp new file mode 100644 index 000000000..58a115dcf --- /dev/null +++ b/src/torchcodec/_core/Metadata.cpp @@ -0,0 +1,121 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "Metadata.h" +#include "torch/types.h" + +namespace facebook::torchcodec { + +std::optional StreamMetadata::getDurationSeconds( + SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + TORCH_CHECK( + endStreamPtsSecondsFromContent.has_value() && + beginStreamPtsSecondsFromContent.has_value(), + "Missing beginStreamPtsSecondsFromContent or endStreamPtsSecondsFromContent"); + return endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value(); + case SeekMode::approximate: + if (durationSecondsFromHeader.has_value()) { + return durationSecondsFromHeader.value(); + } + if (numFramesFromHeader.has_value() && averageFpsFromHeader.has_value() && + averageFpsFromHeader.value() != 0.0) { + return static_cast(numFramesFromHeader.value()) / + averageFpsFromHeader.value(); + } + return std::nullopt; + default: + TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + TORCH_CHECK( + beginStreamPtsSecondsFromContent.has_value(), + "Missing beginStreamPtsSecondsFromContent"); + return beginStreamPtsSecondsFromContent.value(); + case SeekMode::approximate: + if (beginStreamPtsSecondsFromContent.has_value()) { + return beginStreamPtsSecondsFromContent.value(); + } + return 0.0; + default: + TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +std::optional StreamMetadata::getEndStreamSeconds( + SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + TORCH_CHECK( + endStreamPtsSecondsFromContent.has_value(), + "Missing endStreamPtsSecondsFromContent"); + return endStreamPtsSecondsFromContent.value(); + case SeekMode::approximate: + if (endStreamPtsSecondsFromContent.has_value()) { + return endStreamPtsSecondsFromContent.value(); + } + return getDurationSeconds(seekMode); + default: + TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: + TORCH_CHECK( + numFramesFromContent.has_value(), "Missing numFramesFromContent"); + return numFramesFromContent.value(); + case SeekMode::approximate: { + if (numFramesFromHeader.has_value()) { + return numFramesFromHeader.value(); + } + if (averageFpsFromHeader.has_value() && + durationSecondsFromHeader.has_value()) { + return static_cast( + averageFpsFromHeader.value() * durationSecondsFromHeader.value()); + } + return std::nullopt; + } + default: + TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +std::optional StreamMetadata::getAverageFps(SeekMode seekMode) const { + switch (seekMode) { + case SeekMode::custom_frame_mappings: + case SeekMode::exact: { + auto numFrames = getNumFrames(seekMode); + if (numFrames.has_value() && + beginStreamPtsSecondsFromContent.has_value() && + endStreamPtsSecondsFromContent.has_value()) { + double duration = endStreamPtsSecondsFromContent.value() - + beginStreamPtsSecondsFromContent.value(); + if (duration != 0.0) { + return static_cast(numFrames.value()) / duration; + } + } + return averageFpsFromHeader; + } + case SeekMode::approximate: + return averageFpsFromHeader; + default: + TORCH_CHECK(false, "Unknown SeekMode"); + } +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Metadata.h b/src/torchcodec/_core/Metadata.h index ace6cf84c..e138d5dc0 100644 --- a/src/torchcodec/_core/Metadata.h +++ b/src/torchcodec/_core/Metadata.h @@ -18,6 +18,8 @@ extern "C" { namespace facebook::torchcodec { +enum class SeekMode { exact, approximate, custom_frame_mappings }; + struct StreamMetadata { // Common (video and audio) fields derived from the AVStream. int streamIndex; @@ -52,6 +54,13 @@ struct StreamMetadata { std::optional sampleRate; std::optional numChannels; std::optional sampleFormat; + + // Computed methods with fallback logic + std::optional getDurationSeconds(SeekMode seekMode) const; + double getBeginStreamSeconds(SeekMode seekMode) const; + std::optional getEndStreamSeconds(SeekMode seekMode) const; + std::optional getNumFrames(SeekMode seekMode) const; + std::optional getAverageFps(SeekMode seekMode) const; }; struct ContainerMetadata { diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 72cd7afac..e76ede771 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -367,6 +367,14 @@ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { return containerMetadata_; } +SeekMode SingleStreamDecoder::getSeekMode() const { + return seekMode_; +} + +int SingleStreamDecoder::getActiveStreamIndex() const { + return activeStreamIndex_; +} + torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { validateActiveStream(AVMEDIA_TYPE_VIDEO); validateScannedAllStreams("getKeyFrameIndices"); @@ -611,7 +619,7 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - std::optional numFrames = getNumFrames(streamMetadata); + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); if (numFrames.has_value()) { // If the frameIndex is negative, we convert it to a positive index frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value(); @@ -705,7 +713,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( // Note that if we do not have the number of frames available in our // metadata, then we assume that the upper part of the range is valid. - std::optional numFrames = getNumFrames(streamMetadata); + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); if (numFrames.has_value()) { TORCH_CHECK( stop <= numFrames.value(), @@ -779,8 +787,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - double minSeconds = getMinSeconds(streamMetadata); - std::optional maxSeconds = getMaxSeconds(streamMetadata); + double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_); + std::optional maxSeconds = + streamMetadata.getEndStreamSeconds(seekMode_); // The frame played at timestamp t and the one played at timestamp `t + // eps` are probably the same frame, with the same index. The easiest way to @@ -857,7 +866,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( return frameBatchOutput; } - double minSeconds = getMinSeconds(streamMetadata); + double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_); TORCH_CHECK( startSeconds >= minSeconds, "Start seconds is " + std::to_string(startSeconds) + @@ -866,7 +875,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // Note that if we can't determine the maximum seconds from the metadata, // then we assume upper range is valid. - std::optional maxSeconds = getMaxSeconds(streamMetadata); + std::optional maxSeconds = + streamMetadata.getEndStreamSeconds(seekMode_); if (maxSeconds.has_value()) { TORCH_CHECK( startSeconds < maxSeconds.value(), @@ -1439,47 +1449,6 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { // STREAM AND METADATA APIS // -------------------------------------------------------------------------- -std::optional SingleStreamDecoder::getNumFrames( - const StreamMetadata& streamMetadata) { - switch (seekMode_) { - case SeekMode::custom_frame_mappings: - case SeekMode::exact: - return streamMetadata.numFramesFromContent.value(); - case SeekMode::approximate: { - return streamMetadata.numFramesFromHeader; - } - default: - TORCH_CHECK(false, "Unknown SeekMode"); - } -} - -double SingleStreamDecoder::getMinSeconds( - const StreamMetadata& streamMetadata) { - switch (seekMode_) { - case SeekMode::custom_frame_mappings: - case SeekMode::exact: - return streamMetadata.beginStreamPtsSecondsFromContent.value(); - case SeekMode::approximate: - return 0; - default: - TORCH_CHECK(false, "Unknown SeekMode"); - } -} - -std::optional SingleStreamDecoder::getMaxSeconds( - const StreamMetadata& streamMetadata) { - switch (seekMode_) { - case SeekMode::custom_frame_mappings: - case SeekMode::exact: - return streamMetadata.endStreamPtsSecondsFromContent.value(); - case SeekMode::approximate: { - return streamMetadata.durationSecondsFromHeader; - } - default: - TORCH_CHECK(false, "Unknown SeekMode"); - } -} - // -------------------------------------------------------------------------- // VALIDATION UTILS // -------------------------------------------------------------------------- @@ -1529,7 +1498,7 @@ void SingleStreamDecoder::validateFrameIndex( // Note that if we do not have the number of frames available in our // metadata, then we assume that the frameIndex is valid. - std::optional numFrames = getNumFrames(streamMetadata); + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); if (numFrames.has_value()) { if (frameIndex >= numFrames.value()) { throw std::out_of_range( diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 4b41811ff..77ac548e1 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -16,6 +16,7 @@ #include "DeviceInterface.h" #include "FFMPEGCommon.h" #include "Frame.h" +#include "Metadata.h" #include "StreamOptions.h" #include "Transform.h" @@ -30,8 +31,6 @@ class SingleStreamDecoder { // CONSTRUCTION API // -------------------------------------------------------------------------- - enum class SeekMode { exact, approximate, custom_frame_mappings }; - // Creates a SingleStreamDecoder from the video at videoFilePath. explicit SingleStreamDecoder( const std::string& videoFilePath, @@ -60,6 +59,12 @@ class SingleStreamDecoder { // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; + // Returns the seek mode of this decoder. + SeekMode getSeekMode() const; + + // Returns the active stream index. Returns -2 if no stream is active. + int getActiveStreamIndex() const; + // Returns the key frame indices as a tensor. The tensor is 1D and contains // int64 values, where each value is the frame index for a key frame. torch::Tensor getKeyFrameIndices(); @@ -312,10 +317,6 @@ class SingleStreamDecoder { // index. Note that this index may be truncated for some files. int getBestStreamIndex(AVMediaType mediaType); - std::optional getNumFrames(const StreamMetadata& streamMetadata); - double getMinSeconds(const StreamMetadata& streamMetadata); - std::optional getMaxSeconds(const StreamMetadata& streamMetadata); - // -------------------------------------------------------------------------- // VALIDATION UTILS // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 1f011f516..08bcf2b55 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -38,6 +38,21 @@ class StreamMetadata: stream_index: int """Index of the stream that this metadata refers to (int).""" + # Computed fields (computed in C++ with fallback logic) + duration_seconds: Optional[float] + """Duration of the stream in seconds. We try to calculate the duration + from the actual frames if a :term:`scan` was performed. Otherwise we + fall back to ``duration_seconds_from_header``. If that value is also None, + we instead calculate the duration from ``num_frames_from_header`` and + ``average_fps_from_header``. + """ + begin_stream_seconds: Optional[float] + """Beginning of the stream, in seconds (float). Conceptually, this + corresponds to the first frame's :term:`pts`. If a :term:`scan` was performed + and ``begin_stream_seconds_from_content`` is not None, then it is returned. + Otherwise, this value is 0. + """ + def __repr__(self): s = self.__class__.__name__ + ":\n" for field in dataclasses.fields(self): @@ -87,103 +102,27 @@ class VideoStreamMetadata(StreamMetadata): is the ratio between the width and height of each pixel (``fractions.Fraction`` or None).""" - @property - def duration_seconds(self) -> Optional[float]: - """Duration of the stream in seconds. We try to calculate the duration - from the actual frames if a :term:`scan` was performed. Otherwise we - fall back to ``duration_seconds_from_header``. If that value is also None, - we instead calculate the duration from ``num_frames_from_header`` and - ``average_fps_from_header``. - """ - if ( - self.end_stream_seconds_from_content is not None - and self.begin_stream_seconds_from_content is not None - ): - return ( - self.end_stream_seconds_from_content - - self.begin_stream_seconds_from_content - ) - elif self.duration_seconds_from_header is not None: - return self.duration_seconds_from_header - elif ( - self.num_frames_from_header is not None - and self.average_fps_from_header is not None - ): - return self.num_frames_from_header / self.average_fps_from_header - else: - return None - - @property - def begin_stream_seconds(self) -> float: - """Beginning of the stream, in seconds (float). Conceptually, this - corresponds to the first frame's :term:`pts`. If - ``begin_stream_seconds_from_content`` is not None, then it is returned. - Otherwise, this value is 0. - """ - if self.begin_stream_seconds_from_content is None: - return 0 - else: - return self.begin_stream_seconds_from_content - - @property - def end_stream_seconds(self) -> Optional[float]: - """End of the stream, in seconds (float or None). - Conceptually, this corresponds to last_frame.pts + last_frame.duration. - If ``end_stream_seconds_from_content`` is not None, then that value is - returned. Otherwise, returns ``duration_seconds``. - """ - if self.end_stream_seconds_from_content is None: - return self.duration_seconds - else: - return self.end_stream_seconds_from_content - - @property - def num_frames(self) -> Optional[int]: - """Number of frames in the stream (int or None). - This corresponds to ``num_frames_from_content`` if a :term:`scan` was made, - otherwise it corresponds to ``num_frames_from_header``. If that value is also - None, the number of frames is calculated from the duration and the average fps. - """ - if self.num_frames_from_content is not None: - return self.num_frames_from_content - elif self.num_frames_from_header is not None: - return self.num_frames_from_header - elif ( - self.average_fps_from_header is not None - and self.duration_seconds_from_header is not None - ): - return int(self.average_fps_from_header * self.duration_seconds_from_header) - else: - return None - - @property - def average_fps(self) -> Optional[float]: - """Average fps of the stream. If a :term:`scan` was perfomed, this is - computed from the number of frames and the duration of the stream. - Otherwise we fall back to ``average_fps_from_header``. - """ - if ( - self.end_stream_seconds_from_content is None - or self.begin_stream_seconds_from_content is None - or self.num_frames is None - # Should never happen, but prevents ZeroDivisionError: - or self.end_stream_seconds_from_content - == self.begin_stream_seconds_from_content - ): - return self.average_fps_from_header - return self.num_frames / ( - self.end_stream_seconds_from_content - - self.begin_stream_seconds_from_content - ) + # Computed fields (computed in C++ with fallback logic) + end_stream_seconds: Optional[float] + """End of the stream, in seconds (float or None). + Conceptually, this corresponds to last_frame.pts + last_frame.duration. + If :term:`scan` was performed and``end_stream_seconds_from_content`` is not None, then that value is + returned. Otherwise, returns ``duration_seconds``. + """ + num_frames: Optional[int] + """Number of frames in the stream (int or None). + This corresponds to ``num_frames_from_content`` if a :term:`scan` was made, + otherwise it corresponds to ``num_frames_from_header``. If that value is also + None, the number of frames is calculated from the duration and the average fps. + """ + average_fps: Optional[float] + """Average fps of the stream. If a :term:`scan` was perfomed, this is + computed from the number of frames and the duration of the stream. + Otherwise we fall back to ``average_fps_from_header``. + """ def __repr__(self): - s = super().__repr__() - s += f"{SPACES}duration_seconds: {self.duration_seconds}\n" - s += f"{SPACES}begin_stream_seconds: {self.begin_stream_seconds}\n" - s += f"{SPACES}end_stream_seconds: {self.end_stream_seconds}\n" - s += f"{SPACES}num_frames: {self.num_frames}\n" - s += f"{SPACES}average_fps: {self.average_fps}\n" - return s + return super().__repr__() @dataclass @@ -260,10 +199,12 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata: stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index)) common_meta = dict( duration_seconds_from_header=stream_dict.get("durationSecondsFromHeader"), + duration_seconds=stream_dict.get("durationSeconds"), bit_rate=stream_dict.get("bitRate"), begin_stream_seconds_from_header=stream_dict.get( "beginStreamSecondsFromHeader" ), + begin_stream_seconds=stream_dict.get("beginStreamSeconds"), codec=stream_dict.get("codec"), stream_index=stream_index, ) @@ -276,6 +217,9 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata: end_stream_seconds_from_content=stream_dict.get( "endStreamSecondsFromContent" ), + end_stream_seconds=stream_dict.get("endStreamSeconds"), + num_frames=stream_dict.get("numFrames"), + average_fps=stream_dict.get("averageFps"), width=stream_dict.get("width"), height=stream_dict.get("height"), num_frames_from_header=stream_dict.get("numFramesFromHeader"), diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index b78160f6c..2044ac2d8 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -176,13 +176,13 @@ std::string mapToJson(const std::map& metadataMap) { return ss.str(); } -SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) { +SeekMode seekModeFromString(std::string_view seekMode) { if (seekMode == "exact") { - return SingleStreamDecoder::SeekMode::exact; + return SeekMode::exact; } else if (seekMode == "approximate") { - return SingleStreamDecoder::SeekMode::approximate; + return SeekMode::approximate; } else if (seekMode == "custom_frame_mappings") { - return SingleStreamDecoder::SeekMode::custom_frame_mappings; + return SeekMode::custom_frame_mappings; } else { TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); } @@ -285,7 +285,7 @@ at::Tensor create_from_file( std::optional seek_mode = std::nullopt) { std::string filenameStr(filename); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SeekMode realSeek = SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } @@ -306,7 +306,7 @@ at::Tensor create_from_tensor( video_tensor.scalar_type() == torch::kUInt8, "video_tensor must be kUInt8"); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SeekMode realSeek = SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } @@ -329,7 +329,7 @@ at::Tensor _create_from_file_like( fileLikeContext != nullptr, "file_like_context must be a valid pointer"); std::unique_ptr avioContextHolder(fileLikeContext); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SeekMode realSeek = SeekMode::exact; if (seek_mode.has_value()) { realSeek = seekModeFromString(seek_mode.value()); } @@ -796,6 +796,8 @@ std::string get_stream_json_metadata( } auto streamMetadata = allStreamMetadata[stream_index]; + auto seekMode = videoDecoder->getSeekMode(); + int activeStreamIndex = videoDecoder->getActiveStreamIndex(); std::map map; @@ -861,6 +863,38 @@ std::string get_stream_json_metadata( } else { map["mediaType"] = quoteValue("other"); } + + // Check whether content-based metadata is available for this stream. + // In exact mode: content-based metadata exists for all streams. + // In approximate mode: content-based metadata does not exist for any stream. + // In custom_frame_mappings: content-based metadata exists only for the active + // stream. + // Our fallback logic assumes content-based metadata is available. + // It is available for decoding on the active stream, but would break + // when getting metadata from non-active streams. + if ((seekMode != SeekMode::custom_frame_mappings) || + (seekMode == SeekMode::custom_frame_mappings && + stream_index == activeStreamIndex)) { + if (streamMetadata.getDurationSeconds(seekMode).has_value()) { + map["durationSeconds"] = + std::to_string(streamMetadata.getDurationSeconds(seekMode).value()); + } + if (streamMetadata.getNumFrames(seekMode).has_value()) { + map["numFrames"] = + std::to_string(streamMetadata.getNumFrames(seekMode).value()); + } + map["beginStreamSeconds"] = + std::to_string(streamMetadata.getBeginStreamSeconds(seekMode)); + if (streamMetadata.getEndStreamSeconds(seekMode).has_value()) { + map["endStreamSeconds"] = + std::to_string(streamMetadata.getEndStreamSeconds(seekMode).value()); + } + if (streamMetadata.getAverageFps(seekMode).has_value()) { + map["averageFps"] = + std::to_string(streamMetadata.getAverageFps(seekMode).value()); + } + } + return mapToJson(map); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 988def933..f96c06223 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -34,3 +34,22 @@ target_link_libraries( include(GoogleTest) gtest_discover_tests(VideoDecoderTest) + + +add_executable( + MetadataTest + MetadataTest.cpp +) + +target_include_directories(MetadataTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) +target_include_directories(MetadataTest SYSTEM PRIVATE ${libav_include_dirs}) +target_include_directories(MetadataTest PRIVATE ../) + +target_link_libraries( + MetadataTest + ${libtorchcodec_library_name} + ${libtorchcodec_custom_ops_name} + GTest::gtest_main +) + +gtest_discover_tests(MetadataTest) diff --git a/test/MetadataTest.cpp b/test/MetadataTest.cpp new file mode 100644 index 000000000..3a65b6dc5 --- /dev/null +++ b/test/MetadataTest.cpp @@ -0,0 +1,191 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/Metadata.h" + +#include + +namespace facebook::torchcodec { + +// Test that num_frames_from_content always has priority when accessing +// getNumFrames() +TEST(MetadataTest, NumFramesFallbackPriority) { + // in exact mode, both header and content available + { + StreamMetadata metadata; + metadata.numFramesFromHeader = 10; + metadata.numFramesFromContent = 20; + metadata.durationSecondsFromHeader = 4.0; + metadata.averageFpsFromHeader = 30.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::exact), 20); + } + + // in exact mode, only content available + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = 10; + metadata.durationSecondsFromHeader = 4.0; + metadata.averageFpsFromHeader = 30.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::exact), 10); + } + + // in approximate mode, header should be used + { + StreamMetadata metadata; + metadata.numFramesFromHeader = 10; + metadata.numFramesFromContent = std::nullopt; + metadata.durationSecondsFromHeader = 4.0; + metadata.averageFpsFromHeader = 30.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), 10); + } +} + +// Test that if num_frames_from_content and num_frames_from_header are missing, +// getNumFrames() is calculated using average_fps_from_header and +// duration_seconds_from_header in approximate mode +TEST(MetadataTest, CalculateNumFramesUsingFpsAndDuration) { + // both fps and duration available + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = 60.0; + metadata.durationSecondsFromHeader = 10.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), 600); + } + + // fps available but duration missing + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = 60.0; + metadata.durationSecondsFromHeader = std::nullopt; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), std::nullopt); + } + + // duration available but fps missing + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = std::nullopt; + metadata.durationSecondsFromHeader = 10.0; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), std::nullopt); + } + + // both missing + { + StreamMetadata metadata; + metadata.numFramesFromHeader = std::nullopt; + metadata.numFramesFromContent = std::nullopt; + metadata.averageFpsFromHeader = std::nullopt; + metadata.durationSecondsFromHeader = std::nullopt; + + EXPECT_EQ(metadata.getNumFrames(SeekMode::approximate), std::nullopt); + } +} + +// Test that using begin_stream_seconds_from_content and +// end_stream_seconds_from_content to calculate getDurationSeconds() has +// priority. If either value is missing, duration_seconds_from_header is used. +TEST(MetadataTest, DurationSecondsFallback) { + // in exact mode, both begin and end content available, should calculate from + // them + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = 60.0; + metadata.beginStreamPtsSecondsFromContent = 5.0; + metadata.endStreamPtsSecondsFromContent = 20.0; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::exact).value(), 15.0, 1e-6); + } + + // in exact mode, only content values, no header + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.beginStreamPtsSecondsFromContent = 0.0; + metadata.endStreamPtsSecondsFromContent = 10.0; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::exact).value(), 10.0, 1e-6); + } + + // in approximate mode, header value takes priority (ignores content) + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = 60.0; + metadata.beginStreamPtsSecondsFromContent = 5.0; + metadata.endStreamPtsSecondsFromContent = 20.0; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::approximate).value(), 60.0, 1e-6); + } +} + +// Test that duration_seconds is calculated using average_fps_from_header and +// num_frames_from_header if duration_seconds_from_header is missing. +TEST(MetadataTest, CalculateDurationSecondsUsingFpsAndNumFrames) { + // in approximate mode, both num_frames and fps available + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = 100; + metadata.averageFpsFromHeader = 10.0; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_NEAR( + metadata.getDurationSeconds(SeekMode::approximate).value(), 10.0, 1e-6); + } + + // in approximate mode, num_frames available but fps missing + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = 100; + metadata.averageFpsFromHeader = std::nullopt; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::approximate), std::nullopt); + } + + // in approximate mode, fps available but num_frames missing + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = std::nullopt; + metadata.averageFpsFromHeader = 10.0; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::approximate), std::nullopt); + } + + // in approximate mode, both missing + { + StreamMetadata metadata; + metadata.durationSecondsFromHeader = std::nullopt; + metadata.numFramesFromHeader = std::nullopt; + metadata.averageFpsFromHeader = std::nullopt; + metadata.beginStreamPtsSecondsFromContent = std::nullopt; + metadata.endStreamPtsSecondsFromContent = std::nullopt; + + EXPECT_EQ(metadata.getDurationSeconds(SeekMode::approximate), std::nullopt); + } +} + +} // namespace facebook::torchcodec diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 1481d3a2a..63abb3d76 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -64,10 +64,10 @@ class SingleStreamDecoderTest : public testing::TestWithParam { auto contextHolder = std::make_unique(tensor); return std::make_unique( - std::move(contextHolder), SingleStreamDecoder::SeekMode::approximate); + std::move(contextHolder), SeekMode::approximate); } else { return std::make_unique( - filepath, SingleStreamDecoder::SeekMode::approximate); + filepath, SeekMode::approximate); } } diff --git a/test/test_decoders.py b/test/test_decoders.py index 5e5028da6..e9f3f701b 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -6,9 +6,7 @@ import contextlib import gc -import json from functools import partial -from unittest.mock import patch import numpy import pytest @@ -877,56 +875,6 @@ def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode): ).to(device) assert_frames_equal(frames387_None.data, reference_frame387_389) - @pytest.mark.parametrize("device", all_supported_devices()) - @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) - @patch("torchcodec._core._metadata._get_stream_json_metadata") - def test_get_frames_with_missing_num_frames_metadata( - self, mock_get_stream_json_metadata, device, seek_mode - ): - # Create a mock stream_dict to test that initializing VideoDecoder without - # num_frames_from_header and num_frames_from_content calculates num_frames - # using the average_fps and duration_seconds metadata. - mock_stream_dict = { - "averageFpsFromHeader": 29.97003, - "beginStreamSecondsFromContent": 0.0, - "beginStreamSecondsFromHeader": 0.0, - "bitRate": 128783.0, - "codec": "h264", - "durationSecondsFromHeader": 13.013, - "endStreamSecondsFromContent": 13.013, - "width": 480, - "height": 270, - "mediaType": "video", - "numFramesFromHeader": None, - "numFramesFromContent": None, - } - # Set the return value of the mock to be the mock_stream_dict - mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict) - - decoder, device = make_video_decoder( - NASA_VIDEO.path, - stream_index=3, - device=device, - seek_mode=seek_mode, - ) - - assert decoder.metadata.num_frames_from_header is None - assert decoder.metadata.num_frames_from_content is None - assert decoder.metadata.duration_seconds is not None - assert decoder.metadata.average_fps is not None - assert decoder.metadata.num_frames == int( - decoder.metadata.duration_seconds * decoder.metadata.average_fps - ) - assert len(decoder) == 390 - - # Test get_frames_in_range Python logic which uses the num_frames metadata mocked earlier. - # The frame is read at the C++ level. - ref_frames9 = NASA_VIDEO.get_frame_data_by_range( - start=9, stop=10, stream_index=3 - ).to(device) - frames9 = decoder.get_frames_in_range(start=9, stop=10) - assert_frames_equal(ref_frames9, frames9.data) - @pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"]) @pytest.mark.parametrize( "frame_getter", diff --git a/test/test_metadata.py b/test/test_metadata.py index 628b7a68d..a4f6da341 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -147,123 +147,6 @@ def test_get_metadata_audio_file(metadata_getter): assert best_audio_stream_metadata.sample_format == "fltp" -@pytest.mark.parametrize( - "num_frames_from_header, num_frames_from_content, expected_num_frames", - [(10, 20, 20), (None, 10, 10), (10, None, 10)], -) -def test_num_frames_fallback( - num_frames_from_header, num_frames_from_content, expected_num_frames -): - """Check that num_frames_from_content always has priority when accessing `.num_frames`""" - metadata = VideoStreamMetadata( - duration_seconds_from_header=4, - bit_rate=123, - num_frames_from_header=num_frames_from_header, - num_frames_from_content=num_frames_from_content, - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=0, - end_stream_seconds_from_content=4, - codec="whatever", - width=123, - height=321, - average_fps_from_header=30, - pixel_aspect_ratio=Fraction(1, 1), - stream_index=0, - ) - - assert metadata.num_frames == expected_num_frames - - -@pytest.mark.parametrize( - "average_fps_from_header, duration_seconds_from_header, expected_num_frames", - [(60, 10, 600), (60, None, None), (None, 10, None), (None, None, None)], -) -def test_calculate_num_frames_using_fps_and_duration( - average_fps_from_header, duration_seconds_from_header, expected_num_frames -): - """Check that if num_frames_from_content and num_frames_from_header are missing, - `.num_frames` is calculated using average_fps_from_header and duration_seconds_from_header - """ - metadata = VideoStreamMetadata( - duration_seconds_from_header=duration_seconds_from_header, - bit_rate=123, - num_frames_from_header=None, # None to test calculating num_frames - num_frames_from_content=None, # None to test calculating num_frames - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=0, - end_stream_seconds_from_content=4, - codec="whatever", - width=123, - height=321, - pixel_aspect_ratio=Fraction(10, 11), - average_fps_from_header=average_fps_from_header, - stream_index=0, - ) - - assert metadata.num_frames == expected_num_frames - - -@pytest.mark.parametrize( - "duration_seconds_from_header, begin_stream_seconds_from_content, end_stream_seconds_from_content, expected_duration_seconds", - [(60, 5, 20, 15), (60, 1, None, 60), (60, None, 1, 60), (None, 0, 10, 10)], -) -def test_duration_seconds_fallback( - duration_seconds_from_header, - begin_stream_seconds_from_content, - end_stream_seconds_from_content, - expected_duration_seconds, -): - """Check that using begin_stream_seconds_from_content and end_stream_seconds_from_content to calculate `.duration_seconds` - has priority. If either value is missing, duration_seconds_from_header is used. - """ - metadata = VideoStreamMetadata( - duration_seconds_from_header=duration_seconds_from_header, - bit_rate=123, - num_frames_from_header=5, - num_frames_from_content=10, - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=begin_stream_seconds_from_content, - end_stream_seconds_from_content=end_stream_seconds_from_content, - codec="whatever", - width=123, - height=321, - pixel_aspect_ratio=Fraction(10, 11), - average_fps_from_header=5, - stream_index=0, - ) - - assert metadata.duration_seconds == expected_duration_seconds - - -@pytest.mark.parametrize( - "num_frames_from_header, average_fps_from_header, expected_duration_seconds", - [(100, 10, 10), (100, None, None), (None, 10, None), (None, None, None)], -) -def test_calculate_duration_seconds_using_fps_and_num_frames( - num_frames_from_header, average_fps_from_header, expected_duration_seconds -): - """Check that duration_seconds is calculated using average_fps_from_header and num_frames_from_header - if duration_seconds_from_header is missing. - """ - metadata = VideoStreamMetadata( - duration_seconds_from_header=None, # None to test calculating duration_seconds - bit_rate=123, - num_frames_from_header=num_frames_from_header, - num_frames_from_content=10, - begin_stream_seconds_from_header=0, - begin_stream_seconds_from_content=None, # None to test calculating duration_seconds - end_stream_seconds_from_content=None, # None to test calculating duration_seconds - codec="whatever", - width=123, - height=321, - pixel_aspect_ratio=Fraction(10, 11), - average_fps_from_header=average_fps_from_header, - stream_index=0, - ) - assert metadata.duration_seconds_from_header is None - assert metadata.duration_seconds == expected_duration_seconds - - def test_repr(): # Test for calls to print(), str(), etc. Useful to make sure we don't forget # to add additional @properties to __repr__ @@ -275,6 +158,8 @@ def test_repr(): bit_rate: 128783.0 codec: h264 stream_index: 3 + duration_seconds: 13.013 + begin_stream_seconds: 0.0 begin_stream_seconds_from_content: 0.0 end_stream_seconds_from_content: 13.013 width: 480 @@ -283,11 +168,9 @@ def test_repr(): num_frames_from_content: 390 average_fps_from_header: 29.97003 pixel_aspect_ratio: 1 - duration_seconds: 13.013 - begin_stream_seconds: 0.0 end_stream_seconds: 13.013 num_frames: 390 - average_fps: 29.97002997002997 + average_fps: 29.97003 """ ) ffmpeg_major_version = get_ffmpeg_major_version() @@ -303,6 +186,8 @@ def test_repr(): bit_rate: 64000.0 codec: mp3 stream_index: 0 + duration_seconds: {expected_duration_seconds_from_header} + begin_stream_seconds: 0.0 sample_rate: 8000 num_channels: 2 sample_format: fltp diff --git a/test/test_samplers.py b/test/test_samplers.py index 10c529062..24a57e2c6 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -595,6 +595,7 @@ def restore_metadata(): decoder.metadata.num_frames_from_header = ( None # Set to none to prevent fallback calculation ) + decoder.metadata.end_stream_seconds = None with pytest.raises( ValueError, match="Could not infer stream end from video metadata" ): @@ -603,6 +604,7 @@ def restore_metadata(): with restore_metadata(): decoder.metadata.end_stream_seconds_from_content = None decoder.metadata.average_fps_from_header = None + decoder.metadata.average_fps = None with pytest.raises(ValueError, match="Could not infer average fps"): sampler(decoder)