@@ -13,6 +13,35 @@ static bool g_cpu = registerDeviceInterface(
1313 torch::kCPU ,
1414 [](const torch::Device& device) { return new CpuDeviceInterface (device); });
1515
16+ ColorConversionLibrary getColorConversionLibrary (
17+ const VideoStreamOptions& videoStreamOptions,
18+ int width) {
19+ // By default, we want to use swscale for color conversion because it is
20+ // faster. However, it has width requirements, so we may need to fall back
21+ // to filtergraph. We also need to respect what was requested from the
22+ // options; we respect the options unconditionally, so it's possible for
23+ // swscale's width requirements to be violated. We don't expose the ability to
24+ // choose color conversion library publicly; we only use this ability
25+ // internally.
26+
27+ // swscale requires widths to be multiples of 32:
28+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29+ // so we fall back to filtergraph if the width is not a multiple of 32.
30+ auto defaultLibrary = (width % 32 == 0 )
31+ ? ColorConversionLibrary::SWSCALE
32+ : ColorConversionLibrary::FILTERGRAPH;
33+
34+ ColorConversionLibrary colorConversionLibrary =
35+ videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
36+
37+ TORCH_CHECK (
38+ colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
39+ colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
40+ " Invalid color conversion library: " ,
41+ static_cast <int >(colorConversionLibrary));
42+ return colorConversionLibrary;
43+ }
44+
1645} // namespace
1746
1847CpuDeviceInterface::CpuDeviceInterface (const torch::Device& device)
@@ -22,6 +51,52 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
2251 device_.type () == torch::kCPU , " Unsupported device: " , device_.str ());
2352}
2453
54+ std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContextInternal (
55+ const VideoStreamOptions& videoStreamOptions,
56+ const UniqueAVFrame& avFrame,
57+ const AVRational& timeBase) {
58+ enum AVPixelFormat frameFormat =
59+ static_cast <enum AVPixelFormat>(avFrame->format );
60+ auto frameDims =
61+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
62+ int expectedOutputHeight = frameDims.height ;
63+ int expectedOutputWidth = frameDims.width ;
64+
65+ std::unique_ptr<FiltersContext> filtersContext =
66+ std::make_unique<FiltersContext>();
67+
68+ filtersContext->inputWidth = avFrame->width ;
69+ filtersContext->inputHeight = avFrame->height ;
70+ filtersContext->inputFormat = frameFormat;
71+ filtersContext->inputAspectRatio = avFrame->sample_aspect_ratio ;
72+ filtersContext->outputWidth = expectedOutputWidth;
73+ filtersContext->outputHeight = expectedOutputHeight;
74+ filtersContext->outputFormat = AV_PIX_FMT_RGB24;
75+ filtersContext->timeBase = timeBase;
76+
77+ std::stringstream filters;
78+ filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
79+ filters << " :sws_flags=bilinear" ;
80+
81+ filtersContext->filters = filters.str ();
82+ return filtersContext;
83+ }
84+
85+ std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext (
86+ const VideoStreamOptions& videoStreamOptions,
87+ const UniqueAVFrame& avFrame,
88+ const AVRational& timeBase) {
89+ auto frameDims =
90+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
91+ int expectedOutputWidth = frameDims.width ;
92+
93+ if (getColorConversionLibrary (videoStreamOptions, expectedOutputWidth) == ColorConversionLibrary::SWSCALE) {
94+ return nullptr ;
95+ }
96+
97+ return initializeFiltersContextInternal (videoStreamOptions, avFrame, timeBase);
98+ }
99+
25100// Note [preAllocatedOutputTensor with swscale and filtergraph]:
26101// Callers may pass a pre-allocated tensor, where the output.data tensor will
27102// be stored. This parameter is honored in any case, but it only leads to a
@@ -56,56 +131,25 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
56131 }
57132
58133 torch::Tensor outputTensor;
59- // We need to compare the current frame context with our previous frame
60- // context. If they are different, then we need to re-create our colorspace
61- // conversion objects. We create our colorspace conversion objects late so
62- // that we don't have to depend on the unreliable metadata in the header.
63- // And we sometimes re-create them because it's possible for frame
64- // resolution to change mid-stream. Finally, we want to reuse the colorspace
65- // conversion objects as much as possible for performance reasons.
66- enum AVPixelFormat frameFormat =
67- static_cast <enum AVPixelFormat>(avFrame->format );
68- FiltersContext filtersContext;
69-
70- filtersContext.inputWidth = avFrame->width ;
71- filtersContext.inputHeight = avFrame->height ;
72- filtersContext.inputFormat = frameFormat;
73- filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio ;
74- filtersContext.outputWidth = expectedOutputWidth;
75- filtersContext.outputHeight = expectedOutputHeight;
76- filtersContext.outputFormat = AV_PIX_FMT_RGB24;
77- filtersContext.timeBase = timeBase;
78-
79- std::stringstream filters;
80- filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
81- filters << " :sws_flags=bilinear" ;
82-
83- filtersContext.filters = filters.str ();
84-
85- // By default, we want to use swscale for color conversion because it is
86- // faster. However, it has width requirements, so we may need to fall back
87- // to filtergraph. We also need to respect what was requested from the
88- // options; we respect the options unconditionally, so it's possible for
89- // swscale's width requirements to be violated. We don't expose the ability to
90- // choose color conversion library publicly; we only use this ability
91- // internally.
92-
93- // swscale requires widths to be multiples of 32:
94- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
95- // so we fall back to filtergraph if the width is not a multiple of 32.
96- auto defaultLibrary = (expectedOutputWidth % 32 == 0 )
97- ? ColorConversionLibrary::SWSCALE
98- : ColorConversionLibrary::FILTERGRAPH;
99-
100134 ColorConversionLibrary colorConversionLibrary =
101- videoStreamOptions. colorConversionLibrary . value_or (defaultLibrary );
135+ getColorConversionLibrary (videoStreamOptions, expectedOutputWidth );
102136
103137 if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
104138 outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
105139 expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
106140
141+ // We need to compare the current frame context with our previous frame
142+ // context. If they are different, then we need to re-create our colorspace
143+ // conversion objects. We create our colorspace conversion objects late so
144+ // that we don't have to depend on the unreliable metadata in the header.
145+ // And we sometimes re-create them because it's possible for frame
146+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
147+ // conversion objects as much as possible for performance reasons.
148+ std::unique_ptr<FiltersContext> filtersContext =
149+ initializeFiltersContextInternal (videoStreamOptions, avFrame, timeBase);
150+
107151 if (!swsContext_ || prevFiltersContext_ != filtersContext) {
108- createSwsContext (filtersContext, avFrame->colorspace );
152+ createSwsContext (* filtersContext, avFrame->colorspace );
109153 prevFiltersContext_ = std::move (filtersContext);
110154 }
111155 int resultHeight =
@@ -122,25 +166,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
122166
123167 frameOutput.data = outputTensor;
124168 } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
125- if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
126- filterGraphContext_ =
127- std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
128- prevFiltersContext_ = std::move (filtersContext);
129- }
130- outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
169+ TORCH_CHECK_EQ (avFrame->format , AV_PIX_FMT_RGB24);
131170
132- // Similarly to above, if this check fails it means the frame wasn't
133- // reshaped to its expected dimensions by filtergraph.
134- auto shape = outputTensor.sizes ();
135- TORCH_CHECK (
136- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
137- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
138- " Expected output tensor of shape " ,
139- expectedOutputHeight,
140- " x" ,
141- expectedOutputWidth,
142- " x3, got " ,
143- shape);
171+ std::vector<int64_t > shape = {expectedOutputHeight, expectedOutputWidth, 3 };
172+ std::vector<int64_t > strides = {avFrame->linesize [0 ], 3 , 1 };
173+ AVFrame* avFramePtr = avFrame.release ();
174+ auto deleter = [avFramePtr](void *) {
175+ UniqueAVFrame avFrameToDelete (avFramePtr);
176+ };
177+ outputTensor = torch::from_blob (
178+ avFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
144179
145180 if (preAllocatedOutputTensor.has_value ()) {
146181 // We have already validated that preAllocatedOutputTensor and
@@ -150,11 +185,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
150185 } else {
151186 frameOutput.data = outputTensor;
152187 }
153- } else {
154- TORCH_CHECK (
155- false ,
156- " Invalid color conversion library: " ,
157- static_cast <int >(colorConversionLibrary));
158188 }
159189}
160190
@@ -176,25 +206,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
176206 return resultHeight;
177207}
178208
179- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
180- const UniqueAVFrame& avFrame) {
181- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
182-
183- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
184-
185- auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
186- int height = frameDims.height ;
187- int width = frameDims.width ;
188- std::vector<int64_t > shape = {height, width, 3 };
189- std::vector<int64_t > strides = {filteredAVFrame->linesize [0 ], 3 , 1 };
190- AVFrame* filteredAVFramePtr = filteredAVFrame.release ();
191- auto deleter = [filteredAVFramePtr](void *) {
192- UniqueAVFrame avFrameToDelete (filteredAVFramePtr);
193- };
194- return torch::from_blob (
195- filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
196- }
197-
198209void CpuDeviceInterface::createSwsContext (
199210 const FiltersContext& filtersContext,
200211 const enum AVColorSpace colorspace) {
0 commit comments