Skip to content

Commit 78ab058

Browse files
committed
Let's just commit 3k loc in a single commit
1 parent 6f906f4 commit 78ab058

20 files changed

+3204
-72
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 563 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
// BETA CUDA device interface that provides direct control over NVDEC
8+
// while keeping FFmpeg for demuxing. A lot of the logic, particularly the use
9+
// of a cache for the decoders, is inspired by DALI's implementation which is
10+
// APACHE 2.0:
11+
// https://github.com/NVIDIA/DALI/blob/c7539676a24a8e9e99a6e8665e277363c5445259/dali/operators/video/frames_decoder_gpu.cc#L1
12+
//
13+
// NVDEC / NVCUVID docs:
14+
// https://docs.nvidia.com/video-technologies/video-codec-sdk/13.0/nvdec-video-decoder-api-prog-guide/index.html#using-nvidia-video-decoder-nvdecode-api
15+
16+
#pragma once
17+
18+
#include "src/torchcodec/_core/Cache.h"
19+
#include "src/torchcodec/_core/DeviceInterface.h"
20+
#include "src/torchcodec/_core/FFMPEGCommon.h"
21+
#include "src/torchcodec/_core/NVDECCache.h"
22+
23+
#include <map>
24+
#include <memory>
25+
#include <mutex>
26+
#include <queue>
27+
#include <unordered_map>
28+
#include <vector>
29+
30+
#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h"
31+
#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h"
32+
33+
namespace facebook::torchcodec {
34+
35+
class BetaCudaDeviceInterface : public DeviceInterface {
36+
public:
37+
explicit BetaCudaDeviceInterface(const torch::Device& device);
38+
virtual ~BetaCudaDeviceInterface();
39+
40+
void initializeInterface(AVStream* stream) override;
41+
42+
void convertAVFrameToFrameOutput(
43+
const VideoStreamOptions& videoStreamOptions,
44+
const AVRational& timeBase,
45+
UniqueAVFrame& avFrame,
46+
FrameOutput& frameOutput,
47+
std::optional<torch::Tensor> preAllocatedOutputTensor =
48+
std::nullopt) override;
49+
50+
bool canDecodePacketDirectly() const override {
51+
return true;
52+
}
53+
54+
int sendPacket(ReferenceAVPacket& packet) override;
55+
int receiveFrame(UniqueAVFrame& avFrame, int64_t desiredPts) override;
56+
void flush() override;
57+
ReferenceAVPacket* applyBSF(
58+
ReferenceAVPacket& packet,
59+
AutoAVPacket& filteredAutoPacket,
60+
ReferenceAVPacket& filteredPacket) override;
61+
62+
// NVDEC callback functions (must be public for C callbacks)
63+
unsigned char streamPropertyChange(CUVIDEOFORMAT* videoFormat);
64+
int frameReadyForDecoding(CUVIDPICPARAMS* pPicParams);
65+
66+
private:
67+
UniqueAVFrame convertCudaFrameToAVFrame(
68+
CUdeviceptr framePtr,
69+
unsigned int pitch,
70+
const CUVIDPARSERDISPINFO& dispInfo);
71+
72+
CUvideoparser videoParser_ = nullptr;
73+
UniqueCUvideodecoder decoder_;
74+
CUVIDEOFORMAT videoFormat_ = {};
75+
76+
struct FrameBufferSlot {
77+
CUVIDPARSERDISPINFO dispInfo;
78+
int64_t guessedPts;
79+
bool occupied = false;
80+
81+
FrameBufferSlot() : guessedPts(-1), occupied(false) {
82+
memset(&dispInfo, 0, sizeof(dispInfo));
83+
}
84+
};
85+
86+
std::vector<FrameBufferSlot> frameBuffer_;
87+
FrameBufferSlot* findEmptySlot();
88+
FrameBufferSlot* findFrameWithExactPts(int64_t desiredPts);
89+
90+
std::queue<int64_t> packetsPtsQueue;
91+
92+
bool eofSent_ = false;
93+
94+
// Flush flag to prevent decode operations during flush (like DALI's
95+
// isFlushing_)
96+
bool isFlushing_ = false;
97+
98+
AVRational timeBase_ = {0, 0};
99+
100+
UniqueAVBSFContext bitstreamFilter_;
101+
102+
// Default CUDA interface for color conversion.
103+
// TODONVDEC P2: we shouldn't need to keep a separate instance of the default.
104+
// See other TODO there about how interfaces should be completely independent.
105+
std::unique_ptr<DeviceInterface> defaultCudaInterface_;
106+
};
107+
108+
} // namespace facebook::torchcodec

src/torchcodec/_core/CMakeLists.txt

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function(make_torchcodec_libraries
9898
)
9999

100100
if(ENABLE_CUDA)
101-
list(APPEND core_sources CudaDeviceInterface.cpp)
101+
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp)
102102
endif()
103103

