diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index bb0988a13..2df6db1bc 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -46,8 +46,7 @@ void CpuDeviceInterface::initializeVideo( // We calculate this value during initilization but we don't refer to it until // getColorConversionLibrary() is called. Calculating this value during // initialization saves us from having to save all of the transforms. - areTransformsSwScaleCompatible_ = transforms.empty() || - (transforms.size() == 1 && transforms[0]->isResize()); + areTransformsSwScaleCompatible_ = transforms.empty(); // Note that we do not expose this capability in the public API, only through // the core API. @@ -57,16 +56,6 @@ void CpuDeviceInterface::initializeVideo( userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary == ColorConversionLibrary::SWSCALE; - // We can only use swscale when we have a single resize transform. Note that - // we actually decide on whether or not to actually use swscale at the last - // possible moment, when we actually convert the frame. This is because we - // need to know the actual frame dimensions. - if (transforms.size() == 1 && transforms[0]->isResize()) { - auto resize = dynamic_cast(transforms[0].get()); - TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!") - swsFlags_ = resize->getSwsFlags(); - } - // If we have any transforms, replace filters_ with the filter strings from // the transforms. As noted above, we decide between swscale and filtergraph // when we actually decode a frame. @@ -80,7 +69,18 @@ void CpuDeviceInterface::initializeVideo( first = false; } if (!transforms.empty()) { - filters_ = filters.str(); + // Note [Transform and Format Conversion Order] + // We have to ensure that all user filters happen AFTER the explicit format + // conversion. That is, we want the filters to be applied in RGB24, not the + // pixel format of the input frame. + // + // The ouput frame will always be in RGB24, as we specify the sink node with + // AV_PIX_FORMAT_RGB24. Filtergraph will automatically insert a filter + // conversion to ensure the output frame matches the pixel format + // specified in the sink. But by default, it will insert it after the user + // filters. We need an explicit format conversion to get the behavior we + // want. + filters_ = "format=rgb24," + filters.str(); } initialized_ = true; @@ -218,6 +218,11 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( enum AVPixelFormat frameFormat = static_cast(avFrame->format); + TORCH_CHECK( + avFrame->height == outputDims.height && + avFrame->width == outputDims.width, + "Input dimensions are not equal to output dimensions; resize for sws_scale() is not yet supported."); + // We need to compare the current frame context with our previous frame // context. If they are different, then we need to re-create our colorspace // conversion objects. We create our colorspace conversion objects late so @@ -234,7 +239,16 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { swsContext_ = createSwsContext( - swsFrameContext, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_); + swsFrameContext, + avFrame->colorspace, + + // See [Transform and Format Conversion Order] for more on the output + // pixel format. + /*outputFormat=*/AV_PIX_FMT_RGB24, + + // We don't set any flags because we don't yet use sw_scale() for + // resizing. + /*swsFlags=*/0); prevSwsFrameContext_ = swsFrameContext; } @@ -256,17 +270,17 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame, const FrameDims& outputDims) { - enum AVPixelFormat frameFormat = + enum AVPixelFormat avFrameFormat = static_cast(avFrame->format); FiltersContext filtersContext( avFrame->width, avFrame->height, - frameFormat, + avFrameFormat, avFrame->sample_aspect_ratio, outputDims.width, outputDims.height, - AV_PIX_FMT_RGB24, + /*outputFormat=*/AV_PIX_FMT_RGB24, filters_, timeBase_); diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index f7c57045a..da6a179b8 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -87,21 +87,20 @@ class CpuDeviceInterface : public DeviceInterface { UniqueSwsContext swsContext_; SwsFrameContext prevSwsFrameContext_; - // The filter we supply to filterGraph_, if it is used. The default is the - // copy filter, which just copies the input to the output. Computationally, it - // should be a no-op. If we get no user-provided transforms, we will use the - // copy filter. Otherwise, we will construct the string from the transforms. + // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline + // of what FFmpeg calls "filters" to apply to decoded frames before returning + // them. In the PyTorch ecosystem, we call these "transforms". During + // initialization, we convert the user-supplied transforms into this string of + // filters. // - // Note that even if we only use the copy filter, we still get the desired - // colorspace conversion. We construct the filtergraph with its output sink - // set to RGB24. + // Note that if there are no user-supplied transforms, then the default filter + // we use is the copy filter, which is just an identity: it emits the output + // frame unchanged. We supply such a filter because we can't supply just the + // empty-string; we must supply SOME filter. + // + // See also [Tranform and Format Conversion Order] for more on filters. std::string filters_ = "copy"; - // The flags we supply to swsContext_, if it used. The flags control the - // resizing algorithm. We default to bilinear. Users can override this with a - // ResizeTransform. - int swsFlags_ = SWS_BILINEAR; - // Values set during initialization and referred to in // getColorConversionLibrary(). bool areTransformsSwScaleCompatible_; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index b9663d8d2..ffb813f4c 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -399,68 +399,65 @@ SwrContext* createSwrContext( return swrContext; } -AVFilterContext* createBuffersinkFilter( +AVFilterContext* createAVFilterContextWithOptions( AVFilterGraph* filterGraph, - enum AVPixelFormat outputFormat) { - const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - TORCH_CHECK(buffersink != nullptr, "Failed to get buffersink filter."); - - AVFilterContext* sinkContext = nullptr; - int status; + const AVFilter* buffer, + const enum AVPixelFormat outputFormat) { + AVFilterContext* avFilterContext = nullptr; const char* filterName = "out"; - enum AVPixelFormat pix_fmts[] = {outputFormat, AV_PIX_FMT_NONE}; + enum AVPixelFormat pixFmts[] = {outputFormat, AV_PIX_FMT_NONE}; // av_opt_set_int_list was replaced by av_opt_set_array() in FFmpeg 8. #if LIBAVUTIL_VERSION_MAJOR >= 60 // FFmpeg >= 8 // Output options like pixel_formats must be set before filter init - sinkContext = - avfilter_graph_alloc_filter(filterGraph, buffersink, filterName); + avFilterContext = + avfilter_graph_alloc_filter(filterGraph, buffer, filterName); TORCH_CHECK( - sinkContext != nullptr, "Failed to allocate buffersink filter context."); + avFilterContext != nullptr, "Failed to allocate buffer filter context."); // When setting pix_fmts, only the first element is used, so nb_elems = 1 // AV_PIX_FMT_NONE acts as a terminator for the array in av_opt_set_int_list - status = av_opt_set_array( - sinkContext, + int status = av_opt_set_array( + avFilterContext, "pixel_formats", AV_OPT_SEARCH_CHILDREN, 0, // start_elem 1, // nb_elems AV_OPT_TYPE_PIXEL_FMT, - pix_fmts); + pixFmts); TORCH_CHECK( status >= 0, - "Failed to set pixel format for buffersink filter: ", + "Failed to set pixel format for buffer filter: ", getFFMPEGErrorStringFromErrorCode(status)); - status = avfilter_init_str(sinkContext, nullptr); + status = avfilter_init_str(avFilterContext, nullptr); TORCH_CHECK( status >= 0, - "Failed to initialize buffersink filter: ", + "Failed to initialize buffer filter: ", getFFMPEGErrorStringFromErrorCode(status)); #else // FFmpeg <= 7 // For older FFmpeg versions, create filter and then set options - status = avfilter_graph_create_filter( - &sinkContext, buffersink, filterName, nullptr, nullptr, filterGraph); + int status = avfilter_graph_create_filter( + &avFilterContext, buffer, filterName, nullptr, nullptr, filterGraph); TORCH_CHECK( status >= 0, - "Failed to create buffersink filter: ", + "Failed to create buffer filter: ", getFFMPEGErrorStringFromErrorCode(status)); status = av_opt_set_int_list( - sinkContext, + avFilterContext, "pix_fmts", - pix_fmts, + pixFmts, AV_PIX_FMT_NONE, AV_OPT_SEARCH_CHILDREN); TORCH_CHECK( status >= 0, - "Failed to set pixel formats for buffersink filter: ", + "Failed to set pixel formats for buffer filter: ", getFFMPEGErrorStringFromErrorCode(status)); #endif - return sinkContext; + return avFilterContext; } UniqueAVFrame convertAudioAVFrameSamples( diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 2d58abfb2..3f5673454 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -247,9 +247,10 @@ int64_t computeSafeDuration( const AVRational& frameRate, const AVRational& timeBase); -AVFilterContext* createBuffersinkFilter( +AVFilterContext* createAVFilterContextWithOptions( AVFilterGraph* filterGraph, - enum AVPixelFormat outputFormat); + const AVFilter* buffer, + const enum AVPixelFormat outputFormat); struct SwsFrameContext { int inputWidth = 0; @@ -274,7 +275,7 @@ struct SwsFrameContext { UniqueSwsContext createSwsContext( const SwsFrameContext& swsFrameContext, AVColorSpace colorspace, - AVPixelFormat outputFormat = AV_PIX_FMT_RGB24, - int swsFlags = SWS_BILINEAR); + AVPixelFormat outputFormat, + int swsFlags); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index 605b814a8..564939e85 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -63,8 +63,8 @@ FilterGraph::FilterGraph( filterGraph_->nb_threads = videoStreamOptions.ffmpegThreadCount.value(); } - const AVFilter* buffersrc = avfilter_get_by_name("buffer"); - + // Configure the source context. + const AVFilter* bufferSrc = avfilter_get_by_name("buffer"); UniqueAVBufferSrcParameters srcParams(av_buffersrc_parameters_alloc()); TORCH_CHECK(srcParams, "Failed to allocate buffersrc params"); @@ -78,7 +78,7 @@ FilterGraph::FilterGraph( } sourceContext_ = - avfilter_graph_alloc_filter(filterGraph_.get(), buffersrc, "in"); + avfilter_graph_alloc_filter(filterGraph_.get(), bufferSrc, "in"); TORCH_CHECK(sourceContext_, "Failed to allocate filter graph"); int status = av_buffersrc_parameters_set(sourceContext_, srcParams.get()); @@ -93,23 +93,31 @@ FilterGraph::FilterGraph( "Failed to create filter graph : ", getFFMPEGErrorStringFromErrorCode(status)); - sinkContext_ = - createBuffersinkFilter(filterGraph_.get(), filtersContext.outputFormat); + // Configure the sink context. + const AVFilter* bufferSink = avfilter_get_by_name("buffersink"); + TORCH_CHECK(bufferSink != nullptr, "Failed to get buffersink filter."); + + sinkContext_ = createAVFilterContextWithOptions( + filterGraph_.get(), bufferSink, filtersContext.outputFormat); TORCH_CHECK( sinkContext_ != nullptr, "Failed to create and configure buffersink"); + // Create the filtergraph nodes based on the source and sink contexts. UniqueAVFilterInOut outputs(avfilter_inout_alloc()); - UniqueAVFilterInOut inputs(avfilter_inout_alloc()); - outputs->name = av_strdup("in"); outputs->filter_ctx = sourceContext_; outputs->pad_idx = 0; outputs->next = nullptr; + + UniqueAVFilterInOut inputs(avfilter_inout_alloc()); inputs->name = av_strdup("out"); inputs->filter_ctx = sinkContext_; inputs->pad_idx = 0; inputs->next = nullptr; + // Create the filtergraph specified by the filtergraph string in the context + // of the inputs and outputs. Note the dance we have to do with release and + // resetting the output and input nodes because FFmpeg modifies them in place. AVFilterInOut* outputsTmp = outputs.release(); AVFilterInOut* inputsTmp = inputs.release(); status = avfilter_graph_parse_ptr( @@ -126,6 +134,7 @@ FilterGraph::FilterGraph( getFFMPEGErrorStringFromErrorCode(status), ", provided filters: " + filtersContext.filtergraphStr); + // Check filtergraph validity and configure links and formats. status = avfilter_graph_config(filterGraph_.get(), nullptr); TORCH_CHECK( status >= 0, diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index 6083986e1..77cffe636 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -25,38 +25,18 @@ std::string toFilterGraphInterpolation( } } -int toSwsInterpolation(ResizeTransform::InterpolationMode mode) { - switch (mode) { - case ResizeTransform::InterpolationMode::BILINEAR: - return SWS_BILINEAR; - default: - TORCH_CHECK( - false, - "Unknown interpolation mode: " + - std::to_string(static_cast(mode))); - } -} - } // namespace std::string ResizeTransform::getFilterGraphCpu() const { return "scale=" + std::to_string(outputDims_.width) + ":" + std::to_string(outputDims_.height) + - ":sws_flags=" + toFilterGraphInterpolation(interpolationMode_); + ":flags=" + toFilterGraphInterpolation(interpolationMode_); } std::optional ResizeTransform::getOutputFrameDims() const { return outputDims_; } -bool ResizeTransform::isResize() const { - return true; -} - -int ResizeTransform::getSwsFlags() const { - return toSwsInterpolation(interpolationMode_); -} - CropTransform::CropTransform(const FrameDims& dims, int x, int y) : outputDims_(dims), x_(x), y_(y) { TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_); diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index 28d8c28a2..7c7a3e2a5 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -29,12 +29,6 @@ class Transform { return std::nullopt; } - // The ResizeTransform is special, because it is the only transform that - // swscale can handle. - virtual bool isResize() const { - return false; - } - // The validity of some transforms depends on the characteristics of the // AVStream they're being applied to. For example, some transforms will // specify coordinates inside a frame, we need to validate that those are @@ -58,9 +52,6 @@ class ResizeTransform : public Transform { std::string getFilterGraphCpu() const override; std::optional getOutputFrameDims() const override; - bool isResize() const override; - - int getSwsFlags() const; private: FrameDims outputDims_; diff --git a/test/generate_reference_resources.py b/test/generate_reference_resources.py index 953fb996e..3821c9299 100644 --- a/test/generate_reference_resources.py +++ b/test/generate_reference_resources.py @@ -51,16 +51,16 @@ def generate_frame_by_index( ) output_bmp = f"{base_path}.bmp" - # Note that we have an exlicit format conversion to rgb24 in our filtergraph specification, - # which always happens BEFORE any of the filters that we receive as input. We do this to - # ensure that the color conversion happens BEFORE the filters, matching the behavior of the - # torchcodec filtergraph implementation. - # - # Not doing this would result in the color conversion happening AFTER the filters, which - # would result in different color values for the same frame. - filtergraph = f"select='eq(n\\,{frame_index})',format=rgb24" + # Note that we have an exlicit format conversion to rgb24 in our filtergraph + # specification, and we always place the user-supplied filters AFTER the + # format conversion. We do this to ensure that the filters are applied in + # RGB24 colorspace, which matches TorchCodec's behavior. + select = f"select='eq(n\\,{frame_index})'" + format = "format=rgb24" if filters is not None: - filtergraph = filtergraph + f",{filters}" + filtergraph = ",".join([select, format, filters]) + else: + filtergraph = ",".join([select, format]) cmd = [ "ffmpeg", @@ -99,7 +99,7 @@ def generate_frame_by_timestamp( convert_image_to_tensor(output_path) -def generate_nasa_13013_references(): +def generate_nasa_13013_references_by_index(): # Note: The naming scheme used here must match the naming scheme used to load # tensors in ./utils.py. streams = [0, 3] @@ -108,6 +108,8 @@ def generate_nasa_13013_references(): for frame in frames: generate_frame_by_index(NASA_VIDEO, frame_index=frame, stream_index=stream) + +def generate_nasa_13013_references_by_timestamp(): # Extract individual frames at specific timestamps, including the last frame of the video. seek_timestamp = [6.0, 6.1, 10.0, 12.979633] timestamp_name = [f"{seek_timestamp:06f}" for seek_timestamp in seek_timestamp] @@ -115,6 +117,8 @@ def generate_nasa_13013_references(): output_bmp = f"{NASA_VIDEO.path}.time{name}.bmp" generate_frame_by_timestamp(NASA_VIDEO.path, timestamp, output_bmp) + +def generate_nasa_13013_references_crop(): # Extract frames with specific filters. We have tests that assume these exact filters. frames = [0, 15, 200, 389] crop_filter = "crop=300:200:50:35:exact=1" @@ -124,6 +128,24 @@ def generate_nasa_13013_references(): ) +def generate_nasa_13013_references_resize(): + frames = [17, 230, 389] + # Note that the resize algorithm passed to flags is exposed to users, + # but bilinear is the default we use. + resize_filter = "scale=240:135:flags=bilinear" + for frame in frames: + generate_frame_by_index( + NASA_VIDEO, frame_index=frame, stream_index=3, filters=resize_filter + ) + + +def generate_nasa_13013_references(): + generate_nasa_13013_references_by_index() + generate_nasa_13013_references_by_timestamp() + generate_nasa_13013_references_crop() + generate_nasa_13013_references_resize() + + def generate_h265_video_references(): # This video was generated by running the following: # conda install -c conda-forge x265 diff --git a/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000017.pt b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000017.pt new file mode 100644 index 000000000..5da3e81fe Binary files /dev/null and b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000017.pt differ diff --git a/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000230.pt b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000230.pt new file mode 100644 index 000000000..5094e44da Binary files /dev/null and b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000230.pt differ diff --git a/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000389.pt b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000389.pt new file mode 100644 index 000000000..a15622389 Binary files /dev/null and b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000389.pt differ diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index 8d1ba5e53..370849726 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -20,67 +20,69 @@ create_from_file, get_frame_at_index, get_json_metadata, - get_next_frame, ) from torchvision.transforms import v2 -from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda +from .utils import ( + assert_frames_equal, + assert_tensor_close_on_at_least, + AV1_VIDEO, + H265_VIDEO, + NASA_VIDEO, + needs_cuda, + TEST_SRC_2_720P, +) torch._dynamo.config.capture_dynamic_output_shape_ops = True -class TestVideoDecoderTransformOps: - # We choose arbitrary values for width and height scaling to get better - # test coverage. Some pairs upscale the image while others downscale it. - @pytest.mark.parametrize( - "width_scaling_factor,height_scaling_factor", - ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), - ) - @pytest.mark.parametrize("input_video", [NASA_VIDEO]) - def test_color_conversion_library_with_scaling( - self, input_video, width_scaling_factor, height_scaling_factor - ): - decoder = create_from_file(str(input_video.path)) +class TestCoreVideoDecoderTransformOps: + def get_num_frames_core_ops(self, video): + decoder = create_from_file(str(video.path)) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) - assert metadata_dict["width"] == input_video.width - assert metadata_dict["height"] == input_video.height + num_frames = metadata_dict["numFramesFromHeader"] + assert num_frames is not None + return num_frames - target_height = int(input_video.height * height_scaling_factor) - target_width = int(input_video.width * width_scaling_factor) - if width_scaling_factor != 1.0: - assert target_width != input_video.width - if height_scaling_factor != 1.0: - assert target_height != input_video.height + @pytest.mark.parametrize("video", [NASA_VIDEO, H265_VIDEO, AV1_VIDEO]) + def test_color_conversion_library(self, video): + num_frames = self.get_num_frames_core_ops(video) - filtergraph_decoder = create_from_file(str(input_video.path)) + filtergraph_decoder = create_from_file(str(video.path)) _add_video_stream( filtergraph_decoder, - transform_specs=f"resize, {target_height}, {target_width}", color_conversion_library="filtergraph", ) - filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) - swscale_decoder = create_from_file(str(input_video.path)) + swscale_decoder = create_from_file(str(video.path)) _add_video_stream( swscale_decoder, - transform_specs=f"resize, {target_height}, {target_width}", color_conversion_library="swscale", ) - swscale_frame0, _, _ = get_next_frame(swscale_decoder) - assert_frames_equal(filtergraph_frame0, swscale_frame0) - assert filtergraph_frame0.shape == (3, target_height, target_width) - @pytest.mark.parametrize( - "width_scaling_factor,height_scaling_factor", - ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), - ) + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + filtergraph_frame, *_ = get_frame_at_index( + filtergraph_decoder, frame_index=frame_index + ) + swscale_frame, *_ = get_frame_at_index( + swscale_decoder, frame_index=frame_index + ) + + assert_frames_equal(filtergraph_frame, swscale_frame) + @pytest.mark.parametrize("width", [30, 32, 300]) @pytest.mark.parametrize("height", [128]) def test_color_conversion_library_with_generated_videos( - self, tmp_path, width, height, width_scaling_factor, height_scaling_factor + self, tmp_path, width, height ): # We consider filtergraph to be the reference color conversion library. # However the video decoder sometimes uses swscale as that is faster. @@ -129,27 +131,22 @@ def test_color_conversion_library_with_generated_videos( assert metadata_dict["width"] == width assert metadata_dict["height"] == height - target_height = int(height * height_scaling_factor) - target_width = int(width * width_scaling_factor) - if width_scaling_factor != 1.0: - assert target_width != width - if height_scaling_factor != 1.0: - assert target_height != height + num_frames = metadata_dict["numFramesFromHeader"] + assert num_frames is not None and num_frames == 1 filtergraph_decoder = create_from_file(str(video_path)) _add_video_stream( filtergraph_decoder, - transform_specs=f"resize, {target_height}, {target_width}", color_conversion_library="filtergraph", ) - filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) auto_decoder = create_from_file(str(video_path)) add_video_stream( auto_decoder, - transform_specs=f"resize, {target_height}, {target_width}", ) - auto_frame0, _, _ = get_next_frame(auto_decoder) + + filtergraph_frame0, *_ = get_frame_at_index(filtergraph_decoder, frame_index=0) + auto_frame0, *_ = get_frame_at_index(auto_decoder, frame_index=0) assert_frames_equal(filtergraph_frame0, auto_frame0) @needs_cuda @@ -175,6 +172,90 @@ def test_transform_fails(self): ): add_video_stream(decoder, transform_specs="invalid, 1, 2") + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0), (2.0, 2.0)), + ) + @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) + def test_resize_torchvision( + self, video, height_scaling_factor, width_scaling_factor + ): + num_frames = self.get_num_frames_core_ops(video) + + height = int(video.get_height() * height_scaling_factor) + width = int(video.get_width() * width_scaling_factor) + resize_spec = f"resize, {height}, {width}" + + decoder_resize = create_from_file(str(video.path)) + add_video_stream(decoder_resize, transform_specs=resize_spec) + + decoder_full = create_from_file(str(video.path)) + add_video_stream(decoder_full) + + for frame_index in [ + 0, + int(num_frames * 0.1), + int(num_frames * 0.2), + int(num_frames * 0.3), + int(num_frames * 0.4), + int(num_frames * 0.5), + int(num_frames * 0.75), + int(num_frames * 0.90), + num_frames - 1, + ]: + expected_shape = (video.get_num_color_channels(), height, width) + frame_resize, *_ = get_frame_at_index( + decoder_resize, frame_index=frame_index + ) + + frame_full, *_ = get_frame_at_index(decoder_full, frame_index=frame_index) + frame_tv = v2.functional.resize(frame_full, size=(height, width)) + frame_tv_no_antialias = v2.functional.resize( + frame_full, size=(height, width), antialias=False + ) + + assert frame_resize.shape == expected_shape + assert frame_tv.shape == expected_shape + assert frame_tv_no_antialias.shape == expected_shape + + assert_tensor_close_on_at_least( + frame_resize, frame_tv, percentage=99.8, atol=1 + ) + torch.testing.assert_close(frame_resize, frame_tv, rtol=0, atol=6) + + if height_scaling_factor < 1 or width_scaling_factor < 1: + # Antialias only relevant when down-scaling! + with pytest.raises(AssertionError, match="Expected at least"): + assert_tensor_close_on_at_least( + frame_resize, frame_tv_no_antialias, percentage=99, atol=1 + ) + with pytest.raises(AssertionError, match="Tensor-likes are not close"): + torch.testing.assert_close( + frame_resize, frame_tv_no_antialias, rtol=0, atol=6 + ) + + def test_resize_ffmpeg(self): + height = 135 + width = 240 + expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width) + resize_spec = f"resize, {height}, {width}" + resize_filtergraph = f"scale={width}:{height}:flags=bilinear" + + decoder_resize = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder_resize, transform_specs=resize_spec) + + for frame_index in [17, 230, 389]: + frame_resize, *_ = get_frame_at_index( + decoder_resize, frame_index=frame_index + ) + frame_ref = NASA_VIDEO.get_frame_data_by_index( + frame_index, filters=resize_filtergraph + ) + + assert frame_resize.shape == expected_shape + assert frame_ref.shape == expected_shape + assert_frames_equal(frame_resize, frame_ref) + def test_resize_transform_fails(self): decoder = create_from_file(str(NASA_VIDEO.path)) with pytest.raises( @@ -224,7 +305,7 @@ def test_crop_transform(self): add_video_stream(decoder_full) for frame_index in [0, 15, 200, 389]: - frame, *_ = get_frame_at_index(decoder_crop, frame_index=frame_index) + frame_crop, *_ = get_frame_at_index(decoder_crop, frame_index=frame_index) frame_ref = NASA_VIDEO.get_frame_data_by_index( frame_index, filters=crop_filtergraph ) @@ -234,12 +315,12 @@ def test_crop_transform(self): frame_full, top=y, left=x, height=height, width=width ) - assert frame.shape == expected_shape + assert frame_crop.shape == expected_shape assert frame_ref.shape == expected_shape assert frame_tv.shape == expected_shape - assert_frames_equal(frame, frame_tv) - assert_frames_equal(frame, frame_ref) + assert_frames_equal(frame_crop, frame_ref) + assert_frames_equal(frame_crop, frame_tv) def test_crop_transform_fails(self): diff --git a/test/utils.py b/test/utils.py index cbd6a5bf4..43f29cf5a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -430,7 +430,7 @@ def empty_chw_tensor(self) -> torch.Tensor: [0, self.num_color_channels, self.height, self.width], dtype=torch.uint8 ) - def get_width(self, *, stream_index: Optional[int]) -> int: + def get_width(self, *, stream_index: Optional[int] = None) -> int: if stream_index is None: stream_index = self.default_stream_index