66
77#include " src/torchcodec/_core/CpuDeviceInterface.h"
88
9- extern " C" {
10- #include < libavfilter/buffersink.h>
11- #include < libavfilter/buffersrc.h>
12- }
13-
149namespace facebook ::torchcodec {
1510namespace {
1611
@@ -20,17 +15,15 @@ static bool g_cpu = registerDeviceInterface(
2015
2116} // namespace
2217
23- bool CpuDeviceInterface::DecodedFrameContext::operator ==(
24- const CpuDeviceInterface::DecodedFrameContext& other) {
25- return decodedWidth == other.decodedWidth &&
26- decodedHeight == other.decodedHeight &&
27- decodedFormat == other.decodedFormat &&
28- expectedWidth == other.expectedWidth &&
29- expectedHeight == other.expectedHeight ;
18+ bool CpuDeviceInterface::SwsFrameContext::operator ==(
19+ const CpuDeviceInterface::SwsFrameContext& other) const {
20+ return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
21+ inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
22+ outputHeight == other.outputHeight ;
3023}
3124
32- bool CpuDeviceInterface::DecodedFrameContext ::operator !=(
33- const CpuDeviceInterface::DecodedFrameContext & other) {
25+ bool CpuDeviceInterface::SwsFrameContext ::operator !=(
26+ const CpuDeviceInterface::SwsFrameContext & other) const {
3427 return !(*this == other);
3528}
3629
@@ -75,22 +68,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
7568 }
7669
7770 torch::Tensor outputTensor;
78- // We need to compare the current frame context with our previous frame
79- // context. If they are different, then we need to re-create our colorspace
80- // conversion objects. We create our colorspace conversion objects late so
81- // that we don't have to depend on the unreliable metadata in the header.
82- // And we sometimes re-create them because it's possible for frame
83- // resolution to change mid-stream. Finally, we want to reuse the colorspace
84- // conversion objects as much as possible for performance reasons.
8571 enum AVPixelFormat frameFormat =
8672 static_cast <enum AVPixelFormat>(avFrame->format );
87- auto frameContext = DecodedFrameContext{
88- avFrame->width ,
89- avFrame->height ,
90- frameFormat,
91- avFrame->sample_aspect_ratio ,
92- expectedOutputWidth,
93- expectedOutputHeight};
9473
9574 // By default, we want to use swscale for color conversion because it is
9675 // faster. However, it has width requirements, so we may need to fall back
@@ -111,12 +90,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
11190 videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
11291
11392 if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93+ // We need to compare the current frame context with our previous frame
94+ // context. If they are different, then we need to re-create our colorspace
95+ // conversion objects. We create our colorspace conversion objects late so
96+ // that we don't have to depend on the unreliable metadata in the header.
97+ // And we sometimes re-create them because it's possible for frame
98+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
99+ // conversion objects as much as possible for performance reasons.
100+ SwsFrameContext swsFrameContext;
101+
102+ swsFrameContext.inputWidth = avFrame->width ;
103+ swsFrameContext.inputHeight = avFrame->height ;
104+ swsFrameContext.inputFormat = frameFormat;
105+ swsFrameContext.outputWidth = expectedOutputWidth;
106+ swsFrameContext.outputHeight = expectedOutputHeight;
107+
114108 outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
115109 expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
116110
117- if (!swsContext_ || prevFrameContext_ != frameContext ) {
118- createSwsContext (frameContext , avFrame->colorspace );
119- prevFrameContext_ = frameContext ;
111+ if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext ) {
112+ createSwsContext (swsFrameContext , avFrame->colorspace );
113+ prevSwsFrameContext_ = swsFrameContext ;
120114 }
121115 int resultHeight =
122116 convertAVFrameToTensorUsingSwsScale (avFrame, outputTensor);
@@ -132,9 +126,29 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
132126
133127 frameOutput.data = outputTensor;
134128 } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
135- if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
136- createFilterGraph (frameContext, videoStreamOptions, timeBase);
137- prevFrameContext_ = frameContext;
129+ // See comment above in swscale branch about the filterGraphContext_
130+ // creation. creation
131+ FiltersContext filtersContext;
132+
133+ filtersContext.inputWidth = avFrame->width ;
134+ filtersContext.inputHeight = avFrame->height ;
135+ filtersContext.inputFormat = frameFormat;
136+ filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio ;
137+ filtersContext.outputWidth = expectedOutputWidth;
138+ filtersContext.outputHeight = expectedOutputHeight;
139+ filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140+ filtersContext.timeBase = timeBase;
141+
142+ std::stringstream filters;
143+ filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
144+ filters << " :sws_flags=bilinear" ;
145+
146+ filtersContext.filtergraphStr = filters.str ();
147+
148+ if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149+ filterGraphContext_ =
150+ std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
151+ prevFiltersContext_ = std::move (filtersContext);
138152 }
139153 outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
140154
@@ -187,14 +201,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
187201
188202torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
189203 const UniqueAVFrame& avFrame) {
190- int status = av_buffersrc_write_frame (
191- filterGraphContext_.sourceContext , avFrame.get ());
192- TORCH_CHECK (
193- status >= AVSUCCESS, " Failed to add frame to buffer source context" );
204+ UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
194205
195- UniqueAVFrame filteredAVFrame (av_frame_alloc ());
196- status = av_buffersink_get_frame (
197- filterGraphContext_.sinkContext , filteredAVFrame.get ());
198206 TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
199207
200208 auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
@@ -210,117 +218,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
210218 filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
211219}
212220
213- void CpuDeviceInterface::createFilterGraph (
214- const DecodedFrameContext& frameContext,
215- const VideoStreamOptions& videoStreamOptions,
216- const AVRational& timeBase) {
217- filterGraphContext_.filterGraph .reset (avfilter_graph_alloc ());
218- TORCH_CHECK (filterGraphContext_.filterGraph .get () != nullptr );
219-
220- if (videoStreamOptions.ffmpegThreadCount .has_value ()) {
221- filterGraphContext_.filterGraph ->nb_threads =
222- videoStreamOptions.ffmpegThreadCount .value ();
223- }
224-
225- const AVFilter* buffersrc = avfilter_get_by_name (" buffer" );
226- const AVFilter* buffersink = avfilter_get_by_name (" buffersink" );
227-
228- std::stringstream filterArgs;
229- filterArgs << " video_size=" << frameContext.decodedWidth << " x"
230- << frameContext.decodedHeight ;
231- filterArgs << " :pix_fmt=" << frameContext.decodedFormat ;
232- filterArgs << " :time_base=" << timeBase.num << " /" << timeBase.den ;
233- filterArgs << " :pixel_aspect=" << frameContext.decodedAspectRatio .num << " /"
234- << frameContext.decodedAspectRatio .den ;
235-
236- int status = avfilter_graph_create_filter (
237- &filterGraphContext_.sourceContext ,
238- buffersrc,
239- " in" ,
240- filterArgs.str ().c_str (),
241- nullptr ,
242- filterGraphContext_.filterGraph .get ());
243- TORCH_CHECK (
244- status >= 0 ,
245- " Failed to create filter graph: " ,
246- filterArgs.str (),
247- " : " ,
248- getFFMPEGErrorStringFromErrorCode (status));
249-
250- status = avfilter_graph_create_filter (
251- &filterGraphContext_.sinkContext ,
252- buffersink,
253- " out" ,
254- nullptr ,
255- nullptr ,
256- filterGraphContext_.filterGraph .get ());
257- TORCH_CHECK (
258- status >= 0 ,
259- " Failed to create filter graph: " ,
260- getFFMPEGErrorStringFromErrorCode (status));
261-
262- enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
263-
264- status = av_opt_set_int_list (
265- filterGraphContext_.sinkContext ,
266- " pix_fmts" ,
267- pix_fmts,
268- AV_PIX_FMT_NONE,
269- AV_OPT_SEARCH_CHILDREN);
270- TORCH_CHECK (
271- status >= 0 ,
272- " Failed to set output pixel formats: " ,
273- getFFMPEGErrorStringFromErrorCode (status));
274-
275- UniqueAVFilterInOut outputs (avfilter_inout_alloc ());
276- UniqueAVFilterInOut inputs (avfilter_inout_alloc ());
277-
278- outputs->name = av_strdup (" in" );
279- outputs->filter_ctx = filterGraphContext_.sourceContext ;
280- outputs->pad_idx = 0 ;
281- outputs->next = nullptr ;
282- inputs->name = av_strdup (" out" );
283- inputs->filter_ctx = filterGraphContext_.sinkContext ;
284- inputs->pad_idx = 0 ;
285- inputs->next = nullptr ;
286-
287- std::stringstream description;
288- description << " scale=" << frameContext.expectedWidth << " :"
289- << frameContext.expectedHeight ;
290- description << " :sws_flags=bilinear" ;
291-
292- AVFilterInOut* outputsTmp = outputs.release ();
293- AVFilterInOut* inputsTmp = inputs.release ();
294- status = avfilter_graph_parse_ptr (
295- filterGraphContext_.filterGraph .get (),
296- description.str ().c_str (),
297- &inputsTmp,
298- &outputsTmp,
299- nullptr );
300- outputs.reset (outputsTmp);
301- inputs.reset (inputsTmp);
302- TORCH_CHECK (
303- status >= 0 ,
304- " Failed to parse filter description: " ,
305- getFFMPEGErrorStringFromErrorCode (status));
306-
307- status =
308- avfilter_graph_config (filterGraphContext_.filterGraph .get (), nullptr );
309- TORCH_CHECK (
310- status >= 0 ,
311- " Failed to configure filter graph: " ,
312- getFFMPEGErrorStringFromErrorCode (status));
313- }
314-
315221void CpuDeviceInterface::createSwsContext (
316- const DecodedFrameContext& frameContext ,
222+ const SwsFrameContext& swsFrameContext ,
317223 const enum AVColorSpace colorspace) {
318224 SwsContext* swsContext = sws_getContext (
319- frameContext. decodedWidth ,
320- frameContext. decodedHeight ,
321- frameContext. decodedFormat ,
322- frameContext. expectedWidth ,
323- frameContext. expectedHeight ,
225+ swsFrameContext. inputWidth ,
226+ swsFrameContext. inputHeight ,
227+ swsFrameContext. inputFormat ,
228+ swsFrameContext. outputWidth ,
229+ swsFrameContext. outputHeight ,
324230 AV_PIX_FMT_RGB24,
325231 SWS_BILINEAR,
326232 nullptr ,
0 commit comments