@@ -13,6 +13,34 @@ static bool g_cpu = registerDeviceInterface(
1313 torch::kCPU ,
1414 [](const torch::Device& device) { return new CpuDeviceInterface (device); });
1515
16+ ColorConversionLibrary getColorConversionLibrary (
17+ const VideoStreamOptions& videoStreamOptions,
18+ int width) {
19+ // By default, we want to use swscale for color conversion because it is
20+ // faster. However, it has width requirements, so we may need to fall back
21+ // to filtergraph. We also need to respect what was requested from the
22+ // options; we respect the options unconditionally, so it's possible for
23+ // swscale's width requirements to be violated. We don't expose the ability to
24+ // choose color conversion library publicly; we only use this ability
25+ // internally.
26+
27+ // swscale requires widths to be multiples of 32:
28+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29+ // so we fall back to filtergraph if the width is not a multiple of 32.
30+ auto defaultLibrary = (width % 32 == 0 ) ? ColorConversionLibrary::SWSCALE
31+ : ColorConversionLibrary::FILTERGRAPH;
32+
33+ ColorConversionLibrary colorConversionLibrary =
34+ videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
35+
36+ TORCH_CHECK (
37+ colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
38+ colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
39+ " Invalid color conversion library: " ,
40+ static_cast <int >(colorConversionLibrary));
41+ return colorConversionLibrary;
42+ }
43+
1644} // namespace
1745
1846CpuDeviceInterface::SwsFrameContext::SwsFrameContext (
@@ -46,6 +74,38 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4674 device_.type () == torch::kCPU , " Unsupported device: " , device_.str ());
4775}
4876
77+ std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext (
78+ const VideoStreamOptions& videoStreamOptions,
79+ const UniqueAVFrame& avFrame,
80+ const AVRational& timeBase) {
81+ enum AVPixelFormat frameFormat =
82+ static_cast <enum AVPixelFormat>(avFrame->format );
83+ auto frameDims =
84+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
85+ int expectedOutputHeight = frameDims.height ;
86+ int expectedOutputWidth = frameDims.width ;
87+
88+ if (getColorConversionLibrary (videoStreamOptions, expectedOutputWidth) ==
89+ ColorConversionLibrary::SWSCALE) {
90+ return nullptr ;
91+ }
92+
93+ std::stringstream filters;
94+ filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
95+ filters << " :sws_flags=bilinear" ;
96+
97+ return std::make_unique<FiltersContext>(
98+ avFrame->width ,
99+ avFrame->height ,
100+ frameFormat,
101+ avFrame->sample_aspect_ratio ,
102+ expectedOutputWidth,
103+ expectedOutputHeight,
104+ AV_PIX_FMT_RGB24,
105+ filters.str (),
106+ timeBase);
107+ }
108+
49109// Note [preAllocatedOutputTensor with swscale and filtergraph]:
50110// Callers may pass a pre-allocated tensor, where the output.data tensor will
51111// be stored. This parameter is honored in any case, but it only leads to a
@@ -57,7 +117,7 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
57117// `dimension_order` parameter. It's up to callers to re-shape it if needed.
58118void CpuDeviceInterface::convertAVFrameToFrameOutput (
59119 const VideoStreamOptions& videoStreamOptions,
60- const AVRational& timeBase,
120+ [[maybe_unused]] const AVRational& timeBase,
61121 UniqueAVFrame& avFrame,
62122 FrameOutput& frameOutput,
63123 std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -83,23 +143,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
83143 enum AVPixelFormat frameFormat =
84144 static_cast <enum AVPixelFormat>(avFrame->format );
85145
86- // By default, we want to use swscale for color conversion because it is
87- // faster. However, it has width requirements, so we may need to fall back
88- // to filtergraph. We also need to respect what was requested from the
89- // options; we respect the options unconditionally, so it's possible for
90- // swscale's width requirements to be violated. We don't expose the ability to
91- // choose color conversion library publicly; we only use this ability
92- // internally.
93-
94- // swscale requires widths to be multiples of 32:
95- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
96- // so we fall back to filtergraph if the width is not a multiple of 32.
97- auto defaultLibrary = (expectedOutputWidth % 32 == 0 )
98- ? ColorConversionLibrary::SWSCALE
99- : ColorConversionLibrary::FILTERGRAPH;
100-
101146 ColorConversionLibrary colorConversionLibrary =
102- videoStreamOptions. colorConversionLibrary . value_or (defaultLibrary );
147+ getColorConversionLibrary (videoStreamOptions, expectedOutputWidth );
103148
104149 if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
105150 // We need to compare the current frame context with our previous frame
@@ -137,42 +182,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
137182
138183 frameOutput.data = outputTensor;
139184 } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
140- // See comment above in swscale branch about the filterGraphContext_
141- // creation. creation
142- std::stringstream filters;
143- filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
144- filters << " :sws_flags=bilinear" ;
185+ TORCH_CHECK_EQ (avFrame->format , AV_PIX_FMT_RGB24);
145186
146- FiltersContext filtersContext (
147- avFrame->width ,
148- avFrame->height ,
149- frameFormat,
150- avFrame->sample_aspect_ratio ,
151- expectedOutputWidth,
152- expectedOutputHeight,
153- AV_PIX_FMT_RGB24,
154- filters.str (),
155- timeBase);
156-
157- if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
158- filterGraphContext_ =
159- std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
160- prevFiltersContext_ = std::move (filtersContext);
161- }
162- outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
163-
164- // Similarly to above, if this check fails it means the frame wasn't
165- // reshaped to its expected dimensions by filtergraph.
166- auto shape = outputTensor.sizes ();
167- TORCH_CHECK (
168- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
169- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
170- " Expected output tensor of shape " ,
171- expectedOutputHeight,
172- " x" ,
173- expectedOutputWidth,
174- " x3, got " ,
175- shape);
187+ std::vector<int64_t > shape = {expectedOutputHeight, expectedOutputWidth, 3 };
188+ std::vector<int64_t > strides = {avFrame->linesize [0 ], 3 , 1 };
189+ AVFrame* avFramePtr = avFrame.release ();
190+ auto deleter = [avFramePtr](void *) {
191+ UniqueAVFrame avFrameToDelete (avFramePtr);
192+ };
193+ outputTensor = torch::from_blob (
194+ avFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
176195
177196 if (preAllocatedOutputTensor.has_value ()) {
178197 // We have already validated that preAllocatedOutputTensor and
@@ -182,11 +201,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
182201 } else {
183202 frameOutput.data = outputTensor;
184203 }
185- } else {
186- TORCH_CHECK (
187- false ,
188- " Invalid color conversion library: " ,
189- static_cast <int >(colorConversionLibrary));
190204 }
191205}
192206
@@ -208,25 +222,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
208222 return resultHeight;
209223}
210224
211- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
212- const UniqueAVFrame& avFrame) {
213- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
214-
215- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
216-
217- auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
218- int height = frameDims.height ;
219- int width = frameDims.width ;
220- std::vector<int64_t > shape = {height, width, 3 };
221- std::vector<int64_t > strides = {filteredAVFrame->linesize [0 ], 3 , 1 };
222- AVFrame* filteredAVFramePtr = filteredAVFrame.release ();
223- auto deleter = [filteredAVFramePtr](void *) {
224- UniqueAVFrame avFrameToDelete (filteredAVFramePtr);
225- };
226- return torch::from_blob (
227- filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
228- }
229-
230225void CpuDeviceInterface::createSwsContext (
231226 const SwsFrameContext& swsFrameContext,
232227 const enum AVColorSpace colorspace) {
0 commit comments