Skip to content

Commit 48b07cc

Browse files
mollyxuMolly Xu
andauthored
Refactor CudaDeviceInterface::getCudaContext (#956)
Co-authored-by: Molly Xu <mollyxu@fb.com>
1 parent b084768 commit 48b07cc

File tree

2 files changed

+33
-61
lines changed

2 files changed

+33
-61
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 32 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,44 @@ const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
4141
PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>
4242
g_cached_hw_device_ctxs(MAX_CUDA_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE);
4343

44+
int getFlagsAVHardwareDeviceContextCreate() {
45+
// 58.26.100 introduced the concept of reusing the existing cuda context
46+
// which is much faster and lower memory than creating a new cuda context.
4447
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
48+
return AV_CUDA_USE_CURRENT_CONTEXT;
49+
#else
50+
return 0;
51+
#endif
52+
}
53+
54+
UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
55+
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
56+
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
57+
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
58+
59+
UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get(device);
60+
if (hardwareDeviceCtx) {
61+
return hardwareDeviceCtx;
62+
}
4563

46-
AVBufferRef* getFFMPEGContextFromExistingCudaContext(
47-
const torch::Device& device,
48-
torch::DeviceIndex nonNegativeDeviceIndex,
49-
enum AVHWDeviceType type) {
64+
// Create hardware device context
5065
c10::cuda::CUDAGuard deviceGuard(device);
5166
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
5267
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
5368
// So we ensure the deviceIndex is not negative.
5469
// We set the device because we may be called from a different thread than
5570
// the one that initialized the cuda context.
5671
cudaSetDevice(nonNegativeDeviceIndex);
57-
AVBufferRef* hw_device_ctx = nullptr;
72+
AVBufferRef* hardwareDeviceCtxRaw = nullptr;
5873
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
74+
5975
int err = av_hwdevice_ctx_create(
60-
&hw_device_ctx,
76+
&hardwareDeviceCtxRaw,
6177
type,
6278
deviceOrdinal.c_str(),
6379
nullptr,
64-
AV_CUDA_USE_CURRENT_CONTEXT);
80+
getFlagsAVHardwareDeviceContextCreate());
81+
6582
if (err < 0) {
6683
/* clang-format off */
6784
TORCH_CHECK(
@@ -72,53 +89,8 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext(
7289
"). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err));
7390
/* clang-format on */
7491
}
75-
return hw_device_ctx;
76-
}
77-
78-
#else
79-
80-
AVBufferRef* getFFMPEGContextFromNewCudaContext(
81-
[[maybe_unused]] const torch::Device& device,
82-
torch::DeviceIndex nonNegativeDeviceIndex,
83-
enum AVHWDeviceType type) {
84-
AVBufferRef* hw_device_ctx = nullptr;
85-
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
86-
int err = av_hwdevice_ctx_create(
87-
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
88-
if (err < 0) {
89-
TORCH_CHECK(
90-
false,
91-
"Failed to create specified HW device",
92-
getFFMPEGErrorStringFromErrorCode(err));
93-
}
94-
return hw_device_ctx;
95-
}
9692

97-
#endif
98-
99-
UniqueAVBufferRef getCudaContext(const torch::Device& device) {
100-
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
101-
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
102-
torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
103-
104-
UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device);
105-
if (hw_device_ctx) {
106-
return hw_device_ctx;
107-
}
108-
109-
// 58.26.100 introduced the concept of reusing the existing cuda context
110-
// which is much faster and lower memory than creating a new cuda context.
111-
// So we try to use that if it is available.
112-
// FFMPEG 6.1.2 appears to be the earliest release that contains version
113-
// 58.26.100 of avutil.
114-
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
115-
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
116-
return UniqueAVBufferRef(getFFMPEGContextFromExistingCudaContext(
117-
device, nonNegativeDeviceIndex, type));
118-
#else
119-
return UniqueAVBufferRef(
120-
getFFMPEGContextFromNewCudaContext(device, nonNegativeDeviceIndex, type));
121-
#endif
93+
return UniqueAVBufferRef(hardwareDeviceCtxRaw);
12294
}
12395

12496
} // namespace
@@ -131,15 +103,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
131103

132104
initializeCudaContextWithPytorch(device_);
133105

134-
// TODO rename this, this is a hardware device context, not a CUDA context!
135-
// See https://github.com/meta-pytorch/torchcodec/issues/924
136-
ctx_ = getCudaContext(device_);
106+
hardwareDeviceCtx_ = getHardwareDeviceContext(device_);
137107
nppCtx_ = getNppStreamContext(device_);
138108
}
139109

140110
CudaDeviceInterface::~CudaDeviceInterface() {
141-
if (ctx_) {
142-
g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_));
111+
if (hardwareDeviceCtx_) {
112+
g_cached_hw_device_ctxs.addIfCacheHasCapacity(
113+
device_, std::move(hardwareDeviceCtx_));
143114
}
144115
returnNppStreamContextToCache(device_, std::move(nppCtx_));
145116
}
@@ -170,9 +141,10 @@ void CudaDeviceInterface::initializeVideo(
170141

171142
void CudaDeviceInterface::registerHardwareDeviceWithCodec(
172143
AVCodecContext* codecContext) {
173-
TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized");
144+
TORCH_CHECK(
145+
hardwareDeviceCtx_, "Hardware device context has not been initialized");
174146
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
175-
codecContext->hw_device_ctx = av_buffer_ref(ctx_.get());
147+
codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get());
176148
}
177149

178150
UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class CudaDeviceInterface : public DeviceInterface {
5252
VideoStreamOptions videoStreamOptions_;
5353
AVRational timeBase_;
5454

55-
UniqueAVBufferRef ctx_;
55+
UniqueAVBufferRef hardwareDeviceCtx_;
5656
UniqueNppContext nppCtx_;
5757

5858
// This filtergraph instance is only used for NV12 format conversion in

0 commit comments

Comments
 (0)