Skip to content

Commit 0ad7370

Browse files
committed
abstract frameBuffer_ into a FrameBuffer class
1 parent dcf3124 commit 0ad7370

File tree

2 files changed

+42
-24
lines changed

2 files changed

+42
-24
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,6 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
133133
TORCH_CHECK(g_cuda_beta, "BetaCudaDeviceInterface was not registered!");
134134
TORCH_CHECK(
135135
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
136-
137-
// TODONVDEC P1: init size should probably be min_num_decode_surfaces from
138-
// video format
139-
frameBuffer_.resize(4);
140136
}
141137

142138
BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
@@ -344,7 +340,7 @@ int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* pPicParams) {
344340
dispInfo.repeat_first_field = 0;
345341
dispInfo.timestamp = guessedPts;
346342

347-
FrameBufferSlot* slot = findEmptySlot();
343+
FrameBuffer::Slot* slot = frameBuffer_.findEmptySlot();
348344
slot->dispInfo = dispInfo;
349345
slot->guessedPts = guessedPts;
350346
slot->occupied = true;
@@ -358,7 +354,7 @@ int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* pPicParams) {
358354
int BetaCudaDeviceInterface::receiveFrame(
359355
UniqueAVFrame& avFrame,
360356
int64_t desiredPts) {
361-
FrameBufferSlot* slot = findFrameWithExactPts(desiredPts);
357+
FrameBuffer::Slot* slot = frameBuffer_.findFrameWithExactPts(desiredPts);
362358
if (slot == nullptr) {
363359
// No frame found, instruct caller to try again later after sending more
364360
// packets.
@@ -532,9 +528,8 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
532528
preAllocatedOutputTensor);
533529
}
534530

535-
// TODONVDEC P0: Don't let buffer grow indefinitely.
536-
BetaCudaDeviceInterface::FrameBufferSlot*
537-
BetaCudaDeviceInterface::findEmptySlot() {
531+
BetaCudaDeviceInterface::FrameBuffer::Slot*
532+
BetaCudaDeviceInterface::FrameBuffer::findEmptySlot() {
538533
for (auto& slot : frameBuffer_) {
539534
if (!slot.occupied) {
540535
return &slot;
@@ -544,8 +539,9 @@ BetaCudaDeviceInterface::findEmptySlot() {
544539
return &frameBuffer_.back();
545540
}
546541

547-
BetaCudaDeviceInterface::FrameBufferSlot*
548-
BetaCudaDeviceInterface::findFrameWithExactPts(int64_t desiredPts) {
542+
BetaCudaDeviceInterface::FrameBuffer::Slot*
543+
BetaCudaDeviceInterface::FrameBuffer::findFrameWithExactPts(
544+
int64_t desiredPts) {
549545
for (auto& slot : frameBuffer_) {
550546
if (slot.occupied && slot.guessedPts == desiredPts) {
551547
return &slot;

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,40 @@ class BetaCudaDeviceInterface : public DeviceInterface {
6464
int frameReadyForDecoding(CUVIDPICPARAMS* pPicParams);
6565

6666
private:
67+
class FrameBuffer {
68+
public:
69+
struct Slot {
70+
CUVIDPARSERDISPINFO dispInfo;
71+
int64_t guessedPts;
72+
bool occupied = false;
73+
74+
Slot() : guessedPts(-1), occupied(false) {
75+
std::memset(&dispInfo, 0, sizeof(dispInfo));
76+
}
77+
};
78+
79+
// TODONVDEC P1: init size should probably be min_num_decode_surfaces from
80+
// video format
81+
FrameBuffer() : frameBuffer_(4) {}
82+
83+
~FrameBuffer() = default;
84+
85+
Slot* findEmptySlot();
86+
Slot* findFrameWithExactPts(int64_t desiredPts);
87+
88+
// Iterator support for range-based for loops
89+
auto begin() {
90+
return frameBuffer_.begin();
91+
}
92+
93+
auto end() {
94+
return frameBuffer_.end();
95+
}
96+
97+
private:
98+
std::vector<Slot> frameBuffer_;
99+
};
100+
67101
UniqueAVFrame convertCudaFrameToAVFrame(
68102
CUdeviceptr framePtr,
69103
unsigned int pitch,
@@ -73,19 +107,7 @@ class BetaCudaDeviceInterface : public DeviceInterface {
73107
UniqueCUvideodecoder decoder_;
74108
CUVIDEOFORMAT videoFormat_ = {};
75109

76-
struct FrameBufferSlot {
77-
CUVIDPARSERDISPINFO dispInfo;
78-
int64_t guessedPts;
79-
bool occupied = false;
80-
81-
FrameBufferSlot() : guessedPts(-1), occupied(false) {
82-
std::memset(&dispInfo, 0, sizeof(dispInfo));
83-
}
84-
};
85-
86-
std::vector<FrameBufferSlot> frameBuffer_;
87-
FrameBufferSlot* findEmptySlot();
88-
FrameBufferSlot* findFrameWithExactPts(int64_t desiredPts);
110+
FrameBuffer frameBuffer_;
89111

90112
std::queue<int64_t> packetsPtsQueue;
91113

0 commit comments

Comments
 (0)