@@ -83,6 +83,21 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
8383 enum AVPixelFormat frameFormat =
8484 static_cast <enum AVPixelFormat>(avFrame->format );
8585
86+ if (frameFormat == AV_PIX_FMT_RGB24 &&
87+ avFrame->width == expectedOutputWidth &&
88+ avFrame->height == expectedOutputHeight) {
89+ outputTensor = toTensor (avFrame);
90+ if (preAllocatedOutputTensor.has_value ()) {
91+ // We have already validated that preAllocatedOutputTensor and
92+ // outputTensor have the same shape.
93+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
94+ frameOutput.data = preAllocatedOutputTensor.value ();
95+ } else {
96+ frameOutput.data = outputTensor;
97+ }
98+ return ;
99+ }
100+
86101 // By default, we want to use swscale for color conversion because it is
87102 // faster. However, it has width requirements, so we may need to fall back
88103 // to filtergraph. We also need to respect what was requested from the
@@ -159,7 +174,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
159174 std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
160175 prevFiltersContext_ = std::move (filtersContext);
161176 }
162- outputTensor = convertAVFrameToTensorUsingFilterGraph ( avFrame);
177+ outputTensor = toTensor (filterGraphContext_-> convert ( avFrame) );
163178
164179 // Similarly to above, if this check fails it means the frame wasn't
165180 // reshaped to its expected dimensions by filtergraph.
@@ -208,23 +223,20 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
208223 return resultHeight;
209224}
210225
211- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
212- const UniqueAVFrame& avFrame) {
213- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
214-
215- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
226+ torch::Tensor CpuDeviceInterface::toTensor (const UniqueAVFrame& avFrame) {
227+ TORCH_CHECK_EQ (avFrame->format , AV_PIX_FMT_RGB24);
216228
217- auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame .get ());
229+ auto frameDims = getHeightAndWidthFromResizedAVFrame (*avFrame .get ());
218230 int height = frameDims.height ;
219231 int width = frameDims.width ;
220232 std::vector<int64_t > shape = {height, width, 3 };
221- std::vector<int64_t > strides = {filteredAVFrame ->linesize [0 ], 3 , 1 };
222- AVFrame* filteredAVFramePtr = filteredAVFrame. release ( );
223- auto deleter = [filteredAVFramePtr ](void *) {
224- UniqueAVFrame avFrameToDelete (filteredAVFramePtr );
233+ std::vector<int64_t > strides = {avFrame ->linesize [0 ], 3 , 1 };
234+ AVFrame* avFrameClone = av_frame_clone (avFrame. get () );
235+ auto deleter = [avFrameClone ](void *) {
236+ UniqueAVFrame avFrameToDelete (avFrameClone );
225237 };
226238 return torch::from_blob (
227- filteredAVFramePtr ->data [0 ], shape, strides, deleter, {torch::kUInt8 });
239+ avFrameClone ->data [0 ], shape, strides, deleter, {torch::kUInt8 });
228240}
229241
230242void CpuDeviceInterface::createSwsContext (
0 commit comments