diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index d06c47922..f3f887ab3 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -785,8 +785,10 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( for (int64_t i = 0; i < timestamps.numel(); ++i) { auto frameSeconds = timestampsAccessor[i]; + // Use machine epsilon scaled for video processing precision errors + constexpr double eps = std::numeric_limits::epsilon() * 1000; TORCH_CHECK( - frameSeconds >= minSeconds, + frameSeconds >= (minSeconds - eps), "frame pts is " + std::to_string(frameSeconds) + "; must be greater than or equal to " + std::to_string(minSeconds) + "."); @@ -795,7 +797,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( // metadata, then we assume the frame's pts is valid. if (maxSeconds.has_value()) { TORCH_CHECK( - frameSeconds < maxSeconds.value(), + frameSeconds < (maxSeconds.value() + eps), "frame pts is " + std::to_string(frameSeconds) + "; must be less than " + std::to_string(maxSeconds.value()) + "."); diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 1481d3a2a..13fca13e9 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -441,6 +441,191 @@ TEST_P(SingleStreamDecoderTest, GetAudioMetadata) { EXPECT_NEAR(*audioStream.durationSecondsFromHeader, 13.25, 1e-1); } +TEST_P(SingleStreamDecoderTest, FloatingPointPrecisionExactTimestampsWork) { + std::string path = getResourcePath("nasa_13013.mp4"); + + std::unique_ptr ourDecoder = + std::make_unique( + path, SingleStreamDecoder::SeekMode::exact); + ourDecoder->scanFileAndUpdateMetadataAndIndex(); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); + + ContainerMetadata metadata = ourDecoder->getContainerMetadata(); + const auto& videoStream = metadata.allStreamMetadata[3]; + + EXPECT_TRUE(videoStream.beginStreamPtsSecondsFromContent.has_value()); + double minSeconds = videoStream.beginStreamPtsSecondsFromContent.value(); + + EXPECT_NO_THROW({ + auto timestamps = torch::tensor({minSeconds}, torch::dtype(torch::kDouble)); + auto output = ourDecoder->getFramesPlayedAt(timestamps); + EXPECT_EQ(output.data.size(0), 1); + }); +} + +TEST_P(SingleStreamDecoderTest, FloatingPointPrecisionSmallEpsilonErrorsWork) { + std::string path = getResourcePath("nasa_13013.mp4"); + + std::unique_ptr ourDecoder = + std::make_unique( + path, SingleStreamDecoder::SeekMode::exact); + ourDecoder->scanFileAndUpdateMetadataAndIndex(); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); + + ContainerMetadata metadata = ourDecoder->getContainerMetadata(); + const auto& videoStream = metadata.allStreamMetadata[3]; + + EXPECT_TRUE(videoStream.beginStreamPtsSecondsFromContent.has_value()); + double minSeconds = videoStream.beginStreamPtsSecondsFromContent.value(); + + // Simulate small floating-point precision errors from video processing + constexpr double eps = std::numeric_limits::epsilon(); + double almostMinSeconds = minSeconds - eps * 100; + + EXPECT_NO_THROW({ + auto timestamps = + torch::tensor({almostMinSeconds}, torch::dtype(torch::kDouble)); + auto output = ourDecoder->getFramesPlayedAt(timestamps); + EXPECT_EQ(output.data.size(0), 1); + }); +} + +TEST_P(SingleStreamDecoderTest, FloatingPointPrecisionLargerEpsilonErrorsWork) { + std::string path = getResourcePath("nasa_13013.mp4"); + + std::unique_ptr ourDecoder = + std::make_unique( + path, SingleStreamDecoder::SeekMode::exact); + ourDecoder->scanFileAndUpdateMetadataAndIndex(); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); + + ContainerMetadata metadata = ourDecoder->getContainerMetadata(); + const auto& videoStream = metadata.allStreamMetadata[3]; + + EXPECT_TRUE(videoStream.beginStreamPtsSecondsFromContent.has_value()); + double minSeconds = videoStream.beginStreamPtsSecondsFromContent.value(); + + constexpr double eps = std::numeric_limits::epsilon(); + double precisionErrorSeconds = minSeconds - eps * 500; + + EXPECT_NO_THROW({ + auto timestamps = + torch::tensor({precisionErrorSeconds}, torch::dtype(torch::kDouble)); + auto output = ourDecoder->getFramesPlayedAt(timestamps); + EXPECT_EQ(output.data.size(0), 1); + }); +} + +TEST_P( + SingleStreamDecoderTest, + FloatingPointPrecisionInvalidTimestampsStillFail) { + std::string path = getResourcePath("nasa_13013.mp4"); + + std::unique_ptr ourDecoder = + std::make_unique( + path, SingleStreamDecoder::SeekMode::exact); + ourDecoder->scanFileAndUpdateMetadataAndIndex(); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); + + ContainerMetadata metadata = ourDecoder->getContainerMetadata(); + const auto& videoStream = metadata.allStreamMetadata[3]; + + EXPECT_TRUE(videoStream.beginStreamPtsSecondsFromContent.has_value()); + double minSeconds = videoStream.beginStreamPtsSecondsFromContent.value(); + + // Ensure genuinely invalid timestamps still fail appropriately + double definitelyInvalidTimestamp = minSeconds - 0.1; + + EXPECT_THROW( + { + auto timestamps = torch::tensor( + {definitelyInvalidTimestamp}, torch::dtype(torch::kDouble)); + ourDecoder->getFramesPlayedAt(timestamps); + }, + c10::Error); +} + +TEST_P( + SingleStreamDecoderTest, + FloatingPointPrecisionBatchTimestampsWithEpsilonErrorsWork) { + std::string path = getResourcePath("nasa_13013.mp4"); + + std::unique_ptr ourDecoder = + std::make_unique( + path, SingleStreamDecoder::SeekMode::exact); + ourDecoder->scanFileAndUpdateMetadataAndIndex(); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); + + ContainerMetadata metadata = ourDecoder->getContainerMetadata(); + const auto& videoStream = metadata.allStreamMetadata[3]; + + EXPECT_TRUE(videoStream.beginStreamPtsSecondsFromContent.has_value()); + double minSeconds = videoStream.beginStreamPtsSecondsFromContent.value(); + + constexpr double eps = std::numeric_limits::epsilon(); + auto mixedTimestamps = torch::tensor( + {minSeconds, minSeconds - eps * 10, minSeconds + eps * 50}, + torch::dtype(torch::kDouble)); + + EXPECT_NO_THROW({ + auto output = ourDecoder->getFramesPlayedAt(mixedTimestamps); + EXPECT_EQ(output.data.size(0), 3); + }); +} + +TEST_P(SingleStreamDecoderTest, HandleFloatingPointPrecisionInRangeValidation) { + std::string path = getResourcePath("nasa_13013.mp4"); + + // Set exact seek mode to enable precise timestamp validation + std::unique_ptr ourDecoder = + std::make_unique( + path, SingleStreamDecoder::SeekMode::exact); + ourDecoder->scanFileAndUpdateMetadataAndIndex(); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); + + // Get the metadata to understand the valid timestamp range + ContainerMetadata metadata = ourDecoder->getContainerMetadata(); + const auto& videoStream = metadata.allStreamMetadata[3]; // Video stream index + + EXPECT_TRUE(videoStream.beginStreamPtsSecondsFromContent.has_value()); + double minSeconds = videoStream.beginStreamPtsSecondsFromContent.value(); + + // Test case 1: Range starting exactly at minSeconds - should work + EXPECT_NO_THROW({ + auto output = + ourDecoder->getFramesPlayedInRange(minSeconds, minSeconds + 1.0); + EXPECT_GT(output.data.size(0), 0); + }); + + // Test case 2: Range with floating-point precision error at start + constexpr double eps = std::numeric_limits::epsilon(); + double startWithPrecisionError = minSeconds - eps * 100; + + // This should NOT throw an error with our fix + EXPECT_NO_THROW({ + auto output = ourDecoder->getFramesPlayedInRange( + startWithPrecisionError, minSeconds + 1.0); + EXPECT_GT(output.data.size(0), 0); + }); + + // Test case 3: Test that genuinely invalid range still fails appropriately + double definitelyInvalidStart = minSeconds - 0.1; + + // This should still throw an error (not a precision issue) + EXPECT_THROW( + { + ourDecoder->getFramesPlayedInRange( + definitelyInvalidStart, minSeconds + 1.0); + }, + c10::Error); +} + INSTANTIATE_TEST_SUITE_P( FromFileAndMemory, SingleStreamDecoderTest,