@@ -35,16 +35,20 @@ static bool g_cuda_beta = registerDeviceInterface(
3535
3636static int CUDAAPI
3737pfnSequenceCallback (void * pUserData, CUVIDEOFORMAT* videoFormat) {
38- BetaCudaDeviceInterface* decoder =
39- static_cast <BetaCudaDeviceInterface*>(pUserData);
38+ auto decoder = static_cast <BetaCudaDeviceInterface*>(pUserData);
4039 return decoder->streamPropertyChange (videoFormat);
4140}
4241
4342static int CUDAAPI
44- pfnDecodePictureCallback (void * pUserData, CUVIDPICPARAMS* pPicParams) {
45- BetaCudaDeviceInterface* decoder =
46- static_cast <BetaCudaDeviceInterface*>(pUserData);
47- return decoder->frameReadyForDecoding (pPicParams);
43+ pfnDecodePictureCallback (void * pUserData, CUVIDPICPARAMS* picParams) {
44+ auto decoder = static_cast <BetaCudaDeviceInterface*>(pUserData);
45+ return decoder->frameReadyForDecoding (picParams);
46+ }
47+
48+ static int CUDAAPI
49+ pfnDisplayPictureCallback (void * pUserData, CUVIDPARSERDISPINFO* dispInfo) {
50+ auto decoder = static_cast <BetaCudaDeviceInterface*>(pUserData);
51+ return decoder->frameReadyInDisplayOrder (dispInfo);
4852}
4953
5054static UniqueCUvideodecoder createDecoder (CUVIDEOFORMAT* videoFormat) {
@@ -142,7 +146,7 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
142146
143147BetaCudaDeviceInterface::~BetaCudaDeviceInterface () {
144148 // TODONVDEC P0: we probably need to free the frames that have been decoded by
145- // NVDEC but not yet "mapped" - i.e. those that are still in frameBuffer_ ?
149+ // NVDEC but not yet "mapped" - i.e. those that are still in readyFrames_ ?
146150
147151 if (decoder_) {
148152 NVDECCache::getCache (device_.index ())
@@ -218,7 +222,7 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
218222 parserParams.pUserData = this ;
219223 parserParams.pfnSequenceCallback = pfnSequenceCallback;
220224 parserParams.pfnDecodePicture = pfnDecodePictureCallback;
221- parserParams.pfnDisplayPicture = nullptr ;
225+ parserParams.pfnDisplayPicture = pfnDisplayPictureCallback ;
222226
223227 CUresult result = cuvidCreateVideoParser (&videoParser_, &parserParams);
224228 TORCH_CHECK (
@@ -274,10 +278,6 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) {
274278 cuvidPacket.flags = CUVID_PKT_TIMESTAMP;
275279 cuvidPacket.timestamp = packet->pts ;
276280
277- // Like DALI: store packet PTS in queue to later assign to frames as they
278- // come out
279- packetsPtsQueue.push (packet->pts );
280-
281281 } else {
282282 // End of stream packet
283283 cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM;
@@ -329,70 +329,38 @@ void BetaCudaDeviceInterface::applyBSF(ReferenceAVPacket& packet) {
329329// ready to be decoded, i.e. the parser received all the necessary packets for a
330330// given frame. It means we can send that frame to be decoded by the hardware
331331// NVDEC decoder by calling cuvidDecodePicture which is non-blocking.
332- int BetaCudaDeviceInterface::frameReadyForDecoding (CUVIDPICPARAMS* pPicParams ) {
332+ int BetaCudaDeviceInterface::frameReadyForDecoding (CUVIDPICPARAMS* picParams ) {
333333 if (isFlushing_) {
334334 return 0 ;
335335 }
336336
337- TORCH_CHECK (pPicParams != nullptr , " Invalid picture parameters" );
337+ TORCH_CHECK (picParams != nullptr , " Invalid picture parameters" );
338338 TORCH_CHECK (decoder_, " Decoder not initialized before picture decode" );
339339
340340 // Send frame to be decoded by NVDEC - non-blocking call.
341- CUresult result = cuvidDecodePicture (*decoder_.get (), pPicParams);
342- if (result != CUDA_SUCCESS) {
343- return 0 ; // Yes, you're reading that right, 0 mean error.
344- }
341+ CUresult result = cuvidDecodePicture (*decoder_.get (), picParams);
345342
346- // The frame was sent to be decoded on the NVDEC hardware. Now we store some
347- // relevant info into our frame buffer so that we can retrieve the decoded
348- // frame later when receiveFrame() is called.
349- // Importantly we need to 'guess' the PTS of that frame. The heuristic we use
350- // (like in DALI) is that the frames are ready to be decoded in the same order
351- // as the packets were sent to the parser. So we assign the PTS of the frame
352- // by popping the PTS of the oldest packet in our packetsPtsQueue (note:
353- // oldest doesn't necessarily mean lowest PTS!).
343+ // Yes, you're reading that right, 0 means error, 1 means success
344+ return (result == CUDA_SUCCESS);
345+ }
354346
355- TORCH_CHECK (
356- // TODONVDEC P0 the queue may be empty, handle that.
357- !packetsPtsQueue.empty (),
358- " PTS queue is empty when decoding a frame" );
359- int64_t guessedPts = packetsPtsQueue.front ();
360- packetsPtsQueue.pop ();
361-
362- // Field values taken from DALI
363- CUVIDPARSERDISPINFO dispInfo = {};
364- dispInfo.picture_index = pPicParams->CurrPicIdx ;
365- dispInfo.progressive_frame = !pPicParams->field_pic_flag ;
366- dispInfo.top_field_first = pPicParams->bottom_field_flag ^ 1 ;
367- dispInfo.repeat_first_field = 0 ;
368- dispInfo.timestamp = guessedPts;
369-
370- FrameBuffer::Slot* slot = frameBuffer_.findEmptySlot ();
371- slot->dispInfo = dispInfo;
372- slot->guessedPts = guessedPts;
373- slot->occupied = true ;
374-
375- return 1 ;
347+ int BetaCudaDeviceInterface::frameReadyInDisplayOrder (
348+ CUVIDPARSERDISPINFO* dispInfo) {
349+ readyFrames_.push (*dispInfo);
350+ return 1 ; // success
376351}
377352
378- // Moral equivalent of avcodec_receive_frame(). Here, we look for a decoded
379- // frame with the exact desired PTS in our frame buffer. This logic is only
380- // valid in exact seek_mode, for now.
381- int BetaCudaDeviceInterface::receiveFrame (
382- UniqueAVFrame& avFrame,
383- int64_t desiredPts) {
384- FrameBuffer::Slot* slot = frameBuffer_.findFrameWithExactPts (desiredPts);
385- if (slot == nullptr ) {
353+ // Moral equivalent of avcodec_receive_frame().
354+ int BetaCudaDeviceInterface::receiveFrame (UniqueAVFrame& avFrame) {
355+ if (readyFrames_.empty ()) {
386356 // No frame found, instruct caller to try again later after sending more
387357 // packets.
388358 return AVERROR (EAGAIN);
389359 }
390-
391- slot->occupied = false ;
392- slot->guessedPts = -1 ;
360+ CUVIDPARSERDISPINFO dispInfo = readyFrames_.front ();
361+ readyFrames_.pop ();
393362
394363 CUVIDPROCPARAMS procParams = {};
395- CUVIDPARSERDISPINFO dispInfo = slot->dispInfo ;
396364 procParams.progressive_frame = dispInfo.progressive_frame ;
397365 procParams.top_field_first = dispInfo.top_field_first ;
398366 procParams.unpaired_field = dispInfo.repeat_first_field < 0 ;
@@ -452,7 +420,7 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
452420 avFrame->width = width;
453421 avFrame->height = height;
454422 avFrame->format = AV_PIX_FMT_CUDA;
455- avFrame->pts = dispInfo.timestamp ; // == guessedPts
423+ avFrame->pts = dispInfo.timestamp ;
456424
457425 // TODONVDEC P0: Zero division error!!!
458426 // TODONVDEC P0: Move AVRational arithmetic to FFMPEGCommon, and put the
@@ -518,13 +486,8 @@ void BetaCudaDeviceInterface::flush() {
518486
519487 isFlushing_ = false ;
520488
521- for (auto & slot : frameBuffer_) {
522- slot.occupied = false ;
523- slot.guessedPts = -1 ;
524- }
525-
526- std::queue<int64_t > empty;
527- packetsPtsQueue.swap (empty);
489+ std::queue<CUVIDPARSERDISPINFO> emptyQueue;
490+ std::swap (readyFrames_, emptyQueue);
528491
529492 eofSent_ = false ;
530493}
@@ -544,26 +507,4 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
544507 avFrame, frameOutput, preAllocatedOutputTensor);
545508}
546509
547- BetaCudaDeviceInterface::FrameBuffer::Slot*
548- BetaCudaDeviceInterface::FrameBuffer::findEmptySlot () {
549- for (auto & slot : frameBuffer_) {
550- if (!slot.occupied ) {
551- return &slot;
552- }
553- }
554- frameBuffer_.emplace_back ();
555- return &frameBuffer_.back ();
556- }
557-
558- BetaCudaDeviceInterface::FrameBuffer::Slot*
559- BetaCudaDeviceInterface::FrameBuffer::findFrameWithExactPts (
560- int64_t desiredPts) {
561- for (auto & slot : frameBuffer_) {
562- if (slot.occupied && slot.guessedPts == desiredPts) {
563- return &slot;
564- }
565- }
566- return nullptr ;
567- }
568-
569510} // namespace facebook::torchcodec
0 commit comments