Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ void CpuDeviceInterface::initializeVideo(
// We calculate this value during initilization but we don't refer to it until
// getColorConversionLibrary() is called. Calculating this value during
// initialization saves us from having to save all of the transforms.
areTransformsSwScaleCompatible_ = transforms.empty() ||
(transforms.size() == 1 && transforms[0]->isResize());
areTransformsSwScaleCompatible_ = transforms.empty();

// Note that we do not expose this capability in the public API, only through
// the core API.
Expand All @@ -57,16 +56,6 @@ void CpuDeviceInterface::initializeVideo(
userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
ColorConversionLibrary::SWSCALE;

// We can only use swscale when we have a single resize transform. Note that
// we actually decide on whether or not to actually use swscale at the last
// possible moment, when we actually convert the frame. This is because we
// need to know the actual frame dimensions.
if (transforms.size() == 1 && transforms[0]->isResize()) {
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!")
swsFlags_ = resize->getSwsFlags();
}

// If we have any transforms, replace filters_ with the filter strings from
// the transforms. As noted above, we decide between swscale and filtergraph
// when we actually decode a frame.
Expand All @@ -80,7 +69,18 @@ void CpuDeviceInterface::initializeVideo(
first = false;
}
if (!transforms.empty()) {
filters_ = filters.str();
// Note [Transform and Format Conversion Order]
// We have to ensure that all user filters happen AFTER the explicit format
// conversion. That is, we want the filters to be applied in RGB24, not the
// pixel format of the input frame.
//
// The ouput frame will always be in RGB24, as we specify the sink node with
// AV_PIX_FORMAT_RGB24. Filtergraph will automatically insert a filter
// conversion to ensure the output frame matches the pixel format
// specified in the sink. But by default, it will insert it after the user
// filters. We need an explicit format conversion to get the behavior we
// want.
filters_ = "format=rgb24," + filters.str();
}

initialized_ = true;
Expand Down Expand Up @@ -218,6 +218,11 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

TORCH_CHECK(
avFrame->height == outputDims.height &&
avFrame->width == outputDims.width,
"Input dimensions are not equal to output dimensions; resize for sws_scale() is not yet supported.");

// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
Expand All @@ -234,7 +239,16 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(

if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
swsContext_ = createSwsContext(
swsFrameContext, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_);
swsFrameContext,
avFrame->colorspace,

// See [Transform and Format Conversion Order] for more on the output
// pixel format.
/*outputFormat=*/AV_PIX_FMT_RGB24,

// We don't set any flags because we don't yet use sw_scale() for
// resizing.
/*swsFlags=*/0);
prevSwsFrameContext_ = swsFrameContext;
}

Expand All @@ -256,17 +270,17 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame,
const FrameDims& outputDims) {
enum AVPixelFormat frameFormat =
enum AVPixelFormat avFrameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

FiltersContext filtersContext(
avFrame->width,
avFrame->height,
frameFormat,
avFrameFormat,
avFrame->sample_aspect_ratio,
outputDims.width,
outputDims.height,
AV_PIX_FMT_RGB24,
/*outputFormat=*/AV_PIX_FMT_RGB24,
filters_,
timeBase_);

Expand Down
23 changes: 11 additions & 12 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,20 @@ class CpuDeviceInterface : public DeviceInterface {
UniqueSwsContext swsContext_;
SwsFrameContext prevSwsFrameContext_;

// The filter we supply to filterGraph_, if it is used. The default is the
// copy filter, which just copies the input to the output. Computationally, it
// should be a no-op. If we get no user-provided transforms, we will use the
// copy filter. Otherwise, we will construct the string from the transforms.
// We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline
// of what FFmpeg calls "filters" to apply to decoded frames before returning
// them. In the PyTorch ecosystem, we call these "transforms". During
// initialization, we convert the user-supplied transforms into this string of
// filters.
//
// Note that even if we only use the copy filter, we still get the desired
// colorspace conversion. We construct the filtergraph with its output sink
// set to RGB24.
// Note that if there are no user-supplied transforms, then the default filter
// we use is the copy filter, which is just an identity: it emits the output
// frame unchanged. We supply such a filter because we can't supply just the
// empty-string; we must supply SOME filter.
//
// See also [Tranform and Format Conversion Order] for more on filters.
std::string filters_ = "copy";

// The flags we supply to swsContext_, if it used. The flags control the
// resizing algorithm. We default to bilinear. Users can override this with a
// ResizeTransform.
int swsFlags_ = SWS_BILINEAR;

// Values set during initialization and referred to in
// getColorConversionLibrary().
bool areTransformsSwScaleCompatible_;
Expand Down
45 changes: 21 additions & 24 deletions src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,68 +399,65 @@ SwrContext* createSwrContext(
return swrContext;
}

AVFilterContext* createBuffersinkFilter(
AVFilterContext* createAVFilterContextWithOptions(
AVFilterGraph* filterGraph,
enum AVPixelFormat outputFormat) {
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
TORCH_CHECK(buffersink != nullptr, "Failed to get buffersink filter.");

AVFilterContext* sinkContext = nullptr;
int status;
const AVFilter* buffer,
const enum AVPixelFormat outputFormat) {
AVFilterContext* avFilterContext = nullptr;
const char* filterName = "out";

enum AVPixelFormat pix_fmts[] = {outputFormat, AV_PIX_FMT_NONE};
enum AVPixelFormat pixFmts[] = {outputFormat, AV_PIX_FMT_NONE};

// av_opt_set_int_list was replaced by av_opt_set_array() in FFmpeg 8.
#if LIBAVUTIL_VERSION_MAJOR >= 60 // FFmpeg >= 8
// Output options like pixel_formats must be set before filter init
sinkContext =
avfilter_graph_alloc_filter(filterGraph, buffersink, filterName);
avFilterContext =
avfilter_graph_alloc_filter(filterGraph, buffer, filterName);
TORCH_CHECK(
sinkContext != nullptr, "Failed to allocate buffersink filter context.");
avFilterContext != nullptr, "Failed to allocate buffer filter context.");

// When setting pix_fmts, only the first element is used, so nb_elems = 1
// AV_PIX_FMT_NONE acts as a terminator for the array in av_opt_set_int_list
status = av_opt_set_array(
sinkContext,
int status = av_opt_set_array(
avFilterContext,
"pixel_formats",
AV_OPT_SEARCH_CHILDREN,
0, // start_elem
1, // nb_elems
AV_OPT_TYPE_PIXEL_FMT,
pix_fmts);
pixFmts);
TORCH_CHECK(
status >= 0,
"Failed to set pixel format for buffersink filter: ",
"Failed to set pixel format for buffer filter: ",
getFFMPEGErrorStringFromErrorCode(status));

status = avfilter_init_str(sinkContext, nullptr);
status = avfilter_init_str(avFilterContext, nullptr);
TORCH_CHECK(
status >= 0,
"Failed to initialize buffersink filter: ",
"Failed to initialize buffer filter: ",
getFFMPEGErrorStringFromErrorCode(status));
#else // FFmpeg <= 7
// For older FFmpeg versions, create filter and then set options
status = avfilter_graph_create_filter(
&sinkContext, buffersink, filterName, nullptr, nullptr, filterGraph);
int status = avfilter_graph_create_filter(
&avFilterContext, buffer, filterName, nullptr, nullptr, filterGraph);
TORCH_CHECK(
status >= 0,
"Failed to create buffersink filter: ",
"Failed to create buffer filter: ",
getFFMPEGErrorStringFromErrorCode(status));

status = av_opt_set_int_list(
sinkContext,
avFilterContext,
"pix_fmts",
pix_fmts,
pixFmts,
AV_PIX_FMT_NONE,
AV_OPT_SEARCH_CHILDREN);
TORCH_CHECK(
status >= 0,
"Failed to set pixel formats for buffersink filter: ",
"Failed to set pixel formats for buffer filter: ",
getFFMPEGErrorStringFromErrorCode(status));
#endif

return sinkContext;
return avFilterContext;
}

UniqueAVFrame convertAudioAVFrameSamples(
Expand Down
9 changes: 5 additions & 4 deletions src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,10 @@ int64_t computeSafeDuration(
const AVRational& frameRate,
const AVRational& timeBase);

AVFilterContext* createBuffersinkFilter(
AVFilterContext* createAVFilterContextWithOptions(
AVFilterGraph* filterGraph,
enum AVPixelFormat outputFormat);
const AVFilter* buffer,
const enum AVPixelFormat outputFormat);

struct SwsFrameContext {
int inputWidth = 0;
Expand All @@ -274,7 +275,7 @@ struct SwsFrameContext {
UniqueSwsContext createSwsContext(
const SwsFrameContext& swsFrameContext,
AVColorSpace colorspace,
AVPixelFormat outputFormat = AV_PIX_FMT_RGB24,
int swsFlags = SWS_BILINEAR);
AVPixelFormat outputFormat,
int swsFlags);

} // namespace facebook::torchcodec
23 changes: 16 additions & 7 deletions src/torchcodec/_core/FilterGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ FilterGraph::FilterGraph(
filterGraph_->nb_threads = videoStreamOptions.ffmpegThreadCount.value();
}

const AVFilter* buffersrc = avfilter_get_by_name("buffer");

// Configure the source context.
const AVFilter* bufferSrc = avfilter_get_by_name("buffer");
UniqueAVBufferSrcParameters srcParams(av_buffersrc_parameters_alloc());
TORCH_CHECK(srcParams, "Failed to allocate buffersrc params");

Expand All @@ -78,7 +78,7 @@ FilterGraph::FilterGraph(
}

sourceContext_ =
avfilter_graph_alloc_filter(filterGraph_.get(), buffersrc, "in");
avfilter_graph_alloc_filter(filterGraph_.get(), bufferSrc, "in");
TORCH_CHECK(sourceContext_, "Failed to allocate filter graph");

int status = av_buffersrc_parameters_set(sourceContext_, srcParams.get());
Expand All @@ -93,23 +93,31 @@ FilterGraph::FilterGraph(
"Failed to create filter graph : ",
getFFMPEGErrorStringFromErrorCode(status));

sinkContext_ =
createBuffersinkFilter(filterGraph_.get(), filtersContext.outputFormat);
// Configure the sink context.
const AVFilter* bufferSink = avfilter_get_by_name("buffersink");
TORCH_CHECK(bufferSink != nullptr, "Failed to get buffersink filter.");

sinkContext_ = createAVFilterContextWithOptions(
filterGraph_.get(), bufferSink, filtersContext.outputFormat);
TORCH_CHECK(
sinkContext_ != nullptr, "Failed to create and configure buffersink");

// Create the filtergraph nodes based on the source and sink contexts.
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
UniqueAVFilterInOut inputs(avfilter_inout_alloc());

outputs->name = av_strdup("in");
outputs->filter_ctx = sourceContext_;
outputs->pad_idx = 0;
outputs->next = nullptr;

UniqueAVFilterInOut inputs(avfilter_inout_alloc());
inputs->name = av_strdup("out");
inputs->filter_ctx = sinkContext_;
inputs->pad_idx = 0;
inputs->next = nullptr;

// Create the filtergraph specified by the filtergraph string in the context
// of the inputs and outputs. Note the dance we have to do with release and
// resetting the output and input nodes because FFmpeg modifies them in place.
AVFilterInOut* outputsTmp = outputs.release();
AVFilterInOut* inputsTmp = inputs.release();
status = avfilter_graph_parse_ptr(
Expand All @@ -126,6 +134,7 @@ FilterGraph::FilterGraph(
getFFMPEGErrorStringFromErrorCode(status),
", provided filters: " + filtersContext.filtergraphStr);

// Check filtergraph validity and configure links and formats.
status = avfilter_graph_config(filterGraph_.get(), nullptr);
TORCH_CHECK(
status >= 0,
Expand Down
22 changes: 1 addition & 21 deletions src/torchcodec/_core/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,18 @@ std::string toFilterGraphInterpolation(
}
}

int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
switch (mode) {
case ResizeTransform::InterpolationMode::BILINEAR:
return SWS_BILINEAR;
default:
TORCH_CHECK(
false,
"Unknown interpolation mode: " +
std::to_string(static_cast<int>(mode)));
}
}

} // namespace

std::string ResizeTransform::getFilterGraphCpu() const {
return "scale=" + std::to_string(outputDims_.width) + ":" +
std::to_string(outputDims_.height) +
":sws_flags=" + toFilterGraphInterpolation(interpolationMode_);
":flags=" + toFilterGraphInterpolation(interpolationMode_);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the FFmpeg docs:

Libavfilter will automatically insert scale filters where format conversion is required. It is possible to specify swscale flags for those automatically inserted scalers by prepending sws_flags=flags; to the filtergraph description.

Whereas flags is the specific parameter to scale. They end up being semantically equivalent, but it's more clear to use the scale option here.

}

std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
return outputDims_;
}

bool ResizeTransform::isResize() const {
return true;
}

int ResizeTransform::getSwsFlags() const {
return toSwsInterpolation(interpolationMode_);
}

CropTransform::CropTransform(const FrameDims& dims, int x, int y)
: outputDims_(dims), x_(x), y_(y) {
TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_);
Expand Down
9 changes: 0 additions & 9 deletions src/torchcodec/_core/Transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ class Transform {
return std::nullopt;
}

// The ResizeTransform is special, because it is the only transform that
// swscale can handle.
virtual bool isResize() const {
return false;
}

// The validity of some transforms depends on the characteristics of the
// AVStream they're being applied to. For example, some transforms will
// specify coordinates inside a frame, we need to validate that those are
Expand All @@ -58,9 +52,6 @@ class ResizeTransform : public Transform {

std::string getFilterGraphCpu() const override;
std::optional<FrameDims> getOutputFrameDims() const override;
bool isResize() const override;

int getSwsFlags() const;

private:
FrameDims outputDims_;
Expand Down
Loading
Loading