Skip to content

Commit 5605c90

Browse files
committed
Rework frame ordering and pts matching
1 parent b5fe9bc commit 5605c90

File tree

6 files changed

+158
-113
lines changed

6 files changed

+158
-113
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 69 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,17 @@ pfnSequenceCallback(void* pUserData, CUVIDEOFORMAT* videoFormat) {
4141
}
4242

4343
static int CUDAAPI
44-
pfnDecodePictureCallback(void* pUserData, CUVIDPICPARAMS* pPicParams) {
44+
pfnDecodePictureCallback(void* pUserData, CUVIDPICPARAMS* picParams) {
4545
BetaCudaDeviceInterface* decoder =
4646
static_cast<BetaCudaDeviceInterface*>(pUserData);
47-
return decoder->frameReadyForDecoding(pPicParams);
47+
return decoder->frameReadyForDecoding(picParams);
48+
}
49+
50+
static int CUDAAPI
51+
pfnDisplayPictureCallback(void* pUserData, CUVIDPARSERDISPINFO* dispInfo) {
52+
BetaCudaDeviceInterface* decoder =
53+
static_cast<BetaCudaDeviceInterface*>(pUserData);
54+
return decoder->frameReadyInDisplayOrder(dispInfo);
4855
}
4956

5057
static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
@@ -203,7 +210,7 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
203210
parserParams.pUserData = this;
204211
parserParams.pfnSequenceCallback = pfnSequenceCallback;
205212
parserParams.pfnDecodePicture = pfnDecodePictureCallback;
206-
parserParams.pfnDisplayPicture = nullptr;
213+
parserParams.pfnDisplayPicture = pfnDisplayPictureCallback;
207214

208215
CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams);
209216
TORCH_CHECK(
@@ -259,10 +266,6 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) {
259266
cuvidPacket.flags = CUVID_PKT_TIMESTAMP;
260267
cuvidPacket.timestamp = packet->pts;
261268

262-
// Like DALI: store packet PTS in queue to later assign to frames as they
263-
// come out
264-
packetsPtsQueue.push(packet->pts);
265-
266269
} else {
267270
// End of stream packet
268271
cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM;
@@ -312,68 +315,40 @@ void BetaCudaDeviceInterface::applyBSF(ReferenceAVPacket& packet) {
312315
// ready to be decoded, i.e. the parser received all the necessary packets for a
313316
// given frame. It means we can send that frame to be decoded by the hardware
314317
// NVDEC decoder by calling cuvidDecodePicture which is non-blocking.
315-
int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* pPicParams) {
318+
int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* picParams) {
316319
if (isFlushing_) {
317320
return 0;
318321
}
319322

320-
TORCH_CHECK(pPicParams != nullptr, "Invalid picture parameters");
323+
TORCH_CHECK(picParams != nullptr, "Invalid picture parameters");
321324
TORCH_CHECK(decoder_, "Decoder not initialized before picture decode");
322325

323326
// Send frame to be decoded by NVDEC - non-blocking call.
324-
CUresult result = cuvidDecodePicture(*decoder_.get(), pPicParams);
327+
CUresult result = cuvidDecodePicture(*decoder_.get(), picParams);
325328
if (result != CUDA_SUCCESS) {
326-
return 0; // Yes, you're reading that right, 0 mean error.
329+
return 0; // Yes, you're reading that right, 0 means error.
327330
}
328331

329-
// The frame was sent to be decoded on the NVDEC hardware. Now we store some
330-
// relevant info into our frame buffer so that we can retrieve the decoded
331-
// frame later when receiveFrame() is called.
332-
// Importantly we need to 'guess' the PTS of that frame. The heuristic we use
333-
// (like in DALI) is that the frames are ready to be decoded in the same order
334-
// as the packets were sent to the parser. So we assign the PTS of the frame
335-
// by popping the PTS of the oldest packet in our packetsPtsQueue (note:
336-
// oldest doesn't necessarily mean lowest PTS!).
337-
338-
TORCH_CHECK(
339-
// TODONVDEC P0 the queue may be empty, handle that.
340-
!packetsPtsQueue.empty(),
341-
"PTS queue is empty when decoding a frame");
342-
int64_t guessedPts = packetsPtsQueue.front();
343-
packetsPtsQueue.pop();
344-
345-
// Field values taken from DALI
346-
CUVIDPARSERDISPINFO dispInfo = {};
347-
dispInfo.picture_index = pPicParams->CurrPicIdx;
348-
dispInfo.progressive_frame = !pPicParams->field_pic_flag;
349-
dispInfo.top_field_first = pPicParams->bottom_field_flag ^ 1;
350-
dispInfo.repeat_first_field = 0;
351-
dispInfo.timestamp = guessedPts;
352-
353-
FrameBuffer::Slot* slot = frameBuffer_.findEmptySlot();
354-
slot->dispInfo = dispInfo;
355-
slot->guessedPts = guessedPts;
356-
slot->occupied = true;
332+
frameBuffer_.markAsBeingDecoded(/*slotId=*/picParams->CurrPicIdx);
333+
return 1;
334+
}
357335

336+
int BetaCudaDeviceInterface::frameReadyInDisplayOrder(
337+
CUVIDPARSERDISPINFO* dispInfo) {
338+
frameBuffer_.markSlotReadyAndSetInfo(
339+
/*slotId=*/dispInfo->picture_index, dispInfo);
358340
return 1;
359341
}
360342

361-
// Moral equivalent of avcodec_receive_frame(). Here, we look for a decoded
362-
// frame with the exact desired PTS in our frame buffer. This logic is only
363-
// valid in exact seek_mode, for now.
364-
int BetaCudaDeviceInterface::receiveFrame(
365-
UniqueAVFrame& avFrame,
366-
int64_t desiredPts) {
367-
FrameBuffer::Slot* slot = frameBuffer_.findFrameWithExactPts(desiredPts);
343+
// Moral equivalent of avcodec_receive_frame().
344+
int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
345+
FrameBuffer::Slot* slot = frameBuffer_.findReadySlotWithLowestPts();
368346
if (slot == nullptr) {
369347
// No frame found, instruct caller to try again later after sending more
370348
// packets.
371349
return AVERROR(EAGAIN);
372350
}
373351

374-
slot->occupied = false;
375-
slot->guessedPts = -1;
376-
377352
CUVIDPROCPARAMS procParams = {};
378353
CUVIDPARSERDISPINFO dispInfo = slot->dispInfo;
379354
procParams.progressive_frame = dispInfo.progressive_frame;
@@ -382,6 +357,8 @@ int BetaCudaDeviceInterface::receiveFrame(
382357
CUdeviceptr framePtr = 0;
383358
unsigned int pitch = 0;
384359

360+
frameBuffer_.free(slot->slotId);
361+
385362
// We know the frame we want was sent to the hardware decoder, but now we need
386363
// to "map" it to an "output surface" before we can use its data. This is a
387364
// blocking calls that waits until the frame is fully decoded and ready to be
@@ -435,7 +412,7 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
435412
avFrame->width = width;
436413
avFrame->height = height;
437414
avFrame->format = AV_PIX_FMT_CUDA;
438-
avFrame->pts = dispInfo.timestamp; // == guessedPts
415+
avFrame->pts = dispInfo.timestamp;
439416

440417
unsigned int frameRateNum = videoFormat_.frame_rate.numerator;
441418
unsigned int frameRateDen = videoFormat_.frame_rate.denominator;
@@ -498,13 +475,7 @@ void BetaCudaDeviceInterface::flush() {
498475

499476
isFlushing_ = false;
500477

501-
for (auto& slot : frameBuffer_) {
502-
slot.occupied = false;
503-
slot.guessedPts = -1;
504-
}
505-
506-
std::queue<int64_t> empty;
507-
packetsPtsQueue.swap(empty);
478+
frameBuffer_.clear();
508479

509480
eofSent_ = false;
510481
}
@@ -538,26 +509,52 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
538509
preAllocatedOutputTensor);
539510
}
540511

541-
BetaCudaDeviceInterface::FrameBuffer::Slot*
542-
BetaCudaDeviceInterface::FrameBuffer::findEmptySlot() {
543-
for (auto& slot : frameBuffer_) {
544-
if (!slot.occupied) {
545-
return &slot;
546-
}
547-
}
548-
frameBuffer_.emplace_back();
549-
return &frameBuffer_.back();
512+
void BetaCudaDeviceInterface::FrameBuffer::markAsBeingDecoded(int slotId) {
513+
auto it = frameBuffer_.find(slotId);
514+
TORCH_CHECK(
515+
it == frameBuffer_.end(),
516+
"Slot ",
517+
slotId,
518+
" is already occupied. This should never happen.");
519+
520+
frameBuffer_.emplace(slotId, Slot(slotId, SlotState::BEING_DECODED));
521+
}
522+
523+
void BetaCudaDeviceInterface::FrameBuffer::markSlotReadyAndSetInfo(
524+
int slotId,
525+
CUVIDPARSERDISPINFO* dispInfo) {
526+
auto it = frameBuffer_.find(slotId);
527+
TORCH_CHECK(
528+
it != frameBuffer_.end(),
529+
"Could not find matching slot with slotId ",
530+
slotId,
531+
". This should never happen.");
532+
533+
it->second.state = SlotState::READY_FOR_OUTPUT;
534+
it->second.dispInfo = *dispInfo;
535+
}
536+
537+
void BetaCudaDeviceInterface::FrameBuffer::free(int slotId) {
538+
auto it = frameBuffer_.find(slotId);
539+
TORCH_CHECK(
540+
it != frameBuffer_.end(),
541+
"Tried to free non-existing slot with slotId",
542+
slotId,
543+
". This should never happen.");
544+
frameBuffer_.erase(it);
550545
}
551546

552547
BetaCudaDeviceInterface::FrameBuffer::Slot*
553-
BetaCudaDeviceInterface::FrameBuffer::findFrameWithExactPts(
554-
int64_t desiredPts) {
555-
for (auto& slot : frameBuffer_) {
556-
if (slot.occupied && slot.guessedPts == desiredPts) {
557-
return &slot;
548+
BetaCudaDeviceInterface::FrameBuffer::findReadySlotWithLowestPts() {
549+
Slot* outputSlot = nullptr;
550+
for (auto& [_, slot] : frameBuffer_) {
551+
if (slot.state == SlotState::READY_FOR_OUTPUT &&
552+
(outputSlot == nullptr ||
553+
slot.dispInfo.timestamp < outputSlot->dispInfo.timestamp)) {
554+
outputSlot = &slot;
558555
}
559556
}
560-
return nullptr;
557+
return outputSlot;
561558
}
562559

563560
} // namespace facebook::torchcodec

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,49 +52,49 @@ class BetaCudaDeviceInterface : public DeviceInterface {
5252
}
5353

5454
int sendPacket(ReferenceAVPacket& packet) override;
55-
int receiveFrame(UniqueAVFrame& avFrame, int64_t desiredPts) override;
55+
int receiveFrame(UniqueAVFrame& avFrame) override;
5656
void flush() override;
5757

5858
// NVDEC callback functions (must be public for C callbacks)
5959
int streamPropertyChange(CUVIDEOFORMAT* videoFormat);
60-
int frameReadyForDecoding(CUVIDPICPARAMS* pPicParams);
60+
int frameReadyForDecoding(CUVIDPICPARAMS* picParams);
61+
int frameReadyInDisplayOrder(CUVIDPARSERDISPINFO* dispInfo);
6162

6263
private:
6364
// Apply bitstream filter, modifies packet in-place
6465
void applyBSF(ReferenceAVPacket& packet);
6566

6667
class FrameBuffer {
6768
public:
69+
enum class SlotState { BEING_DECODED, READY_FOR_OUTPUT };
70+
6871
struct Slot {
6972
CUVIDPARSERDISPINFO dispInfo;
70-
int64_t guessedPts;
71-
bool occupied = false;
73+
SlotState state;
74+
int slotId;
7275

73-
Slot() : guessedPts(-1), occupied(false) {
76+
explicit Slot(int id, SlotState s) : state(s), slotId(id) {
7477
std::memset(&dispInfo, 0, sizeof(dispInfo));
78+
TORCH_CHECK(
79+
state == SlotState::BEING_DECODED,
80+
"Programmer: are you sure you want to create a slot that is not BEING_DECODED?");
7581
}
7682
};
7783

78-
// TODONVDEC P1: init size should probably be min_num_decode_surfaces from
79-
// video format
80-
FrameBuffer() : frameBuffer_(4) {}
81-
84+
FrameBuffer() = default;
8285
~FrameBuffer() = default;
8386

84-
Slot* findEmptySlot();
85-
Slot* findFrameWithExactPts(int64_t desiredPts);
87+
void markAsBeingDecoded(int slotId);
88+
void markSlotReadyAndSetInfo(int slotId, CUVIDPARSERDISPINFO* dispInfo);
89+
void free(int slotId);
90+
Slot* findReadySlotWithLowestPts();
8691

87-
// Iterator support for range-based for loops
88-
auto begin() {
89-
return frameBuffer_.begin();
90-
}
91-
92-
auto end() {
93-
return frameBuffer_.end();
92+
void clear() {
93+
frameBuffer_.clear();
9494
}
9595

9696
private:
97-
std::vector<Slot> frameBuffer_;
97+
std::unordered_map<int, Slot> frameBuffer_;
9898
};
9999

100100
UniqueAVFrame convertCudaFrameToAVFrame(
@@ -108,8 +108,6 @@ class BetaCudaDeviceInterface : public DeviceInterface {
108108

109109
FrameBuffer frameBuffer_;
110110

111-
std::queue<int64_t> packetsPtsQueue;
112-
113111
bool eofSent_ = false;
114112

115113
// Flush flag to prevent decode operations during flush (like DALI's

src/torchcodec/_core/DeviceInterface.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ class DeviceInterface {
8787
// Moral equivalent of avcodec_receive_frame()
8888
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready,
8989
// AVERROR_EOF if end of stream, or other AVERROR on failure
90-
virtual int receiveFrame(
91-
[[maybe_unused]] UniqueAVFrame& avFrame,
92-
[[maybe_unused]] int64_t desiredPts) {
90+
virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) {
9391
TORCH_CHECK(
9492
false,
9593
"Send/receive packet decoding not implemented for this device interface");

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11461146

11471147
while (true) {
11481148
if (useCustomInterface) {
1149-
status = deviceInterface_->receiveFrame(avFrame, cursor_);
1149+
status = deviceInterface_->receiveFrame(avFrame);
11501150
} else {
11511151
status =
11521152
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());

src/torchcodec/decoders/_video_decoder.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,6 @@ def __init__(
153153
device_variant = device_split[2]
154154
device = ":".join(device_split[0:2])
155155

156-
# TODONVDEC P0 Support approximate mode. Not ideal to validate that here
157-
# either, but validating this at a lower level forces to add yet another
158-
# (temprorary) validation API to the device inteface
159-
if device_variant == "beta" and seek_mode != "exact":
160-
raise ValueError("Seek mode must be exact for BETA CUDA interface.")
161-
162156
core.add_video_stream(
163157
self._decoder,
164158
stream_index=stream_index,

0 commit comments

Comments
 (0)