@@ -13,22 +13,26 @@ extern "C" {
1313
1414namespace facebook ::torchcodec {
1515
16- bool DecodedFrameContext::operator ==(const DecodedFrameContext& other) {
17- return decodedWidth == other.decodedWidth &&
18- decodedHeight == other.decodedHeight &&
19- decodedFormat == other.decodedFormat &&
20- expectedWidth == other.expectedWidth &&
21- expectedHeight == other.expectedHeight ;
16+ bool operator ==(const AVRational& lhs, const AVRational& rhs) {
17+ return lhs.num == rhs.num && lhs.den == rhs.den ;
2218}
2319
24- bool DecodedFrameContext::operator !=(const DecodedFrameContext& other) {
20+ bool FiltersContext::operator ==(const FiltersContext& other) {
21+ return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
22+ inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
23+ outputHeight == other.outputHeight &&
24+ outputFormat == other.outputFormat && filters == other.filters &&
25+ timeBase == other.timeBase &&
26+ hwFramesCtx.get () == other.hwFramesCtx .get ();
27+ }
28+
29+ bool FiltersContext::operator !=(const FiltersContext& other) {
2530 return !(*this == other);
2631}
2732
2833FilterGraph::FilterGraph (
29- const DecodedFrameContext& frameContext,
30- const VideoStreamOptions& videoStreamOptions,
31- const AVRational& timeBase) {
34+ const FiltersContext& filtersContext,
35+ const VideoStreamOptions& videoStreamOptions) {
3236 filterGraph_.reset (avfilter_graph_alloc ());
3337 TORCH_CHECK (filterGraph_.get () != nullptr );
3438
@@ -39,26 +43,40 @@ FilterGraph::FilterGraph(
3943 const AVFilter* buffersrc = avfilter_get_by_name (" buffer" );
4044 const AVFilter* buffersink = avfilter_get_by_name (" buffersink" );
4145
42- std::stringstream filterArgs;
43- filterArgs << " video_size=" << frameContext.decodedWidth << " x"
44- << frameContext.decodedHeight ;
45- filterArgs << " :pix_fmt=" << frameContext.decodedFormat ;
46- filterArgs << " :time_base=" << timeBase.num << " /" << timeBase.den ;
47- filterArgs << " :pixel_aspect=" << frameContext.decodedAspectRatio .num << " /"
48- << frameContext.decodedAspectRatio .den ;
49-
50- int status = avfilter_graph_create_filter (
51- &sourceContext_,
52- buffersrc,
53- " in" ,
54- filterArgs.str ().c_str (),
55- nullptr ,
56- filterGraph_.get ());
46+ auto deleter = [](AVBufferSrcParameters* p) {
47+ if (p) {
48+ av_freep (&p);
49+ }
50+ };
51+ std::unique_ptr<AVBufferSrcParameters, decltype (deleter)> srcParams (
52+ nullptr , deleter);
53+
54+ srcParams.reset (av_buffersrc_parameters_alloc ());
55+ TORCH_CHECK (srcParams, " Failed to allocate buffersrc params" );
56+
57+ srcParams->format = filtersContext.inputFormat ;
58+ srcParams->width = filtersContext.inputWidth ;
59+ srcParams->height = filtersContext.inputHeight ;
60+ srcParams->sample_aspect_ratio = filtersContext.inputAspectRatio ;
61+ srcParams->time_base = filtersContext.timeBase ;
62+ if (filtersContext.hwFramesCtx ) {
63+ srcParams->hw_frames_ctx = av_buffer_ref (filtersContext.hwFramesCtx .get ());
64+ }
65+
66+ sourceContext_ =
67+ avfilter_graph_alloc_filter (filterGraph_.get (), buffersrc, " in" );
68+ TORCH_CHECK (sourceContext_, " Failed to allocate filter graph" );
69+
70+ int status = av_buffersrc_parameters_set (sourceContext_, srcParams.get ());
5771 TORCH_CHECK (
5872 status >= 0 ,
5973 " Failed to create filter graph: " ,
60- filterArgs.str (),
61- " : " ,
74+ getFFMPEGErrorStringFromErrorCode (status));
75+
76+ status = avfilter_init_str (sourceContext_, nullptr );
77+ TORCH_CHECK (
78+ status >= 0 ,
79+ " Failed to create filter graph : " ,
6280 getFFMPEGErrorStringFromErrorCode (status));
6381
6482 status = avfilter_graph_create_filter (
@@ -68,7 +86,8 @@ FilterGraph::FilterGraph(
6886 " Failed to create filter graph: " ,
6987 getFFMPEGErrorStringFromErrorCode (status));
7088
71- enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
89+ enum AVPixelFormat pix_fmts[] = {
90+ filtersContext.outputFormat , AV_PIX_FMT_NONE};
7291
7392 status = av_opt_set_int_list (
7493 sinkContext_,
@@ -93,16 +112,11 @@ FilterGraph::FilterGraph(
93112 inputs->pad_idx = 0 ;
94113 inputs->next = nullptr ;
95114
96- std::stringstream description;
97- description << " scale=" << frameContext.expectedWidth << " :"
98- << frameContext.expectedHeight ;
99- description << " :sws_flags=bilinear" ;
100-
101115 AVFilterInOut* outputsTmp = outputs.release ();
102116 AVFilterInOut* inputsTmp = inputs.release ();
103117 status = avfilter_graph_parse_ptr (
104118 filterGraph_.get (),
105- description. str () .c_str (),
119+ filtersContext. filters .c_str (),
106120 &inputsTmp,
107121 &outputsTmp,
108122 nullptr );
@@ -128,8 +142,7 @@ UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) {
128142 UniqueAVFrame filteredAVFrame (av_frame_alloc ());
129143 status = av_buffersink_get_frame (sinkContext_, filteredAVFrame.get ());
130144 TORCH_CHECK (
131- status >= AVSUCCESS, " Failed to fet frame from buffer sink context" );
132- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
145+ status >= AVSUCCESS, " Failed to get frame from buffer sink context" );
133146
134147 return filteredAVFrame;
135148}
0 commit comments