@@ -13,6 +13,34 @@ static bool g_cpu = registerDeviceInterface(
1313 torch::kCPU ,
1414 [](const torch::Device& device) { return new CpuDeviceInterface (device); });
1515
16+ ColorConversionLibrary getColorConversionLibrary (
17+ const VideoStreamOptions& videoStreamOptions,
18+ int width) {
19+ // By default, we want to use swscale for color conversion because it is
20+ // faster. However, it has width requirements, so we may need to fall back
21+ // to filtergraph. We also need to respect what was requested from the
22+ // options; we respect the options unconditionally, so it's possible for
23+ // swscale's width requirements to be violated. We don't expose the ability to
24+ // choose color conversion library publicly; we only use this ability
25+ // internally.
26+
27+ // swscale requires widths to be multiples of 32:
28+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29+ // so we fall back to filtergraph if the width is not a multiple of 32.
30+ auto defaultLibrary = (width % 32 == 0 ) ? ColorConversionLibrary::SWSCALE
31+ : ColorConversionLibrary::FILTERGRAPH;
32+
33+ ColorConversionLibrary colorConversionLibrary =
34+ videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
35+
36+ TORCH_CHECK (
37+ colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
38+ colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
39+ " Invalid color conversion library: " ,
40+ static_cast <int >(colorConversionLibrary));
41+ return colorConversionLibrary;
42+ }
43+
1644} // namespace
1745
1846bool CpuDeviceInterface::SwsFrameContext::operator ==(
@@ -34,6 +62,42 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
3462 device_.type () == torch::kCPU , " Unsupported device: " , device_.str ());
3563}
3664
65+ std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext (
66+ const VideoStreamOptions& videoStreamOptions,
67+ const UniqueAVFrame& avFrame,
68+ const AVRational& timeBase) {
69+ enum AVPixelFormat frameFormat =
70+ static_cast <enum AVPixelFormat>(avFrame->format );
71+ auto frameDims =
72+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
73+ int expectedOutputHeight = frameDims.height ;
74+ int expectedOutputWidth = frameDims.width ;
75+
76+ if (getColorConversionLibrary (videoStreamOptions, expectedOutputWidth) ==
77+ ColorConversionLibrary::SWSCALE) {
78+ return nullptr ;
79+ }
80+
81+ std::unique_ptr<FiltersContext> filtersContext =
82+ std::make_unique<FiltersContext>();
83+
84+ filtersContext->inputWidth = avFrame->width ;
85+ filtersContext->inputHeight = avFrame->height ;
86+ filtersContext->inputFormat = frameFormat;
87+ filtersContext->inputAspectRatio = avFrame->sample_aspect_ratio ;
88+ filtersContext->outputWidth = expectedOutputWidth;
89+ filtersContext->outputHeight = expectedOutputHeight;
90+ filtersContext->outputFormat = AV_PIX_FMT_RGB24;
91+ filtersContext->timeBase = timeBase;
92+
93+ std::stringstream filters;
94+ filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
95+ filters << " :sws_flags=bilinear" ;
96+
97+ filtersContext->filtergraphStr = filters.str ();
98+ return filtersContext;
99+ }
100+
37101// Note [preAllocatedOutputTensor with swscale and filtergraph]:
38102// Callers may pass a pre-allocated tensor, where the output.data tensor will
39103// be stored. This parameter is honored in any case, but it only leads to a
@@ -45,7 +109,7 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
45109// `dimension_order` parameter. It's up to callers to re-shape it if needed.
46110void CpuDeviceInterface::convertAVFrameToFrameOutput (
47111 const VideoStreamOptions& videoStreamOptions,
48- const AVRational& timeBase,
112+ [[maybe_unused]] const AVRational& timeBase,
49113 UniqueAVFrame& avFrame,
50114 FrameOutput& frameOutput,
51115 std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -71,23 +135,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
71135 enum AVPixelFormat frameFormat =
72136 static_cast <enum AVPixelFormat>(avFrame->format );
73137
74- // By default, we want to use swscale for color conversion because it is
75- // faster. However, it has width requirements, so we may need to fall back
76- // to filtergraph. We also need to respect what was requested from the
77- // options; we respect the options unconditionally, so it's possible for
78- // swscale's width requirements to be violated. We don't expose the ability to
79- // choose color conversion library publicly; we only use this ability
80- // internally.
81-
82- // swscale requires widths to be multiples of 32:
83- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
84- // so we fall back to filtergraph if the width is not a multiple of 32.
85- auto defaultLibrary = (expectedOutputWidth % 32 == 0 )
86- ? ColorConversionLibrary::SWSCALE
87- : ColorConversionLibrary::FILTERGRAPH;
88-
89138 ColorConversionLibrary colorConversionLibrary =
90- videoStreamOptions. colorConversionLibrary . value_or (defaultLibrary );
139+ getColorConversionLibrary (videoStreamOptions, expectedOutputWidth );
91140
92141 if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93142 // We need to compare the current frame context with our previous frame
@@ -126,44 +175,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
126175
127176 frameOutput.data = outputTensor;
128177 } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
129- // See comment above in swscale branch about the filterGraphContext_
130- // creation. creation
131- FiltersContext filtersContext;
132-
133- filtersContext.inputWidth = avFrame->width ;
134- filtersContext.inputHeight = avFrame->height ;
135- filtersContext.inputFormat = frameFormat;
136- filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio ;
137- filtersContext.outputWidth = expectedOutputWidth;
138- filtersContext.outputHeight = expectedOutputHeight;
139- filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140- filtersContext.timeBase = timeBase;
141-
142- std::stringstream filters;
143- filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
144- filters << " :sws_flags=bilinear" ;
145-
146- filtersContext.filtergraphStr = filters.str ();
147-
148- if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149- filterGraphContext_ =
150- std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
151- prevFiltersContext_ = std::move (filtersContext);
152- }
153- outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
178+ TORCH_CHECK_EQ (avFrame->format , AV_PIX_FMT_RGB24);
154179
155- // Similarly to above, if this check fails it means the frame wasn't
156- // reshaped to its expected dimensions by filtergraph.
157- auto shape = outputTensor.sizes ();
158- TORCH_CHECK (
159- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
160- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
161- " Expected output tensor of shape " ,
162- expectedOutputHeight,
163- " x" ,
164- expectedOutputWidth,
165- " x3, got " ,
166- shape);
180+ std::vector<int64_t > shape = {expectedOutputHeight, expectedOutputWidth, 3 };
181+ std::vector<int64_t > strides = {avFrame->linesize [0 ], 3 , 1 };
182+ AVFrame* avFramePtr = avFrame.release ();
183+ auto deleter = [avFramePtr](void *) {
184+ UniqueAVFrame avFrameToDelete (avFramePtr);
185+ };
186+ outputTensor = torch::from_blob (
187+ avFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
167188
168189 if (preAllocatedOutputTensor.has_value ()) {
169190 // We have already validated that preAllocatedOutputTensor and
@@ -173,11 +194,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
173194 } else {
174195 frameOutput.data = outputTensor;
175196 }
176- } else {
177- TORCH_CHECK (
178- false ,
179- " Invalid color conversion library: " ,
180- static_cast <int >(colorConversionLibrary));
181197 }
182198}
183199
@@ -199,25 +215,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
199215 return resultHeight;
200216}
201217
202- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
203- const UniqueAVFrame& avFrame) {
204- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
205-
206- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
207-
208- auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
209- int height = frameDims.height ;
210- int width = frameDims.width ;
211- std::vector<int64_t > shape = {height, width, 3 };
212- std::vector<int64_t > strides = {filteredAVFrame->linesize [0 ], 3 , 1 };
213- AVFrame* filteredAVFramePtr = filteredAVFrame.release ();
214- auto deleter = [filteredAVFramePtr](void *) {
215- UniqueAVFrame avFrameToDelete (filteredAVFramePtr);
216- };
217- return torch::from_blob (
218- filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
219- }
220-
221218void CpuDeviceInterface::createSwsContext (
222219 const SwsFrameContext& swsFrameContext,
223220 const enum AVColorSpace colorspace) {
0 commit comments