66
77#include " src/torchcodec/_core/CpuDeviceInterface.h"
88
9- extern " C" {
10- #include < libavfilter/buffersink.h>
11- #include < libavfilter/buffersrc.h>
12- }
13-
149namespace facebook ::torchcodec {
1510namespace {
1611
@@ -20,20 +15,6 @@ static bool g_cpu = registerDeviceInterface(
2015
2116} // namespace
2217
23- bool CpuDeviceInterface::DecodedFrameContext::operator ==(
24- const CpuDeviceInterface::DecodedFrameContext& other) {
25- return decodedWidth == other.decodedWidth &&
26- decodedHeight == other.decodedHeight &&
27- decodedFormat == other.decodedFormat &&
28- expectedWidth == other.expectedWidth &&
29- expectedHeight == other.expectedHeight ;
30- }
31-
32- bool CpuDeviceInterface::DecodedFrameContext::operator !=(
33- const CpuDeviceInterface::DecodedFrameContext& other) {
34- return !(*this == other);
35- }
36-
3718CpuDeviceInterface::CpuDeviceInterface (const torch::Device& device)
3819 : DeviceInterface(device) {
3920 TORCH_CHECK (g_cpu, " CpuDeviceInterface was not registered!" );
@@ -132,8 +113,9 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
132113
133114 frameOutput.data = outputTensor;
134115 } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
135- if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
136- createFilterGraph (frameContext, videoStreamOptions, timeBase);
116+ if (!filterGraphContext_ || prevFrameContext_ != frameContext) {
117+ filterGraphContext_ = std::make_unique<FilterGraph>(
118+ frameContext, videoStreamOptions, timeBase);
137119 prevFrameContext_ = frameContext;
138120 }
139121 outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
@@ -187,14 +169,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
187169
188170torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
189171 const UniqueAVFrame& avFrame) {
190- int status = av_buffersrc_write_frame (
191- filterGraphContext_.sourceContext , avFrame.get ());
192- TORCH_CHECK (
193- status >= AVSUCCESS, " Failed to add frame to buffer source context" );
172+ UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
194173
195- UniqueAVFrame filteredAVFrame (av_frame_alloc ());
196- status = av_buffersink_get_frame (
197- filterGraphContext_.sinkContext , filteredAVFrame.get ());
198174 TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
199175
200176 auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
@@ -210,108 +186,6 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
210186 filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
211187}
212188
213- void CpuDeviceInterface::createFilterGraph (
214- const DecodedFrameContext& frameContext,
215- const VideoStreamOptions& videoStreamOptions,
216- const AVRational& timeBase) {
217- filterGraphContext_.filterGraph .reset (avfilter_graph_alloc ());
218- TORCH_CHECK (filterGraphContext_.filterGraph .get () != nullptr );
219-
220- if (videoStreamOptions.ffmpegThreadCount .has_value ()) {
221- filterGraphContext_.filterGraph ->nb_threads =
222- videoStreamOptions.ffmpegThreadCount .value ();
223- }
224-
225- const AVFilter* buffersrc = avfilter_get_by_name (" buffer" );
226- const AVFilter* buffersink = avfilter_get_by_name (" buffersink" );
227-
228- std::stringstream filterArgs;
229- filterArgs << " video_size=" << frameContext.decodedWidth << " x"
230- << frameContext.decodedHeight ;
231- filterArgs << " :pix_fmt=" << frameContext.decodedFormat ;
232- filterArgs << " :time_base=" << timeBase.num << " /" << timeBase.den ;
233- filterArgs << " :pixel_aspect=" << frameContext.decodedAspectRatio .num << " /"
234- << frameContext.decodedAspectRatio .den ;
235-
236- int status = avfilter_graph_create_filter (
237- &filterGraphContext_.sourceContext ,
238- buffersrc,
239- " in" ,
240- filterArgs.str ().c_str (),
241- nullptr ,
242- filterGraphContext_.filterGraph .get ());
243- TORCH_CHECK (
244- status >= 0 ,
245- " Failed to create filter graph: " ,
246- filterArgs.str (),
247- " : " ,
248- getFFMPEGErrorStringFromErrorCode (status));
249-
250- status = avfilter_graph_create_filter (
251- &filterGraphContext_.sinkContext ,
252- buffersink,
253- " out" ,
254- nullptr ,
255- nullptr ,
256- filterGraphContext_.filterGraph .get ());
257- TORCH_CHECK (
258- status >= 0 ,
259- " Failed to create filter graph: " ,
260- getFFMPEGErrorStringFromErrorCode (status));
261-
262- enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
263-
264- status = av_opt_set_int_list (
265- filterGraphContext_.sinkContext ,
266- " pix_fmts" ,
267- pix_fmts,
268- AV_PIX_FMT_NONE,
269- AV_OPT_SEARCH_CHILDREN);
270- TORCH_CHECK (
271- status >= 0 ,
272- " Failed to set output pixel formats: " ,
273- getFFMPEGErrorStringFromErrorCode (status));
274-
275- UniqueAVFilterInOut outputs (avfilter_inout_alloc ());
276- UniqueAVFilterInOut inputs (avfilter_inout_alloc ());
277-
278- outputs->name = av_strdup (" in" );
279- outputs->filter_ctx = filterGraphContext_.sourceContext ;
280- outputs->pad_idx = 0 ;
281- outputs->next = nullptr ;
282- inputs->name = av_strdup (" out" );
283- inputs->filter_ctx = filterGraphContext_.sinkContext ;
284- inputs->pad_idx = 0 ;
285- inputs->next = nullptr ;
286-
287- std::stringstream description;
288- description << " scale=" << frameContext.expectedWidth << " :"
289- << frameContext.expectedHeight ;
290- description << " :sws_flags=bilinear" ;
291-
292- AVFilterInOut* outputsTmp = outputs.release ();
293- AVFilterInOut* inputsTmp = inputs.release ();
294- status = avfilter_graph_parse_ptr (
295- filterGraphContext_.filterGraph .get (),
296- description.str ().c_str (),
297- &inputsTmp,
298- &outputsTmp,
299- nullptr );
300- outputs.reset (outputsTmp);
301- inputs.reset (inputsTmp);
302- TORCH_CHECK (
303- status >= 0 ,
304- " Failed to parse filter description: " ,
305- getFFMPEGErrorStringFromErrorCode (status));
306-
307- status =
308- avfilter_graph_config (filterGraphContext_.filterGraph .get (), nullptr );
309- TORCH_CHECK (
310- status >= 0 ,
311- " Failed to configure filter graph: " ,
312- getFFMPEGErrorStringFromErrorCode (status));
313- }
314-
315189void CpuDeviceInterface::createSwsContext (
316190 const DecodedFrameContext& frameContext,
317191 const enum AVColorSpace colorspace) {
0 commit comments