Skip to content

Commit afd5aba

Browse files
authored
Proper resize tests; remove swscale resize (#1013)
1 parent c552b60 commit afd5aba

13 files changed

+248
-154
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ void CpuDeviceInterface::initializeVideo(
4747
// We calculate this value during initilization but we don't refer to it until
4848
// getColorConversionLibrary() is called. Calculating this value during
4949
// initialization saves us from having to save all of the transforms.
50-
areTransformsSwScaleCompatible_ = transforms.empty() ||
51-
(transforms.size() == 1 && transforms[0]->isResize());
50+
areTransformsSwScaleCompatible_ = transforms.empty();
5251

5352
// Note that we do not expose this capability in the public API, only through
5453
// the core API.
@@ -58,16 +57,6 @@ void CpuDeviceInterface::initializeVideo(
5857
userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
5958
ColorConversionLibrary::SWSCALE;
6059

61-
// We can only use swscale when we have a single resize transform. Note that
62-
// we actually decide on whether or not to actually use swscale at the last
63-
// possible moment, when we actually convert the frame. This is because we
64-
// need to know the actual frame dimensions.
65-
if (transforms.size() == 1 && transforms[0]->isResize()) {
66-
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
67-
TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!")
68-
swsFlags_ = resize->getSwsFlags();
69-
}
70-
7160
// If we have any transforms, replace filters_ with the filter strings from
7261
// the transforms. As noted above, we decide between swscale and filtergraph
7362
// when we actually decode a frame.
@@ -81,7 +70,18 @@ void CpuDeviceInterface::initializeVideo(
8170
first = false;
8271
}
8372
if (!transforms.empty()) {
84-
filters_ = filters.str();
73+
// Note [Transform and Format Conversion Order]
74+
// We have to ensure that all user filters happen AFTER the explicit format
75+
// conversion. That is, we want the filters to be applied in RGB24, not the
76+
// pixel format of the input frame.
77+
//
78+
// The ouput frame will always be in RGB24, as we specify the sink node with
79+
// AV_PIX_FORMAT_RGB24. Filtergraph will automatically insert a filter
80+
// conversion to ensure the output frame matches the pixel format
81+
// specified in the sink. But by default, it will insert it after the user
82+
// filters. We need an explicit format conversion to get the behavior we
83+
// want.
84+
filters_ = "format=rgb24," + filters.str();
8585
}
8686

8787
initialized_ = true;
@@ -238,6 +238,11 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
238238
enum AVPixelFormat frameFormat =
239239
static_cast<enum AVPixelFormat>(avFrame->format);
240240

241+
TORCH_CHECK(
242+
avFrame->height == outputDims.height &&
243+
avFrame->width == outputDims.width,
244+
"Input dimensions are not equal to output dimensions; resize for sws_scale() is not yet supported.");
245+
241246
// We need to compare the current frame context with our previous frame
242247
// context. If they are different, then we need to re-create our colorspace
243248
// conversion objects. We create our colorspace conversion objects late so
@@ -254,7 +259,16 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
254259

255260
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
256261
swsContext_ = createSwsContext(
257-
swsFrameContext, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_);
262+
swsFrameContext,
263+
avFrame->colorspace,
264+
265+
// See [Transform and Format Conversion Order] for more on the output
266+
// pixel format.
267+
/*outputFormat=*/AV_PIX_FMT_RGB24,
268+
269+
// We don't set any flags because we don't yet use sw_scale() for
270+
// resizing.
271+
/*swsFlags=*/0);
258272
prevSwsFrameContext_ = swsFrameContext;
259273
}
260274

@@ -276,17 +290,17 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
276290
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
277291
const UniqueAVFrame& avFrame,
278292
const FrameDims& outputDims) {
279-
enum AVPixelFormat frameFormat =
293+
enum AVPixelFormat avFrameFormat =
280294
static_cast<enum AVPixelFormat>(avFrame->format);
281295

282296
FiltersContext filtersContext(
283297
avFrame->width,
284298
avFrame->height,
285-
frameFormat,
299+
avFrameFormat,
286300
avFrame->sample_aspect_ratio,
287301
outputDims.width,
288302
outputDims.height,
289-
AV_PIX_FMT_RGB24,
303+
/*outputFormat=*/AV_PIX_FMT_RGB24,
290304
filters_,
291305
timeBase_);
292306

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,20 @@ class CpuDeviceInterface : public DeviceInterface {
100100
UniqueSwsContext swsContext_;
101101
SwsFrameContext prevSwsFrameContext_;
102102

103-
// The filter we supply to filterGraph_, if it is used. The default is the
104-
// copy filter, which just copies the input to the output. Computationally, it
105-
// should be a no-op. If we get no user-provided transforms, we will use the
106-
// copy filter. Otherwise, we will construct the string from the transforms.
103+
// We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline
104+
// of what FFmpeg calls "filters" to apply to decoded frames before returning
105+
// them. In the PyTorch ecosystem, we call these "transforms". During
106+
// initialization, we convert the user-supplied transforms into this string of
107+
// filters.
107108
//
108-
// Note that even if we only use the copy filter, we still get the desired
109-
// colorspace conversion. We construct the filtergraph with its output sink
110-
// set to RGB24.
109+
// Note that if there are no user-supplied transforms, then the default filter
110+
// we use is the copy filter, which is just an identity: it emits the output
111+
// frame unchanged. We supply such a filter because we can't supply just the
112+
// empty-string; we must supply SOME filter.
113+
//
114+
// See also [Tranform and Format Conversion Order] for more on filters.
111115
std::string filters_ = "copy";
112116

113-
// The flags we supply to swsContext_, if it used. The flags control the
114-
// resizing algorithm. We default to bilinear. Users can override this with a
115-
// ResizeTransform.
116-
int swsFlags_ = SWS_BILINEAR;
117-
118117
// Values set during initialization and referred to in
119118
// getColorConversionLibrary().
120119
bool areTransformsSwScaleCompatible_;

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -399,68 +399,65 @@ SwrContext* createSwrContext(
399399
return swrContext;
400400
}
401401

402-
AVFilterContext* createBuffersinkFilter(
402+
AVFilterContext* createAVFilterContextWithOptions(
403403
AVFilterGraph* filterGraph,
404-
enum AVPixelFormat outputFormat) {
405-
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
406-
TORCH_CHECK(buffersink != nullptr, "Failed to get buffersink filter.");
407-
408-
AVFilterContext* sinkContext = nullptr;
409-
int status;
404+
const AVFilter* buffer,
405+
const enum AVPixelFormat outputFormat) {
406+
AVFilterContext* avFilterContext = nullptr;
410407
const char* filterName = "out";
411408

412-
enum AVPixelFormat pix_fmts[] = {outputFormat, AV_PIX_FMT_NONE};
409+
enum AVPixelFormat pixFmts[] = {outputFormat, AV_PIX_FMT_NONE};
413410

414411
// av_opt_set_int_list was replaced by av_opt_set_array() in FFmpeg 8.
415412
#if LIBAVUTIL_VERSION_MAJOR >= 60 // FFmpeg >= 8
416413
// Output options like pixel_formats must be set before filter init
417-
sinkContext =
418-
avfilter_graph_alloc_filter(filterGraph, buffersink, filterName);
414+
avFilterContext =
415+
avfilter_graph_alloc_filter(filterGraph, buffer, filterName);
419416
TORCH_CHECK(
420-
sinkContext != nullptr, "Failed to allocate buffersink filter context.");
417+
avFilterContext != nullptr, "Failed to allocate buffer filter context.");
421418

422419
// When setting pix_fmts, only the first element is used, so nb_elems = 1
423420
// AV_PIX_FMT_NONE acts as a terminator for the array in av_opt_set_int_list
424-
status = av_opt_set_array(
425-
sinkContext,
421+
int status = av_opt_set_array(
422+
avFilterContext,
426423
"pixel_formats",
427424
AV_OPT_SEARCH_CHILDREN,
428425
0, // start_elem
429426
1, // nb_elems
430427
AV_OPT_TYPE_PIXEL_FMT,
431-
pix_fmts);
428+
pixFmts);
432429
TORCH_CHECK(
433430
status >= 0,
434-
"Failed to set pixel format for buffersink filter: ",
431+
"Failed to set pixel format for buffer filter: ",
435432
getFFMPEGErrorStringFromErrorCode(status));
436433

437-
status = avfilter_init_str(sinkContext, nullptr);
434+
status = avfilter_init_str(avFilterContext, nullptr);
438435
TORCH_CHECK(
439436
status >= 0,
440-
"Failed to initialize buffersink filter: ",
437+
"Failed to initialize buffer filter: ",
441438
getFFMPEGErrorStringFromErrorCode(status));
442439
#else // FFmpeg <= 7
443440
// For older FFmpeg versions, create filter and then set options
444-
status = avfilter_graph_create_filter(
445-
&sinkContext, buffersink, filterName, nullptr, nullptr, filterGraph);
441+
int status = avfilter_graph_create_filter(
442+
&avFilterContext, buffer, filterName, nullptr, nullptr, filterGraph);
446443
TORCH_CHECK(
447444
status >= 0,
448-
"Failed to create buffersink filter: ",
445+
"Failed to create buffer filter: ",
449446
getFFMPEGErrorStringFromErrorCode(status));
450447

451448
status = av_opt_set_int_list(
452-
sinkContext,
449+
avFilterContext,
453450
"pix_fmts",
454-
pix_fmts,
451+
pixFmts,
455452
AV_PIX_FMT_NONE,
456453
AV_OPT_SEARCH_CHILDREN);
457454
TORCH_CHECK(
458455
status >= 0,
459-
"Failed to set pixel formats for buffersink filter: ",
456+
"Failed to set pixel formats for buffer filter: ",
460457
getFFMPEGErrorStringFromErrorCode(status));
461458
#endif
462459

463-
return sinkContext;
460+
return avFilterContext;
464461
}
465462

466463
UniqueAVFrame convertAudioAVFrameSamples(

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,10 @@ int64_t computeSafeDuration(
247247
const AVRational& frameRate,
248248
const AVRational& timeBase);
249249

250-
AVFilterContext* createBuffersinkFilter(
250+
AVFilterContext* createAVFilterContextWithOptions(
251251
AVFilterGraph* filterGraph,
252-
enum AVPixelFormat outputFormat);
252+
const AVFilter* buffer,
253+
const enum AVPixelFormat outputFormat);
253254

254255
struct SwsFrameContext {
255256
int inputWidth = 0;
@@ -274,7 +275,7 @@ struct SwsFrameContext {
274275
UniqueSwsContext createSwsContext(
275276
const SwsFrameContext& swsFrameContext,
276277
AVColorSpace colorspace,
277-
AVPixelFormat outputFormat = AV_PIX_FMT_RGB24,
278-
int swsFlags = SWS_BILINEAR);
278+
AVPixelFormat outputFormat,
279+
int swsFlags);
279280

280281
} // namespace facebook::torchcodec

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ FilterGraph::FilterGraph(
6363
filterGraph_->nb_threads = videoStreamOptions.ffmpegThreadCount.value();
6464
}
6565

66-
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
67-
66+
// Configure the source context.
67+
const AVFilter* bufferSrc = avfilter_get_by_name("buffer");
6868
UniqueAVBufferSrcParameters srcParams(av_buffersrc_parameters_alloc());
6969
TORCH_CHECK(srcParams, "Failed to allocate buffersrc params");
7070

@@ -78,7 +78,7 @@ FilterGraph::FilterGraph(
7878
}
7979

8080
sourceContext_ =
81-
avfilter_graph_alloc_filter(filterGraph_.get(), buffersrc, "in");
81+
avfilter_graph_alloc_filter(filterGraph_.get(), bufferSrc, "in");
8282
TORCH_CHECK(sourceContext_, "Failed to allocate filter graph");
8383

8484
int status = av_buffersrc_parameters_set(sourceContext_, srcParams.get());
@@ -93,23 +93,31 @@ FilterGraph::FilterGraph(
9393
"Failed to create filter graph : ",
9494
getFFMPEGErrorStringFromErrorCode(status));
9595

96-
sinkContext_ =
97-
createBuffersinkFilter(filterGraph_.get(), filtersContext.outputFormat);
96+
// Configure the sink context.
97+
const AVFilter* bufferSink = avfilter_get_by_name("buffersink");
98+
TORCH_CHECK(bufferSink != nullptr, "Failed to get buffersink filter.");
99+
100+
sinkContext_ = createAVFilterContextWithOptions(
101+
filterGraph_.get(), bufferSink, filtersContext.outputFormat);
98102
TORCH_CHECK(
99103
sinkContext_ != nullptr, "Failed to create and configure buffersink");
100104

105+
// Create the filtergraph nodes based on the source and sink contexts.
101106
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
102-
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
103-
104107
outputs->name = av_strdup("in");
105108
outputs->filter_ctx = sourceContext_;
106109
outputs->pad_idx = 0;
107110
outputs->next = nullptr;
111+
112+
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
108113
inputs->name = av_strdup("out");
109114
inputs->filter_ctx = sinkContext_;
110115
inputs->pad_idx = 0;
111116
inputs->next = nullptr;
112117

118+
// Create the filtergraph specified by the filtergraph string in the context
119+
// of the inputs and outputs. Note the dance we have to do with release and
120+
// resetting the output and input nodes because FFmpeg modifies them in place.
113121
AVFilterInOut* outputsTmp = outputs.release();
114122
AVFilterInOut* inputsTmp = inputs.release();
115123
status = avfilter_graph_parse_ptr(
@@ -126,6 +134,7 @@ FilterGraph::FilterGraph(
126134
getFFMPEGErrorStringFromErrorCode(status),
127135
", provided filters: " + filtersContext.filtergraphStr);
128136

137+
// Check filtergraph validity and configure links and formats.
129138
status = avfilter_graph_config(filterGraph_.get(), nullptr);
130139
TORCH_CHECK(
131140
status >= 0,

src/torchcodec/_core/Transform.cpp

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,38 +25,18 @@ std::string toFilterGraphInterpolation(
2525
}
2626
}
2727

28-
int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
29-
switch (mode) {
30-
case ResizeTransform::InterpolationMode::BILINEAR:
31-
return SWS_BILINEAR;
32-
default:
33-
TORCH_CHECK(
34-
false,
35-
"Unknown interpolation mode: " +
36-
std::to_string(static_cast<int>(mode)));
37-
}
38-
}
39-
4028
} // namespace
4129

4230
std::string ResizeTransform::getFilterGraphCpu() const {
4331
return "scale=" + std::to_string(outputDims_.width) + ":" +
4432
std::to_string(outputDims_.height) +
45-
":sws_flags=" + toFilterGraphInterpolation(interpolationMode_);
33+
":flags=" + toFilterGraphInterpolation(interpolationMode_);
4634
}
4735

4836
std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
4937
return outputDims_;
5038
}
5139

52-
bool ResizeTransform::isResize() const {
53-
return true;
54-
}
55-
56-
int ResizeTransform::getSwsFlags() const {
57-
return toSwsInterpolation(interpolationMode_);
58-
}
59-
6040
CropTransform::CropTransform(const FrameDims& dims, int x, int y)
6141
: outputDims_(dims), x_(x), y_(y) {
6242
TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_);

src/torchcodec/_core/Transform.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@ class Transform {
2929
return std::nullopt;
3030
}
3131

32-
// The ResizeTransform is special, because it is the only transform that
33-
// swscale can handle.
34-
virtual bool isResize() const {
35-
return false;
36-
}
37-
3832
// The validity of some transforms depends on the characteristics of the
3933
// AVStream they're being applied to. For example, some transforms will
4034
// specify coordinates inside a frame, we need to validate that those are
@@ -58,9 +52,6 @@ class ResizeTransform : public Transform {
5852

5953
std::string getFilterGraphCpu() const override;
6054
std::optional<FrameDims> getOutputFrameDims() const override;
61-
bool isResize() const override;
62-
63-
int getSwsFlags() const;
6455

6556
private:
6657
FrameDims outputDims_;

0 commit comments

Comments
 (0)