@@ -41,10 +41,17 @@ pfnSequenceCallback(void* pUserData, CUVIDEOFORMAT* videoFormat) {
4141}
4242
4343static int CUDAAPI
44- pfnDecodePictureCallback (void * pUserData, CUVIDPICPARAMS* pPicParams ) {
44+ pfnDecodePictureCallback (void * pUserData, CUVIDPICPARAMS* picParams ) {
4545 BetaCudaDeviceInterface* decoder =
4646 static_cast <BetaCudaDeviceInterface*>(pUserData);
47- return decoder->frameReadyForDecoding (pPicParams);
47+ return decoder->frameReadyForDecoding (picParams);
48+ }
49+
50+ static int CUDAAPI
51+ pfnDisplayPictureCallback (void * pUserData, CUVIDPARSERDISPINFO* dispInfo) {
52+ BetaCudaDeviceInterface* decoder =
53+ static_cast <BetaCudaDeviceInterface*>(pUserData);
54+ return decoder->frameReadyInDisplayOrder (dispInfo);
4855}
4956
5057static UniqueCUvideodecoder createDecoder (CUVIDEOFORMAT* videoFormat) {
@@ -203,7 +210,7 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
203210 parserParams.pUserData = this ;
204211 parserParams.pfnSequenceCallback = pfnSequenceCallback;
205212 parserParams.pfnDecodePicture = pfnDecodePictureCallback;
206- parserParams.pfnDisplayPicture = nullptr ;
213+ parserParams.pfnDisplayPicture = pfnDisplayPictureCallback ;
207214
208215 CUresult result = cuvidCreateVideoParser (&videoParser_, &parserParams);
209216 TORCH_CHECK (
@@ -259,10 +266,6 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) {
259266 cuvidPacket.flags = CUVID_PKT_TIMESTAMP;
260267 cuvidPacket.timestamp = packet->pts ;
261268
262- // Like DALI: store packet PTS in queue to later assign to frames as they
263- // come out
264- packetsPtsQueue.push (packet->pts );
265-
266269 } else {
267270 // End of stream packet
268271 cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM;
@@ -312,68 +315,40 @@ void BetaCudaDeviceInterface::applyBSF(ReferenceAVPacket& packet) {
312315// ready to be decoded, i.e. the parser received all the necessary packets for a
313316// given frame. It means we can send that frame to be decoded by the hardware
314317// NVDEC decoder by calling cuvidDecodePicture which is non-blocking.
315- int BetaCudaDeviceInterface::frameReadyForDecoding (CUVIDPICPARAMS* pPicParams ) {
318+ int BetaCudaDeviceInterface::frameReadyForDecoding (CUVIDPICPARAMS* picParams ) {
316319 if (isFlushing_) {
317320 return 0 ;
318321 }
319322
320- TORCH_CHECK (pPicParams != nullptr , " Invalid picture parameters" );
323+ TORCH_CHECK (picParams != nullptr , " Invalid picture parameters" );
321324 TORCH_CHECK (decoder_, " Decoder not initialized before picture decode" );
322325
323326 // Send frame to be decoded by NVDEC - non-blocking call.
324- CUresult result = cuvidDecodePicture (*decoder_.get (), pPicParams );
327+ CUresult result = cuvidDecodePicture (*decoder_.get (), picParams );
325328 if (result != CUDA_SUCCESS) {
326- return 0 ; // Yes, you're reading that right, 0 mean error.
329+ return 0 ; // Yes, you're reading that right, 0 means error.
327330 }
328331
329- // The frame was sent to be decoded on the NVDEC hardware. Now we store some
330- // relevant info into our frame buffer so that we can retrieve the decoded
331- // frame later when receiveFrame() is called.
332- // Importantly we need to 'guess' the PTS of that frame. The heuristic we use
333- // (like in DALI) is that the frames are ready to be decoded in the same order
334- // as the packets were sent to the parser. So we assign the PTS of the frame
335- // by popping the PTS of the oldest packet in our packetsPtsQueue (note:
336- // oldest doesn't necessarily mean lowest PTS!).
337-
338- TORCH_CHECK (
339- // TODONVDEC P0 the queue may be empty, handle that.
340- !packetsPtsQueue.empty (),
341- " PTS queue is empty when decoding a frame" );
342- int64_t guessedPts = packetsPtsQueue.front ();
343- packetsPtsQueue.pop ();
344-
345- // Field values taken from DALI
346- CUVIDPARSERDISPINFO dispInfo = {};
347- dispInfo.picture_index = pPicParams->CurrPicIdx ;
348- dispInfo.progressive_frame = !pPicParams->field_pic_flag ;
349- dispInfo.top_field_first = pPicParams->bottom_field_flag ^ 1 ;
350- dispInfo.repeat_first_field = 0 ;
351- dispInfo.timestamp = guessedPts;
352-
353- FrameBuffer::Slot* slot = frameBuffer_.findEmptySlot ();
354- slot->dispInfo = dispInfo;
355- slot->guessedPts = guessedPts;
356- slot->occupied = true ;
332+ frameBuffer_.markAsBeingDecoded (/* slotId=*/ picParams->CurrPicIdx );
333+ return 1 ;
334+ }
357335
336+ int BetaCudaDeviceInterface::frameReadyInDisplayOrder (
337+ CUVIDPARSERDISPINFO* dispInfo) {
338+ frameBuffer_.markSlotReadyAndSetInfo (
339+ /* slotId=*/ dispInfo->picture_index , dispInfo);
358340 return 1 ;
359341}
360342
361- // Moral equivalent of avcodec_receive_frame(). Here, we look for a decoded
362- // frame with the exact desired PTS in our frame buffer. This logic is only
363- // valid in exact seek_mode, for now.
364- int BetaCudaDeviceInterface::receiveFrame (
365- UniqueAVFrame& avFrame,
366- int64_t desiredPts) {
367- FrameBuffer::Slot* slot = frameBuffer_.findFrameWithExactPts (desiredPts);
343+ // Moral equivalent of avcodec_receive_frame().
344+ int BetaCudaDeviceInterface::receiveFrame (UniqueAVFrame& avFrame) {
345+ FrameBuffer::Slot* slot = frameBuffer_.findReadySlotWithLowestPts ();
368346 if (slot == nullptr ) {
369347 // No frame found, instruct caller to try again later after sending more
370348 // packets.
371349 return AVERROR (EAGAIN);
372350 }
373351
374- slot->occupied = false ;
375- slot->guessedPts = -1 ;
376-
377352 CUVIDPROCPARAMS procParams = {};
378353 CUVIDPARSERDISPINFO dispInfo = slot->dispInfo ;
379354 procParams.progressive_frame = dispInfo.progressive_frame ;
@@ -382,6 +357,8 @@ int BetaCudaDeviceInterface::receiveFrame(
382357 CUdeviceptr framePtr = 0 ;
383358 unsigned int pitch = 0 ;
384359
360+ frameBuffer_.free (slot->slotId );
361+
385362 // We know the frame we want was sent to the hardware decoder, but now we need
386363 // to "map" it to an "output surface" before we can use its data. This is a
387364 // blocking calls that waits until the frame is fully decoded and ready to be
@@ -435,7 +412,7 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
435412 avFrame->width = width;
436413 avFrame->height = height;
437414 avFrame->format = AV_PIX_FMT_CUDA;
438- avFrame->pts = dispInfo.timestamp ; // == guessedPts
415+ avFrame->pts = dispInfo.timestamp ;
439416
440417 unsigned int frameRateNum = videoFormat_.frame_rate .numerator ;
441418 unsigned int frameRateDen = videoFormat_.frame_rate .denominator ;
@@ -498,13 +475,7 @@ void BetaCudaDeviceInterface::flush() {
498475
499476 isFlushing_ = false ;
500477
501- for (auto & slot : frameBuffer_) {
502- slot.occupied = false ;
503- slot.guessedPts = -1 ;
504- }
505-
506- std::queue<int64_t > empty;
507- packetsPtsQueue.swap (empty);
478+ frameBuffer_.clear ();
508479
509480 eofSent_ = false ;
510481}
@@ -538,26 +509,52 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
538509 preAllocatedOutputTensor);
539510}
540511
541- BetaCudaDeviceInterface::FrameBuffer::Slot*
542- BetaCudaDeviceInterface::FrameBuffer::findEmptySlot () {
543- for (auto & slot : frameBuffer_) {
544- if (!slot.occupied ) {
545- return &slot;
546- }
547- }
548- frameBuffer_.emplace_back ();
549- return &frameBuffer_.back ();
512+ void BetaCudaDeviceInterface::FrameBuffer::markAsBeingDecoded (int slotId) {
513+ auto it = frameBuffer_.find (slotId);
514+ TORCH_CHECK (
515+ it == frameBuffer_.end (),
516+ " Slot " ,
517+ slotId,
518+ " is already occupied. This should never happen." );
519+
520+ frameBuffer_.emplace (slotId, Slot (slotId, SlotState::BEING_DECODED));
521+ }
522+
523+ void BetaCudaDeviceInterface::FrameBuffer::markSlotReadyAndSetInfo (
524+ int slotId,
525+ CUVIDPARSERDISPINFO* dispInfo) {
526+ auto it = frameBuffer_.find (slotId);
527+ TORCH_CHECK (
528+ it != frameBuffer_.end (),
529+ " Could not find matching slot with slotId " ,
530+ slotId,
531+ " . This should never happen." );
532+
533+ it->second .state = SlotState::READY_FOR_OUTPUT;
534+ it->second .dispInfo = *dispInfo;
535+ }
536+
537+ void BetaCudaDeviceInterface::FrameBuffer::free (int slotId) {
538+ auto it = frameBuffer_.find (slotId);
539+ TORCH_CHECK (
540+ it != frameBuffer_.end (),
541+ " Tried to free non-existing slot with slotId" ,
542+ slotId,
543+ " . This should never happen." );
544+ frameBuffer_.erase (it);
550545}
551546
552547BetaCudaDeviceInterface::FrameBuffer::Slot*
553- BetaCudaDeviceInterface::FrameBuffer::findFrameWithExactPts (
554- int64_t desiredPts) {
555- for (auto & slot : frameBuffer_) {
556- if (slot.occupied && slot.guessedPts == desiredPts) {
557- return &slot;
548+ BetaCudaDeviceInterface::FrameBuffer::findReadySlotWithLowestPts () {
549+ Slot* outputSlot = nullptr ;
550+ for (auto & [_, slot] : frameBuffer_) {
551+ if (slot.state == SlotState::READY_FOR_OUTPUT &&
552+ (outputSlot == nullptr ||
553+ slot.dispInfo .timestamp < outputSlot->dispInfo .timestamp )) {
554+ outputSlot = &slot;
558555 }
559556 }
560- return nullptr ;
557+ return outputSlot ;
561558}
562559
563560} // namespace facebook::torchcodec
0 commit comments