@@ -162,12 +162,11 @@ void VideoDecoder::initializeDecoder() {
162162 av_q2d (avStream->time_base ) * avStream->duration ;
163163 }
164164
165- double fps = av_q2d (avStream->r_frame_rate );
166- if (fps > 0 ) {
167- streamMetadata.averageFps = fps;
168- }
169-
170165 if (avStream->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO) {
166+ double fps = av_q2d (avStream->r_frame_rate );
167+ if (fps > 0 ) {
168+ streamMetadata.averageFps = fps;
169+ }
171170 containerMetadata_.numVideoStreams ++;
172171 } else if (avStream->codecpar ->codec_type == AVMEDIA_TYPE_AUDIO) {
173172 containerMetadata_.numAudioStreams ++;
@@ -340,7 +339,7 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
340339}
341340
342341torch::Tensor VideoDecoder::getKeyFrameIndices () {
343- validateActiveStream ();
342+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
344343 validateScannedAllStreams (" getKeyFrameIndices" );
345344
346345 const std::vector<FrameInfo>& keyFrames =
@@ -409,84 +408,76 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions(
409408 }
410409}
411410
412- void VideoDecoder::addVideoStream (
411+ void VideoDecoder::addStream (
413412 int streamIndex,
414- const VideoStreamOptions& videoStreamOptions) {
413+ AVMediaType mediaType,
414+ const torch::Device& device,
415+ std::optional<int > ffmpegThreadCount) {
415416 TORCH_CHECK (
416417 activeStreamIndex_ == NO_ACTIVE_STREAM,
417418 " Can only add one single stream." );
419+ TORCH_CHECK (
420+ mediaType == AVMEDIA_TYPE_VIDEO || mediaType == AVMEDIA_TYPE_AUDIO,
421+ " Can only add video or audio streams." );
418422 TORCH_CHECK (formatContext_.get () != nullptr );
419423
420424 AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr ;
421425
422426 activeStreamIndex_ = av_find_best_stream (
423- formatContext_.get (), AVMEDIA_TYPE_VIDEO, streamIndex, -1 , &avCodec, 0 );
427+ formatContext_.get (), mediaType, streamIndex, -1 , &avCodec, 0 );
428+
424429 if (activeStreamIndex_ < 0 ) {
425- throw std::invalid_argument (" No valid stream found in input file." );
430+ throw std::invalid_argument (
431+ " No valid stream found in input file. Is " +
432+ std::to_string (streamIndex) + " of the desired media type?" );
426433 }
434+
427435 TORCH_CHECK (avCodec != nullptr );
428436
429437 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
430438 streamInfo.streamIndex = activeStreamIndex_;
431439 streamInfo.timeBase = formatContext_->streams [activeStreamIndex_]->time_base ;
432440 streamInfo.stream = formatContext_->streams [activeStreamIndex_];
441+ streamInfo.avMediaType = mediaType;
433442
434- if (streamInfo.stream ->codecpar ->codec_type != AVMEDIA_TYPE_VIDEO) {
435- throw std::invalid_argument (
436- " Stream with index " + std::to_string (activeStreamIndex_) +
437- " is not a video stream." );
438- }
439-
440- if (videoStreamOptions.device .type () == torch::kCUDA ) {
443+ // This should never happen, checking just to be safe.
444+ TORCH_CHECK (
445+ streamInfo.stream ->codecpar ->codec_type == mediaType,
446+ " FFmpeg found stream with index " ,
447+ activeStreamIndex_,
448+ " which is of the wrong media type." );
449+
450+ // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
451+ // addStream() which is supposed to be generic
452+ if (mediaType == AVMEDIA_TYPE_VIDEO && device.type () == torch::kCUDA ) {
441453 avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream (
442- findCudaCodec (
443- videoStreamOptions.device , streamInfo.stream ->codecpar ->codec_id )
454+ findCudaCodec (device, streamInfo.stream ->codecpar ->codec_id )
444455 .value_or (avCodec));
445456 }
446457
447- StreamMetadata& streamMetadata =
448- containerMetadata_.allStreamMetadata [activeStreamIndex_];
449- if (seekMode_ == SeekMode::approximate &&
450- !streamMetadata.averageFps .has_value ()) {
451- throw std::runtime_error (
452- " Seek mode is approximate, but stream " +
453- std::to_string (activeStreamIndex_) +
454- " does not have an average fps in its metadata." );
455- }
456-
457458 AVCodecContext* codecContext = avcodec_alloc_context3 (avCodec);
458459 TORCH_CHECK (codecContext != nullptr );
459- codecContext->thread_count = videoStreamOptions.ffmpegThreadCount .value_or (0 );
460460 streamInfo.codecContext .reset (codecContext);
461461
462462 int retVal = avcodec_parameters_to_context (
463463 streamInfo.codecContext .get (), streamInfo.stream ->codecpar );
464464 TORCH_CHECK_EQ (retVal, AVSUCCESS);
465465
466- if (videoStreamOptions.device .type () == torch::kCPU ) {
467- // No more initialization needed for CPU.
468- } else if (videoStreamOptions.device .type () == torch::kCUDA ) {
469- initializeContextOnCuda (videoStreamOptions.device , codecContext);
470- } else {
471- TORCH_CHECK (
472- false , " Invalid device type: " + videoStreamOptions.device .str ());
466+ streamInfo.codecContext ->thread_count = ffmpegThreadCount.value_or (0 );
467+
468+ // TODO_CODE_QUALITY same as above.
469+ if (mediaType == AVMEDIA_TYPE_VIDEO && device.type () == torch::kCUDA ) {
470+ initializeContextOnCuda (device, codecContext);
473471 }
474- streamInfo.videoStreamOptions = videoStreamOptions;
475472
476473 retVal = avcodec_open2 (streamInfo.codecContext .get (), avCodec, nullptr );
477474 if (retVal < AVSUCCESS) {
478475 throw std::invalid_argument (getFFMPEGErrorStringFromErrorCode (retVal));
479476 }
480477
481478 codecContext->time_base = streamInfo.stream ->time_base ;
482-
483- containerMetadata_.allStreamMetadata [activeStreamIndex_].width =
484- codecContext->width ;
485- containerMetadata_.allStreamMetadata [activeStreamIndex_].height =
486- codecContext->height ;
487- auto codedId = codecContext->codec_id ;
488479 containerMetadata_.allStreamMetadata [activeStreamIndex_].codecName =
489- std::string (avcodec_get_name (codedId ));
480+ std::string (avcodec_get_name (codecContext-> codec_id ));
490481
491482 // We will only need packets from the active stream, so we tell FFmpeg to
492483 // discard packets from the other streams. Note that av_read_frame() may still
@@ -497,6 +488,38 @@ void VideoDecoder::addVideoStream(
497488 formatContext_->streams [i]->discard = AVDISCARD_ALL;
498489 }
499490 }
491+ }
492+
493+ void VideoDecoder::addVideoStream (
494+ int streamIndex,
495+ const VideoStreamOptions& videoStreamOptions) {
496+ TORCH_CHECK (
497+ videoStreamOptions.device .type () == torch::kCPU ||
498+ videoStreamOptions.device .type () == torch::kCUDA ,
499+ " Invalid device type: " + videoStreamOptions.device .str ());
500+
501+ addStream (
502+ streamIndex,
503+ AVMEDIA_TYPE_VIDEO,
504+ videoStreamOptions.device ,
505+ videoStreamOptions.ffmpegThreadCount );
506+
507+ auto & streamMetadata =
508+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
509+
510+ if (seekMode_ == SeekMode::approximate &&
511+ !streamMetadata.averageFps .has_value ()) {
512+ throw std::runtime_error (
513+ " Seek mode is approximate, but stream " +
514+ std::to_string (activeStreamIndex_) +
515+ " does not have an average fps in its metadata." );
516+ }
517+
518+ auto & streamInfo = streamInfos_[activeStreamIndex_];
519+ streamInfo.videoStreamOptions = videoStreamOptions;
520+
521+ streamMetadata.width = streamInfo.codecContext ->width ;
522+ streamMetadata.height = streamInfo.codecContext ->height ;
500523
501524 // By default, we want to use swscale for color conversion because it is
502525 // faster. However, it has width requirements, so we may need to fall back
@@ -505,7 +528,7 @@ void VideoDecoder::addVideoStream(
505528 // swscale's width requirements to be violated. We don't expose the ability to
506529 // choose color conversion library publicly; we only use this ability
507530 // internally.
508- int width = videoStreamOptions.width .value_or (codecContext->width );
531+ int width = videoStreamOptions.width .value_or (streamInfo. codecContext ->width );
509532
510533 // swscale requires widths to be multiples of 32:
511534 // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
@@ -518,6 +541,21 @@ void VideoDecoder::addVideoStream(
518541 videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
519542}
520543
544+ void VideoDecoder::addAudioStream (int streamIndex) {
545+ TORCH_CHECK (
546+ seekMode_ == SeekMode::approximate,
547+ " seek_mode must be 'approximate' for audio streams." );
548+
549+ addStream (streamIndex, AVMEDIA_TYPE_AUDIO);
550+
551+ auto & streamInfo = streamInfos_[activeStreamIndex_];
552+ auto & streamMetadata =
553+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
554+ streamMetadata.sampleRate =
555+ static_cast <int64_t >(streamInfo.codecContext ->sample_rate );
556+ streamMetadata.numChannels = getNumChannels (streamInfo.codecContext );
557+ }
558+
521559// --------------------------------------------------------------------------
522560// HIGH-LEVEL DECODING ENTRY-POINTS
523561// --------------------------------------------------------------------------
@@ -546,7 +584,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
546584VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal (
547585 int64_t frameIndex,
548586 std::optional<torch::Tensor> preAllocatedOutputTensor) {
549- validateActiveStream ();
587+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
550588
551589 const auto & streamInfo = streamInfos_[activeStreamIndex_];
552590 const auto & streamMetadata =
@@ -560,7 +598,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
560598
561599VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices (
562600 const std::vector<int64_t >& frameIndices) {
563- validateActiveStream ();
601+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
564602
565603 auto indicesAreSorted =
566604 std::is_sorted (frameIndices.begin (), frameIndices.end ());
@@ -619,7 +657,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices(
619657
620658VideoDecoder::FrameBatchOutput
621659VideoDecoder::getFramesInRange (int64_t start, int64_t stop, int64_t step) {
622- validateActiveStream ();
660+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
623661
624662 const auto & streamMetadata =
625663 containerMetadata_.allStreamMetadata [activeStreamIndex_];
@@ -690,7 +728,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
690728
691729VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt (
692730 const std::vector<double >& timestamps) {
693- validateActiveStream ();
731+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
694732
695733 const auto & streamMetadata =
696734 containerMetadata_.allStreamMetadata [activeStreamIndex_];
@@ -721,7 +759,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
721759VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange (
722760 double startSeconds,
723761 double stopSeconds) {
724- validateActiveStream ();
762+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
725763
726764 const auto & streamMetadata =
727765 containerMetadata_.allStreamMetadata [activeStreamIndex_];
@@ -860,7 +898,7 @@ bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
860898// AVFormatContext if it is needed. We can skip seeking in certain cases. See
861899// the comment of canWeAvoidSeeking() for details.
862900void VideoDecoder::maybeSeekToBeforeDesiredPts () {
863- validateActiveStream ();
901+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
864902 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
865903
866904 int64_t desiredPts =
@@ -907,7 +945,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
907945
908946VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
909947 std::function<bool (AVFrame*)> filterFunction) {
910- validateActiveStream ();
948+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
911949
912950 resetDecodeStats ();
913951
@@ -1587,7 +1625,8 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) {
15871625// VALIDATION UTILS
15881626// --------------------------------------------------------------------------
15891627
1590- void VideoDecoder::validateActiveStream () {
1628+ void VideoDecoder::validateActiveStream (
1629+ std::optional<AVMediaType> avMediaType) {
15911630 auto errorMsg =
15921631 " Provided stream index=" + std::to_string (activeStreamIndex_) +
15931632 " was not previously added." ;
@@ -1601,6 +1640,14 @@ void VideoDecoder::validateActiveStream() {
16011640 " Invalid stream index=" + std::to_string (activeStreamIndex_) +
16021641 " ; valid indices are in the range [0, " +
16031642 std::to_string (allStreamMetadataSize) + " )." );
1643+
1644+ if (avMediaType.has_value ()) {
1645+ TORCH_CHECK (
1646+ streamInfos_[activeStreamIndex_].avMediaType == avMediaType.value (),
1647+ " The method you called isn't supported. " ,
1648+ " If you're seeing this error, you are probably trying to call an " ,
1649+ " unsupported method on an audio stream." );
1650+ }
16041651}
16051652
16061653void VideoDecoder::validateScannedAllStreams (const std::string& msg) {
@@ -1648,7 +1695,7 @@ void VideoDecoder::resetDecodeStats() {
16481695}
16491696
16501697double VideoDecoder::getPtsSecondsForFrame (int64_t frameIndex) {
1651- validateActiveStream ();
1698+ validateActiveStream (AVMEDIA_TYPE_VIDEO );
16521699 validateScannedAllStreams (" getPtsSecondsForFrame" );
16531700
16541701 const auto & streamInfo = streamInfos_[activeStreamIndex_];
0 commit comments