104104
set(core_library_dependencies
@@ -111,6 +111,29 @@ function(make_torchcodec_libraries
111111
${CUDA_nppi_LIBRARY}
112112
${CUDA_nppicc_LIBRARY}
113113
)
114+
115+
# Try to find NVCUVID. Try the normal way first. This should work locally.
116+
find_library(NVCUVID_LIBRARY NAMES nvcuvid)
117+
# If not found, try with version suffix, or hardcoded path. Appears
118+
# to be necessary on the CI.
119+
if(NOT NVCUVID_LIBRARY)
120+
find_library(NVCUVID_LIBRARY NAMES nvcuvid.1 PATHS /usr/lib64 /usr/lib)
121+
endif()
122+
if(NOT NVCUVID_LIBRARY)
123+
set(NVCUVID_LIBRARY "/usr/lib64/libnvcuvid.so.1")
124+
endif()
125+
126+
if(NVCUVID_LIBRARY)
127+
message(STATUS "Found NVCUVID: ${NVCUVID_LIBRARY}")
128+
else()
129+
message(FATAL_ERROR "Could not find NVCUVID library")
130+
endif()
131+
132+
# Add CUDA Driver library (needed for cuCtxGetCurrent, etc.)
133+
find_library(CUDA_DRIVER_LIBRARY NAMES cuda REQUIRED)
134+
message(STATUS "Found CUDA Driver library: ${CUDA_DRIVER_LIBRARY}")
135+
136+
list(APPEND core_library_dependencies ${NVCUVID_LIBRARY} ${CUDA_DRIVER_LIBRARY})
114137
endif()
115138

116139
make_torchcodec_sublibrary(

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ extern "C" {
1313
#include <libavutil/pixdesc.h>
1414
}
1515

16+
// TODONVDEC P1 Changes were made to this file to accomodate for the BETA CUDA
17+
// interface (see other TODONVDEC below). That's because the BETA CUDA interface
18+
// relies on this default CUDA interface to do the color conversion. That's
19+
// hacky, ugly, and leads to complicated code. We should refactor all this so
20+
// that an interface doesn't need to know anything about any other interface.
21+
// Note - this is more than just about the BETA CUDA interface: this default
22+
// interface already relies on the CPU interface to do software decoding when
23+
// needed, and that's already leading to similar complications.
24+
1625
namespace facebook::torchcodec {
1726
namespace {
1827

@@ -216,10 +225,11 @@ std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext(
216225
return nullptr;
217226
}
218227

219-
TORCH_CHECK(
220-
avFrame->hw_frames_ctx != nullptr,
221-
"The AVFrame does not have a hw_frames_ctx. "
222-
"That's unexpected, please report this to the TorchCodec repo.");
228+
if (avFrame->hw_frames_ctx == nullptr) {
229+
// TODONVDEC P2 return early for for beta interface where avFrames don't
230+
// have a hw_frames_ctx. We should get rid of this or improve the logic.
231+
return nullptr;
232+
}
223233

224234
auto hwFramesCtx =
225235
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
@@ -347,22 +357,23 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
347357
// Above we checked that the AVFrame was on GPU, but that's not enough, we
348358
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
349359
// because this is what the NPP color conversion routines expect.
350-
TORCH_CHECK(
351-
avFrame->hw_frames_ctx != nullptr,
352-
"The AVFrame does not have a hw_frames_ctx. "
353-
"That's unexpected, please report this to the TorchCodec repo.");
354-
355-
auto hwFramesCtx =
356-
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
357-
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
360+
// TODONVDEC P2 this can be hit from the beta interface, but there's no
361+
// hw_frames_ctx in this case. We should try to understand how that affects
362+
// this validation.
363+
AVHWFramesContext* hwFramesCtx = nullptr;
364+
if (avFrame->hw_frames_ctx != nullptr) {
365+
hwFramesCtx =
366+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
367+
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
358368

359-
TORCH_CHECK(
360-
actualFormat == AV_PIX_FMT_NV12,
361-
"The AVFrame is ",
362-
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
363-
: "unknown"),
364-
", but we expected AV_PIX_FMT_NV12. "
365-
"That's unexpected, please report this to the TorchCodec repo.");
369+
TORCH_CHECK(
370+
actualFormat == AV_PIX_FMT_NV12,
371+
"The AVFrame is ",
372+
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
373+
: "unknown"),
374+
", but we expected AV_PIX_FMT_NV12. "
375+
"That's unexpected, please report this to the TorchCodec repo.");
376+
}
366377

367378
auto frameDims =
368379
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
@@ -396,19 +407,23 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
396407
// arbitrary, but unfortunately we know it's hardcoded to be the default
397408
// stream by FFmpeg:
398409
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
399-
TORCH_CHECK(
400-
hwFramesCtx->device_ctx != nullptr,
401-
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
402-
auto cudaDeviceCtx =
403-
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
404-
at::cuda::CUDAEvent nvdecDoneEvent;
405-
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
406-
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
407-
nvdecDoneEvent.record(nvdecStream);
408-
409-
// Don't start NPP work before NVDEC is done decoding the frame!
410410
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
411-
nvdecDoneEvent.block(nppStream);
411+
if (hwFramesCtx) {
412+
// TODONVDEC P2 this block won't be hit from the beta interface because
413+
// there is no hwFramesCtx, but we should still make sure there's no CUDA
414+
// stream sync issue in the beta interface.
415+
TORCH_CHECK(
416+
hwFramesCtx->device_ctx != nullptr,
417+
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
418+
auto cudaDeviceCtx =
419+
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
420+
at::cuda::CUDAEvent nvdecDoneEvent;
421+
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
422+
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
423+
nvdecDoneEvent.record(nvdecStream);
424+
// Don't start NPP work before NVDEC is done decoding the frame!
425+
nvdecDoneEvent.block(nppStream);
426+
}
412427

413428
// Create the NPP context if we haven't yet.
414429
nppCtx_->hStream = nppStream.stream();

src/torchcodec/_core/DeviceInterface.cpp

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
namespace facebook::torchcodec {
1212

1313
namespace {
14-
using DeviceInterfaceMap = std::map<torch::DeviceType, CreateDeviceInterfaceFn>;
14+
using DeviceInterfaceMap =
15+
std::map<DeviceInterfaceKey, CreateDeviceInterfaceFn>;
1516
static std::mutex g_interface_mutex;
1617

1718
DeviceInterfaceMap& getDeviceMap() {
@@ -30,50 +31,79 @@ std::string getDeviceType(const std::string& device) {
3031
} // namespace
3132

3233
bool registerDeviceInterface(
33-
torch::DeviceType deviceType,
34+
const DeviceInterfaceKey& key,
3435
CreateDeviceInterfaceFn createInterface) {
3536
std::scoped_lock lock(g_interface_mutex);
3637
DeviceInterfaceMap& deviceMap = getDeviceMap();
3738

3839
TORCH_CHECK(
39-
deviceMap.find(deviceType) == deviceMap.end(),
40-
"Device interface already registered for ",
41-
deviceType);
42-
deviceMap.insert({deviceType, createInterface});
40+
deviceMap.find(key) == deviceMap.end(),
41+
"Device interface already registered for device type ",
42+
key.deviceType,
43+
" variant '",
44+
key.variant,
45+
"'");
46+
deviceMap.insert({key, createInterface});
4347

4448
return true;
4549
}
4650

47-
torch::Device createTorchDevice(const std::string device) {
51+
bool registerDeviceInterface(
52+
torch::DeviceType deviceType,
53+
CreateDeviceInterfaceFn createInterface) {
54+
return registerDeviceInterface(
55+
DeviceInterfaceKey(deviceType), createInterface);
56+
}
57+
58+
void validateDeviceInterface(
59+
const std::string device,
60+
const std::string variant) {
4861
std::scoped_lock lock(g_interface_mutex);
4962
std::string deviceType = getDeviceType(device);
63+
5064
DeviceInterfaceMap& deviceMap = getDeviceMap();
5165

66+
// Find device interface that matches device type and variant
67+
torch::DeviceType deviceTypeEnum = torch::Device(deviceType).type();
68+
5269
auto deviceInterface = std::find_if(
5370
deviceMap.begin(),
5471
deviceMap.end(),
55-
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
56-
return device.rfind(
57-
torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0;
72+
[&](const std::pair<DeviceInterfaceKey, CreateDeviceInterfaceFn>& arg) {
73+
return arg.first.deviceType == deviceTypeEnum &&
74+
arg.first.variant == variant;
5875
});
59-
TORCH_CHECK(
60-
deviceInterface != deviceMap.end(), "Unsupported device: ", device);
6176

62-
return torch::Device(device);
77+
TORCH_CHECK(
78+
deviceInterface != deviceMap.end(),
79+
"Unsupported device: ",
80+
device,
81+
" (device type: ",
82+
deviceType,
83+
", variant: ",
84+
variant,
85+
")");
6386
}
6487

6588
std::unique_ptr<DeviceInterface> createDeviceInterface(
66-
const torch::Device& device) {
67-
auto deviceType = device.type();
89+
const torch::Device& device,
90+
const std::string_view variant) {
91+
DeviceInterfaceKey key(device.type(), variant);
6892
std::scoped_lock lock(g_interface_mutex);
6993
DeviceInterfaceMap& deviceMap = getDeviceMap();
7094

71-
TORCH_CHECK(
72-
deviceMap.find(deviceType) != deviceMap.end(),
73-
"Unsupported device: ",
74-
device);
95+
auto it = deviceMap.find(key);
96+
if (it != deviceMap.end()) {
97+
return std::unique_ptr<DeviceInterface>(it->second(device));
98+
}
7599

76-
return std::unique_ptr<DeviceInterface>(deviceMap[deviceType](device));
100+
TORCH_CHECK(
101+
false,
102+
"No device interface found for device type: ",
103+
device.type(),
104+
" variant: '",
105+
variant,
106+
"'");
77107
}
78108

79109
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)