@@ -41,27 +41,44 @@ const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
4141PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void , av_buffer_unref>>
4242 g_cached_hw_device_ctxs (MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
4343
44+ int getFlagsAVHardwareDeviceContextCreate () {
45+ // 58.26.100 introduced the concept of reusing the existing cuda context
46+ // which is much faster and lower memory than creating a new cuda context.
4447#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
48+ return AV_CUDA_USE_CURRENT_CONTEXT;
49+ #else
50+ return 0 ;
51+ #endif
52+ }
53+
54+ UniqueAVBufferRef getHardwareDeviceContext (const torch::Device& device) {
55+ enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
56+ TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
57+ torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex (device);
58+
59+ UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get (device);
60+ if (hardwareDeviceCtx) {
61+ return hardwareDeviceCtx;
62+ }
4563
46- AVBufferRef* getFFMPEGContextFromExistingCudaContext (
47- const torch::Device& device,
48- torch::DeviceIndex nonNegativeDeviceIndex,
49- enum AVHWDeviceType type) {
64+ // Create hardware device context
5065 c10::cuda::CUDAGuard deviceGuard (device);
5166 // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
5267 // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
5368 // So we ensure the deviceIndex is not negative.
5469 // We set the device because we may be called from a different thread than
5570 // the one that initialized the cuda context.
5671 cudaSetDevice (nonNegativeDeviceIndex);
57- AVBufferRef* hw_device_ctx = nullptr ;
72+ AVBufferRef* hardwareDeviceCtxRaw = nullptr ;
5873 std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex);
74+
5975 int err = av_hwdevice_ctx_create (
60- &hw_device_ctx ,
76+ &hardwareDeviceCtxRaw ,
6177 type,
6278 deviceOrdinal.c_str (),
6379 nullptr ,
64- AV_CUDA_USE_CURRENT_CONTEXT);
80+ getFlagsAVHardwareDeviceContextCreate ());
81+
6582 if (err < 0 ) {
6683 /* clang-format off */
6784 TORCH_CHECK (
@@ -72,53 +89,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
7289 " ). FFmpeg error: " , getFFMPEGErrorStringFromErrorCode (err));
7390 /* clang-format on */
7491 }
75- return hw_device_ctx;
76- }
77-
78- #else
79-
80- AVBufferRef* getFFMPEGContextFromNewCudaContext (
81- [[maybe_unused]] const torch::Device& device,
82- torch::DeviceIndex nonNegativeDeviceIndex,
83- enum AVHWDeviceType type) {
84- AVBufferRef* hw_device_ctx = nullptr ;
85- std::string deviceOrdinal = std::to_string (nonNegativeDeviceIndex);
86- int err = av_hwdevice_ctx_create (
87- &hw_device_ctx, type, deviceOrdinal.c_str (), nullptr , 0 );
88- if (err < 0 ) {
89- TORCH_CHECK (
90- false ,
91- " Failed to create specified HW device" ,
92- getFFMPEGErrorStringFromErrorCode (err));
93- }
94- return hw_device_ctx;
95- }
9692
97- #endif
98-
99- UniqueAVBufferRef getCudaContext (const torch::Device& device) {
100- enum AVHWDeviceType type = av_hwdevice_find_type_by_name (" cuda" );
101- TORCH_CHECK (type != AV_HWDEVICE_TYPE_NONE, " Failed to find cuda device" );
102- torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex (device);
103-
104- UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get (device);
105- if (hw_device_ctx) {
106- return hw_device_ctx;
107- }
108-
109- // 58.26.100 introduced the concept of reusing the existing cuda context
110- // which is much faster and lower memory than creating a new cuda context.
111- // So we try to use that if it is available.
112- // FFMPEG 6.1.2 appears to be the earliest release that contains version
113- // 58.26.100 of avutil.
114- // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
115- #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
116- return UniqueAVBufferRef (getFFMPEGContextFromExistingCudaContext (
117- device, nonNegativeDeviceIndex, type));
118- #else
119- return UniqueAVBufferRef (
120- getFFMPEGContextFromNewCudaContext (device, nonNegativeDeviceIndex, type));
121- #endif
93+ return UniqueAVBufferRef (hardwareDeviceCtxRaw);
12294}
12395
12496} // namespace
@@ -131,15 +103,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
131103
132104 initializeCudaContextWithPytorch (device_);
133105
134- // TODO rename this, this is a hardware device context, not a CUDA context!
135- // See https://github.com/meta-pytorch/torchcodec/issues/924
136- ctx_ = getCudaContext (device_);
106+ hardwareDeviceCtx_ = getHardwareDeviceContext (device_);
137107 nppCtx_ = getNppStreamContext (device_);
138108}
139109
140110CudaDeviceInterface::~CudaDeviceInterface () {
141- if (ctx_) {
142- g_cached_hw_device_ctxs.addIfCacheHasCapacity (device_, std::move (ctx_));
111+ if (hardwareDeviceCtx_) {
112+ g_cached_hw_device_ctxs.addIfCacheHasCapacity (
113+ device_, std::move (hardwareDeviceCtx_));
143114 }
144115 returnNppStreamContextToCache (device_, std::move (nppCtx_));
145116}
@@ -170,9 +141,10 @@ void CudaDeviceInterface::initializeVideo(
170141
171142void CudaDeviceInterface::registerHardwareDeviceWithCodec (
172143 AVCodecContext* codecContext) {
173- TORCH_CHECK (ctx_, " FFmpeg HW device has not been initialized" );
144+ TORCH_CHECK (
145+ hardwareDeviceCtx_, " Hardware device context has not been initialized" );
174146 TORCH_CHECK (codecContext != nullptr , " codecContext is null" );
175- codecContext->hw_device_ctx = av_buffer_ref (ctx_ .get ());
147+ codecContext->hw_device_ctx = av_buffer_ref (hardwareDeviceCtx_ .get ());
176148}
177149
178150UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24 (
0 commit comments