diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 587456f34..c1cbd1afc 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -674,149 +674,13 @@ void BetaCudaDeviceInterface::flush() { std::swap(readyFrames_, emptyQueue); } -UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12( - UniqueAVFrame& cpuFrame) { - // This is called in the context of the CPU fallback: the frame was decoded on - // the CPU, and in this function we convert that frame into NV12 format and - // send it to the GPU. - // We do that in 2 steps: - // - First we convert the input CPU frame into an intermediate NV12 CPU frame - // using sws_scale. - // - Then we allocate GPU memory and copy the NV12 CPU frame to the GPU. This - // is what we return - - TORCH_CHECK(cpuFrame != nullptr, "CPU frame cannot be null"); - - int width = cpuFrame->width; - int height = cpuFrame->height; - - // intermediate NV12 CPU frame. It's not on the GPU yet. - UniqueAVFrame nv12CpuFrame(av_frame_alloc()); - TORCH_CHECK(nv12CpuFrame != nullptr, "Failed to allocate NV12 CPU frame"); - - nv12CpuFrame->format = AV_PIX_FMT_NV12; - nv12CpuFrame->width = width; - nv12CpuFrame->height = height; - - int ret = av_frame_get_buffer(nv12CpuFrame.get(), 0); - TORCH_CHECK( - ret >= 0, - "Failed to allocate NV12 CPU frame buffer: ", - getFFMPEGErrorStringFromErrorCode(ret)); - - SwsFrameContext swsFrameContext( - width, - height, - static_cast(cpuFrame->format), - width, - height); - - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - swsContext_ = createSwsContext( - swsFrameContext, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR); - prevSwsFrameContext_ = swsFrameContext; - } - - int convertedHeight = sws_scale( - swsContext_.get(), - cpuFrame->data, - cpuFrame->linesize, - 0, - height, - nv12CpuFrame->data, - nv12CpuFrame->linesize); - TORCH_CHECK( - convertedHeight == height, "sws_scale failed for CPU->NV12 conversion"); - - int ySize = width * height; - TORCH_CHECK( - ySize % 2 == 0, - "Y plane size must be even. Please report on TorchCodec repo."); - int uvSize = ySize / 2; // NV12: UV plane is half the size of Y plane - size_t totalSize = static_cast(ySize + uvSize); - - uint8_t* cudaBuffer = nullptr; - cudaError_t err = - cudaMalloc(reinterpret_cast(&cudaBuffer), totalSize); - TORCH_CHECK( - err == cudaSuccess, - "Failed to allocate CUDA memory: ", - cudaGetErrorString(err)); - - UniqueAVFrame gpuFrame(av_frame_alloc()); - TORCH_CHECK(gpuFrame != nullptr, "Failed to allocate GPU AVFrame"); - - gpuFrame->format = AV_PIX_FMT_CUDA; - gpuFrame->width = width; - gpuFrame->height = height; - gpuFrame->data[0] = cudaBuffer; - gpuFrame->data[1] = cudaBuffer + ySize; - gpuFrame->linesize[0] = width; - gpuFrame->linesize[1] = width; - - // Note that we use cudaMemcpy2D here instead of cudaMemcpy because the - // linesizes (strides) may be different than the widths for the input CPU - // frame. That's precisely what cudaMemcpy2D is for. - err = cudaMemcpy2D( - gpuFrame->data[0], - gpuFrame->linesize[0], - nv12CpuFrame->data[0], - nv12CpuFrame->linesize[0], - width, - height, - cudaMemcpyHostToDevice); - TORCH_CHECK( - err == cudaSuccess, - "Failed to copy Y plane to GPU: ", - cudaGetErrorString(err)); - - TORCH_CHECK( - height % 2 == 0, - "height must be even. Please report on TorchCodec repo."); - err = cudaMemcpy2D( - gpuFrame->data[1], - gpuFrame->linesize[1], - nv12CpuFrame->data[1], - nv12CpuFrame->linesize[1], - width, - height / 2, - cudaMemcpyHostToDevice); - TORCH_CHECK( - err == cudaSuccess, - "Failed to copy UV plane to GPU: ", - cudaGetErrorString(err)); - - ret = av_frame_copy_props(gpuFrame.get(), cpuFrame.get()); - TORCH_CHECK( - ret >= 0, - "Failed to copy frame properties: ", - getFFMPEGErrorStringFromErrorCode(ret)); - - // We're almost done, but we need to make sure the CUDA memory is freed - // properly. Usually, AVFrame data is freed when av_frame_free() is called - // (upon UniqueAVFrame destruction), but since we allocated the CUDA memory - // ourselves, FFmpeg doesn't know how to free it. The recommended way to deal - // with this is to associate the opaque_ref field of the AVFrame with a `free` - // callback that will then be called by av_frame_free(). - gpuFrame->opaque_ref = av_buffer_create( - nullptr, // data - we don't need any - 0, // data size - cudaBufferFreeCallback, // callback triggered by av_frame_free() - cudaBuffer, // parameter to callback - 0); // flags - TORCH_CHECK( - gpuFrame->opaque_ref != nullptr, - "Failed to create GPU memory cleanup reference"); - - return gpuFrame; -} - void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { UniqueAVFrame gpuFrame = - cpuFallback_ ? transferCpuFrameToGpuNV12(avFrame) : std::move(avFrame); + cpuFallback_ ? transferCpuFrameToGpuNV12(avFrame, swsCtx_, device_) + : std::move(avFrame); // TODONVDEC P2: we may need to handle 10bit videos the same way the CUDA // ffmpeg interface does it with maybeConvertAVFrameToNV12OrRGB24(). diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index 0b0e7e6c6..906c920be 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -20,6 +20,7 @@ #include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/NVDECCache.h" +#include "src/torchcodec/_core/SwsContext.h" #include #include @@ -81,8 +82,6 @@ class BetaCudaDeviceInterface : public DeviceInterface { unsigned int pitch, const CUVIDPARSERDISPINFO& dispInfo); - UniqueAVFrame transferCpuFrameToGpuNV12(UniqueAVFrame& cpuFrame); - CUvideoparser videoParser_ = nullptr; UniqueCUvideodecoder decoder_; CUVIDEOFORMAT videoFormat_ = {}; @@ -101,8 +100,10 @@ class BetaCudaDeviceInterface : public DeviceInterface { std::unique_ptr cpuFallback_; bool nvcuvidAvailable_ = false; - UniqueSwsContext swsContext_; - SwsFrameContext prevSwsFrameContext_; + + // Swscale context cache for GPU transfer during CPU fallback. + // Used to convert CPU frames to NV12 before transferring to GPU. + SwsScaler swsCtx_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 6b4ccb5d4..1671ba970 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -96,6 +96,7 @@ function(make_torchcodec_libraries Encoder.cpp ValidationUtils.cpp Transform.cpp + SwsContext.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp index 4532e3c76..9fcfe5bd9 100644 --- a/src/torchcodec/_core/CUDACommon.cpp +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -327,4 +327,137 @@ int getDeviceIndex(const torch::Device& device) { return deviceIndex; } +// Callback for freeing CUDA memory associated with AVFrame +void cudaBufferFreeCallback(void* opaque, [[maybe_unused]] uint8_t* data) { + cudaFree(opaque); +} + +UniqueAVFrame transferCpuFrameToGpuNV12( + UniqueAVFrame& cpuFrame, + SwsScaler& swsCtx, + [[maybe_unused]] const torch::Device& device) { + // This function converts a CPU frame to NV12 format and transfers it to GPU. + // We do that in 2 steps: + // - First we convert the input CPU frame into an intermediate NV12 CPU frame + // using sws_scale. + // - Then we allocate GPU memory and copy the NV12 CPU frame to the GPU. This + // is what we return. + + TORCH_CHECK(cpuFrame != nullptr, "CPU frame cannot be null"); + + int width = cpuFrame->width; + int height = cpuFrame->height; + + // Intermediate NV12 CPU frame. It's not on the GPU yet. + UniqueAVFrame nv12CpuFrame(av_frame_alloc()); + TORCH_CHECK(nv12CpuFrame != nullptr, "Failed to allocate NV12 CPU frame"); + + nv12CpuFrame->format = AV_PIX_FMT_NV12; + nv12CpuFrame->width = width; + nv12CpuFrame->height = height; + + int ret = av_frame_get_buffer(nv12CpuFrame.get(), 0); + TORCH_CHECK( + ret >= 0, + "Failed to allocate NV12 CPU frame buffer: ", + getFFMPEGErrorStringFromErrorCode(ret)); + + FrameDims outputDims(height, width); + auto swsContext = swsCtx.getOrCreateContext( + cpuFrame, outputDims, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR); + + int convertedHeight = sws_scale( + swsContext, + cpuFrame->data, + cpuFrame->linesize, + 0, + height, + nv12CpuFrame->data, + nv12CpuFrame->linesize); + TORCH_CHECK( + convertedHeight == height, "sws_scale failed for CPU->NV12 conversion"); + + int ySize = width * height; + TORCH_CHECK( + ySize % 2 == 0, + "Y plane size must be even. Please report on TorchCodec repo."); + int uvSize = ySize / 2; // NV12: UV plane is half the size of Y plane + size_t totalSize = static_cast(ySize + uvSize); + + uint8_t* cudaBuffer = nullptr; + cudaError_t err = + cudaMalloc(reinterpret_cast(&cudaBuffer), totalSize); + TORCH_CHECK( + err == cudaSuccess, + "Failed to allocate CUDA memory: ", + cudaGetErrorString(err)); + + UniqueAVFrame gpuFrame(av_frame_alloc()); + TORCH_CHECK(gpuFrame != nullptr, "Failed to allocate GPU AVFrame"); + + gpuFrame->format = AV_PIX_FMT_CUDA; + gpuFrame->width = width; + gpuFrame->height = height; + gpuFrame->data[0] = cudaBuffer; + gpuFrame->data[1] = cudaBuffer + ySize; + gpuFrame->linesize[0] = width; + gpuFrame->linesize[1] = width; + + // Note that we use cudaMemcpy2D here instead of cudaMemcpy because the + // linesizes (strides) may be different than the widths for the input CPU + // frame. That's precisely what cudaMemcpy2D is for. + err = cudaMemcpy2D( + gpuFrame->data[0], + gpuFrame->linesize[0], + nv12CpuFrame->data[0], + nv12CpuFrame->linesize[0], + width, + height, + cudaMemcpyHostToDevice); + TORCH_CHECK( + err == cudaSuccess, + "Failed to copy Y plane to GPU: ", + cudaGetErrorString(err)); + + TORCH_CHECK( + height % 2 == 0, + "height must be even. Please report on TorchCodec repo."); + err = cudaMemcpy2D( + gpuFrame->data[1], + gpuFrame->linesize[1], + nv12CpuFrame->data[1], + nv12CpuFrame->linesize[1], + width, + height / 2, + cudaMemcpyHostToDevice); + TORCH_CHECK( + err == cudaSuccess, + "Failed to copy UV plane to GPU: ", + cudaGetErrorString(err)); + + ret = av_frame_copy_props(gpuFrame.get(), cpuFrame.get()); + TORCH_CHECK( + ret >= 0, + "Failed to copy frame properties: ", + getFFMPEGErrorStringFromErrorCode(ret)); + + // We're almost done, but we need to make sure the CUDA memory is freed + // properly. Usually, AVFrame data is freed when av_frame_free() is called + // (upon UniqueAVFrame destruction), but since we allocated the CUDA memory + // ourselves, FFmpeg doesn't know how to free it. The recommended way to deal + // with this is to associate the opaque_ref field of the AVFrame with a `free` + // callback that will then be called by av_frame_free(). + gpuFrame->opaque_ref = av_buffer_create( + nullptr, // data - we don't need any + 0, // data size + cudaBufferFreeCallback, // callback triggered by av_frame_free() + cudaBuffer, // parameter to callback + 0); // flags + TORCH_CHECK( + gpuFrame->opaque_ref != nullptr, + "Failed to create GPU memory cleanup reference"); + + return gpuFrame; +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h index 588f60e49..e898f6cfe 100644 --- a/src/torchcodec/_core/CUDACommon.h +++ b/src/torchcodec/_core/CUDACommon.h @@ -13,6 +13,7 @@ #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/Frame.h" +#include "src/torchcodec/_core/SwsContext.h" extern "C" { #include @@ -48,4 +49,11 @@ void validatePreAllocatedTensorShape( int getDeviceIndex(const torch::Device& device); +// Convert CPU frame to NV12 and transfer to GPU for GPU-accelerated color +// conversion. Used during CPU fallback to move color conversion to GPU. +UniqueAVFrame transferCpuFrameToGpuNV12( + UniqueAVFrame& cpuFrame, + SwsScaler& swsCtx, + const torch::Device& device); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index bb0988a13..9a13ba5cb 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -215,35 +215,17 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, torch::Tensor& outputTensor, const FrameDims& outputDims) { - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); - - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. - SwsFrameContext swsFrameContext( - avFrame->width, - avFrame->height, - frameFormat, - outputDims.width, - outputDims.height); - - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - swsContext_ = createSwsContext( - swsFrameContext, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_); - prevSwsFrameContext_ = swsFrameContext; - } + // Get or create swscale context. The SwsScaler class manages caching + // and recreation logic internally based on frame properties. + auto swsContext = swsCtx_.getOrCreateContext( + avFrame, outputDims, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_); uint8_t* pointers[4] = { outputTensor.data_ptr(), nullptr, nullptr, nullptr}; int expectedOutputWidth = outputTensor.sizes()[1]; int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; int resultHeight = sws_scale( - swsContext_.get(), + swsContext, avFrame->data, avFrame->linesize, 0, diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index f7c57045a..6a98971fa 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -9,6 +9,7 @@ #include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/FilterGraph.h" +#include "src/torchcodec/_core/SwsContext.h" namespace facebook::torchcodec { @@ -66,26 +67,18 @@ class CpuDeviceInterface : public DeviceInterface { // resolutions. std::optional resizedOutputDims_; - // Color-conversion objects. Only one of filterGraph_ and swsContext_ should - // be non-null. Which one we use is determined dynamically in + // Color-conversion objects. Only one of filterGraph_ and swsCtx_ should + // be actively used. Which one we use is determined dynamically in // getColorConversionLibrary() each time we decode a frame. // - // Creating both filterGraph_ and swsContext_ is relatively expensive, so we - // reuse them across frames. However, it is possbile that subsequent frames + // Creating both filterGraph_ and swsCtx_ is relatively expensive, so we + // reuse them across frames. However, it is possible that subsequent frames // are different enough (change in dimensions) that we can't reuse the color - // conversion object. We store the relevant frame context from the frame used - // to create the object last time. We always compare the current frame's info - // against the previous one to determine if we need to recreate the color - // conversion object. - // - // TODO: The names of these fields is confusing, as the actual color - // conversion object for Sws has "context" in the name, and we use - // "context" for the structs we store to know if we need to recreate a - // color conversion object. We should clean that up. + // conversion object. These objects internally track the frame properties + // needed to determine if they need to be recreated. std::unique_ptr filterGraph_; FiltersContext prevFiltersContext_; - UniqueSwsContext swsContext_; - SwsFrameContext prevSwsFrameContext_; + SwsScaler swsCtx_; // The filter we supply to filterGraph_, if it is used. The default is the // copy filter, which just copies the input to the output. Computationally, it diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index be45050e6..cae53be09 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -257,34 +257,26 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // whatever reason. Typically that happens if the video's encoder isn't // supported by NVDEC. // - // In both cases, we have a frame on the CPU. We send the frame back to the - // CUDA device when we're done. - - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); - - FrameOutput cpuFrameOutput; - if (frameFormat == AV_PIX_FMT_RGB24) { - // Reason 1 above. The frame is already in RGB24, we just need to convert - // it to a tensor. - cpuFrameOutput.data = rgbAVFrameToTensor(avFrame); - } else { - // Reason 2 above. We need to do a full conversion which requires an - // actual CPU device. - cpuInterface_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); - } + // In both cases, we have a frame on the CPU. Instead of doing color + // conversion on CPU and then transferring to GPU, we transfer to GPU first + // and do GPU-accelerated color conversion with NPP. - // Finally, we need to send the frame back to the GPU. Note that the - // pre-allocated tensor is on the GPU, so we can't send that to the CPU - // device interface. We copy it over here. - if (preAllocatedOutputTensor.has_value()) { - preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data); - frameOutput.data = preAllocatedOutputTensor.value(); - } else { - frameOutput.data = cpuFrameOutput.data.to(device_); - } + // Transfer CPU frame to GPU as NV12 for GPU-accelerated color conversion + avFrame = transferCpuFrameToGpuNV12(avFrame, swsCtx_, device_); usingCPUFallback_ = true; + + // Now the frame is on GPU in NV12 format. Do GPU-accelerated color + // conversion with NPP. + TORCH_CHECK( + avFrame->format == AV_PIX_FMT_CUDA, + "Expected CUDA format frame after GPU transfer"); + + at::cuda::CUDAStream nvdecStream = + at::cuda::getCurrentCUDAStream(device_.index()); + + frameOutput.data = convertNV12FrameToRGB( + avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); return; } diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 9f171ee3c..e7f49f6b1 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -9,6 +9,7 @@ #include "src/torchcodec/_core/CUDACommon.h" #include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FilterGraph.h" +#include "src/torchcodec/_core/SwsContext.h" namespace facebook::torchcodec { @@ -63,6 +64,10 @@ class CudaDeviceInterface : public DeviceInterface { std::unique_ptr nv12ConversionContext_; std::unique_ptr nv12Conversion_; + // Swscale context cache for GPU transfer during CPU fallback. + // Used to convert CPU frames to NV12 before transferring to GPU. + SwsScaler swsCtx_; + bool usingCPUFallback_ = false; }; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index b9663d8d2..c7f035ab9 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/SwsContext.h" #include @@ -605,28 +606,6 @@ int64_t computeSafeDuration( } } -SwsFrameContext::SwsFrameContext( - int inputWidth, - int inputHeight, - AVPixelFormat inputFormat, - int outputWidth, - int outputHeight) - : inputWidth(inputWidth), - inputHeight(inputHeight), - inputFormat(inputFormat), - outputWidth(outputWidth), - outputHeight(outputHeight) {} - -bool SwsFrameContext::operator==(const SwsFrameContext& other) const { - return inputWidth == other.inputWidth && inputHeight == other.inputHeight && - inputFormat == other.inputFormat && outputWidth == other.outputWidth && - outputHeight == other.outputHeight; -} - -bool SwsFrameContext::operator!=(const SwsFrameContext& other) const { - return !(*this == other); -} - UniqueSwsContext createSwsContext( const SwsFrameContext& swsFrameContext, AVColorSpace colorspace, diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 2d58abfb2..edeef234e 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -251,24 +251,8 @@ AVFilterContext* createBuffersinkFilter( AVFilterGraph* filterGraph, enum AVPixelFormat outputFormat); -struct SwsFrameContext { - int inputWidth = 0; - int inputHeight = 0; - AVPixelFormat inputFormat = AV_PIX_FMT_NONE; - int outputWidth = 0; - int outputHeight = 0; - - SwsFrameContext() = default; - SwsFrameContext( - int inputWidth, - int inputHeight, - AVPixelFormat inputFormat, - int outputWidth, - int outputHeight); - - bool operator==(const SwsFrameContext& other) const; - bool operator!=(const SwsFrameContext& other) const; -}; +// Forward declare SwsFrameContext from SwsContext.h +struct SwsFrameContext; // Utility functions for swscale context management UniqueSwsContext createSwsContext( diff --git a/src/torchcodec/_core/SwsContext.cpp b/src/torchcodec/_core/SwsContext.cpp new file mode 100644 index 000000000..25b9c0f22 --- /dev/null +++ b/src/torchcodec/_core/SwsContext.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/SwsContext.h" +#include "src/torchcodec/_core/FFMPEGCommon.h" + +extern "C" { +#include +} + +namespace facebook::torchcodec { + +SwsFrameContext::SwsFrameContext( + int inputWidth, + int inputHeight, + AVPixelFormat inputFormat, + int outputWidth, + int outputHeight) + : inputWidth(inputWidth), + inputHeight(inputHeight), + inputFormat(inputFormat), + outputWidth(outputWidth), + outputHeight(outputHeight) {} + +bool SwsFrameContext::operator==(const SwsFrameContext& other) const { + return inputWidth == other.inputWidth && inputHeight == other.inputHeight && + inputFormat == other.inputFormat && outputWidth == other.outputWidth && + outputHeight == other.outputHeight; +} + +bool SwsFrameContext::operator!=(const SwsFrameContext& other) const { + return !(*this == other); +} + +SwsContext* SwsScaler::getOrCreateContext( + const UniqueAVFrame& avFrame, + const FrameDims& outputDims, + AVColorSpace colorspace, + AVPixelFormat outputFormat, + int swsFlags) { + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + + SwsFrameContext currentFrameContext( + avFrame->width, + avFrame->height, + frameFormat, + outputDims.width, + outputDims.height); + + // Recreate swscale context only if frame properties changed + if (!swsContext_ || prevFrameContext_ != currentFrameContext) { + swsContext_ = createSwsContext( + currentFrameContext, colorspace, outputFormat, swsFlags); + prevFrameContext_ = currentFrameContext; + } + + return swsContext_.get(); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SwsContext.h b/src/torchcodec/_core/SwsContext.h new file mode 100644 index 000000000..4c146ae46 --- /dev/null +++ b/src/torchcodec/_core/SwsContext.h @@ -0,0 +1,60 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +extern "C" { +#include +} + +#include "src/torchcodec/_core/Frame.h" + +namespace facebook::torchcodec { + +// Context describing frame properties needed for swscale conversion. +// Used to detect when swscale context needs to be recreated. +struct SwsFrameContext { + int inputWidth; + int inputHeight; + AVPixelFormat inputFormat; + int outputWidth; + int outputHeight; + + SwsFrameContext( + int inputWidth, + int inputHeight, + AVPixelFormat inputFormat, + int outputWidth, + int outputHeight); + + bool operator==(const SwsFrameContext& other) const; + bool operator!=(const SwsFrameContext& other) const; +}; + +// Manages swscale context creation and caching across multiple frame conversions. +// Reuses the context as long as frame properties remain the same. +class SwsScaler { + public: + SwsScaler() = default; + ~SwsScaler() = default; + + // Get or create a swscale context for the given frame and output dimensions. + // Reuses cached context if frame properties haven't changed. + // Returns a raw pointer to the internal swscale context. The pointer is valid + // as long as this SwsScaler object is alive. + SwsContext* getOrCreateContext( + const UniqueAVFrame& avFrame, + const FrameDims& outputDims, + AVColorSpace colorspace, + AVPixelFormat outputFormat, + int swsFlags = SWS_BILINEAR); + + private: + UniqueSwsContext swsContext_; + SwsFrameContext prevFrameContext_ = SwsFrameContext(0, 0, AV_PIX_FMT_NONE, 0, 0); +}; + +} // namespace facebook::torchcodec