@@ -580,14 +580,18 @@ void VideoDecoder::addVideoStream(
580580 videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
581581}
582582
583- void VideoDecoder::addAudioStream (int streamIndex) {
583+ void VideoDecoder::addAudioStream (
584+ int streamIndex,
585+ const AudioStreamOptions& audioStreamOptions) {
584586 TORCH_CHECK (
585587 seekMode_ == SeekMode::approximate,
586588 " seek_mode must be 'approximate' for audio streams." );
587589
588590 addStream (streamIndex, AVMEDIA_TYPE_AUDIO);
589591
590592 auto & streamInfo = streamInfos_[activeStreamIndex_];
593+ streamInfo.audioStreamOptions = audioStreamOptions;
594+
591595 auto & streamMetadata =
592596 containerMetadata_.allStreamMetadata [activeStreamIndex_];
593597 streamMetadata.sampleRate =
@@ -947,6 +951,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
947951 (stopPts <= lastDecodedAvFrameEnd);
948952 }
949953
954+ auto lastSamples = maybeFlushSwrBuffers ();
955+ if (lastSamples.has_value ()) {
956+ frames.push_back (*lastSamples);
957+ }
958+
950959 return AudioFramesOutput{torch::cat (frames, 1 ), firstFramePtsSeconds};
951960}
952961
@@ -1200,8 +1209,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12001209 getDuration (avFrame),
12011210 formatContext_->streams [activeStreamIndex_]->time_base );
12021211 if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1203- convertAudioAVFrameToFrameOutputOnCPU (
1204- avFrame, frameOutput, preAllocatedOutputTensor);
1212+ convertAudioAVFrameToFrameOutputOnCPU (avFrame, frameOutput);
12051213 } else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
12061214 convertAVFrameToFrameOutputOnCPU (
12071215 avFrame, frameOutput, preAllocatedOutputTensor);
@@ -1379,24 +1387,30 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13791387
13801388void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
13811389 UniqueAVFrame& srcAVFrame,
1382- FrameOutput& frameOutput,
1383- std::optional<torch::Tensor> preAllocatedOutputTensor) {
1384- TORCH_CHECK (
1385- !preAllocatedOutputTensor.has_value (),
1386- " pre-allocated audio tensor not supported yet." );
1387-
1390+ FrameOutput& frameOutput) {
13881391 AVSampleFormat sourceSampleFormat =
13891392 static_cast <AVSampleFormat>(srcAVFrame->format );
13901393 AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
13911394
1395+ int sourceSampleRate = srcAVFrame->sample_rate ;
1396+ int desiredSampleRate =
1397+ streamInfos_[activeStreamIndex_].audioStreamOptions .sampleRate .value_or (
1398+ sourceSampleRate);
1399+
1400+ bool mustConvert =
1401+ (sourceSampleFormat != desiredSampleFormat ||
1402+ sourceSampleRate != desiredSampleRate);
1403+
13921404 UniqueAVFrame convertedAVFrame;
1393- if (sourceSampleFormat != desiredSampleFormat) {
1394- convertedAVFrame = convertAudioAVFrameSampleFormat (
1395- srcAVFrame, sourceSampleFormat, desiredSampleFormat);
1405+ if (mustConvert) {
1406+ convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate (
1407+ srcAVFrame,
1408+ sourceSampleFormat,
1409+ desiredSampleFormat,
1410+ sourceSampleRate,
1411+ desiredSampleRate);
13961412 }
1397- const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1398- ? convertedAVFrame
1399- : srcAVFrame;
1413+ const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
14001414
14011415 AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
14021416 TORCH_CHECK (
@@ -1419,55 +1433,110 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
14191433 memcpy (
14201434 outputChannelData, avFrame->extended_data [channel], numBytesPerChannel);
14211435 }
1436+
14221437 frameOutput.data = outputData;
14231438}
14241439
1425- UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat (
1426- const UniqueAVFrame& avFrame ,
1440+ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate (
1441+ const UniqueAVFrame& srcAVFrame ,
14271442 AVSampleFormat sourceSampleFormat,
1428- AVSampleFormat desiredSampleFormat
1429-
1430- ) {
1443+ AVSampleFormat desiredSampleFormat,
1444+ int sourceSampleRate,
1445+ int desiredSampleRate ) {
14311446 auto & streamInfo = streamInfos_[activeStreamIndex_];
1432- const auto & streamMetadata =
1433- containerMetadata_.allStreamMetadata [activeStreamIndex_];
1434- int sampleRate = static_cast <int >(streamMetadata.sampleRate .value ());
14351447
14361448 if (!streamInfo.swrContext ) {
14371449 createSwrContext (
1438- streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1450+ streamInfo,
1451+ sourceSampleFormat,
1452+ desiredSampleFormat,
1453+ sourceSampleRate,
1454+ desiredSampleRate);
14391455 }
14401456
14411457 UniqueAVFrame convertedAVFrame (av_frame_alloc ());
14421458 TORCH_CHECK (
14431459 convertedAVFrame,
14441460 " Could not allocate frame for sample format conversion." );
14451461
1446- setChannelLayout (convertedAVFrame, avFrame );
1462+ setChannelLayout (convertedAVFrame, srcAVFrame );
14471463 convertedAVFrame->format = static_cast <int >(desiredSampleFormat);
1448- convertedAVFrame->sample_rate = avFrame->sample_rate ;
1449- convertedAVFrame->nb_samples = avFrame->nb_samples ;
1464+ convertedAVFrame->sample_rate = desiredSampleRate;
1465+ if (sourceSampleRate != desiredSampleRate) {
1466+ // Note that this is an upper bound on the number of output samples.
1467+ // `swr_convert()` will likely not fill convertedAVFrame with that many
1468+ // samples if sample rate conversion is needed. It will buffer the last few
1469+ // ones because those require future samples. That's also why we reset
1470+ // nb_samples after the call to `swr_convert()`.
1471+ // We could also use `swr_get_out_samples()` to determine the number of
1472+ // output samples, but empirically `av_rescale_rnd()` seems to provide a
1473+ // tighter bound.
1474+ convertedAVFrame->nb_samples = av_rescale_rnd (
1475+ swr_get_delay (streamInfo.swrContext .get (), sourceSampleRate) +
1476+ srcAVFrame->nb_samples ,
1477+ desiredSampleRate,
1478+ sourceSampleRate,
1479+ AV_ROUND_UP);
1480+ } else {
1481+ convertedAVFrame->nb_samples = srcAVFrame->nb_samples ;
1482+ }
14501483
14511484 auto status = av_frame_get_buffer (convertedAVFrame.get (), 0 );
14521485 TORCH_CHECK (
14531486 status == AVSUCCESS,
14541487 " Could not allocate frame buffers for sample format conversion: " ,
14551488 getFFMPEGErrorStringFromErrorCode (status));
14561489
1457- auto numSampleConverted = swr_convert (
1490+ auto numConvertedSamples = swr_convert (
14581491 streamInfo.swrContext .get (),
14591492 convertedAVFrame->data ,
14601493 convertedAVFrame->nb_samples ,
1461- static_cast <const uint8_t **>(const_cast <const uint8_t **>(avFrame->data )),
1462- avFrame->nb_samples );
1494+ static_cast <const uint8_t **>(
1495+ const_cast <const uint8_t **>(srcAVFrame->data )),
1496+ srcAVFrame->nb_samples );
14631497 TORCH_CHECK (
1464- numSampleConverted > 0 ,
1498+ numConvertedSamples > 0 ,
14651499 " Error in swr_convert: " ,
1466- getFFMPEGErrorStringFromErrorCode (numSampleConverted));
1500+ getFFMPEGErrorStringFromErrorCode (numConvertedSamples));
1501+
1502+ // See comment above about nb_samples
1503+ convertedAVFrame->nb_samples = numConvertedSamples;
14671504
14681505 return convertedAVFrame;
14691506}
14701507
1508+ std::optional<torch::Tensor> VideoDecoder::maybeFlushSwrBuffers () {
1509+ // When sample rate conversion is involved, swresample buffers some of the
1510+ // samples in-between calls to swr_convert (see the libswresample docs).
1511+ // That's because the last few samples in a given frame require future samples
1512+ // from the next frame to be properly converted. This function flushes out the
1513+ // samples that are stored in swresample's buffers.
1514+ auto & streamInfo = streamInfos_[activeStreamIndex_];
1515+ if (!streamInfo.swrContext ) {
1516+ return std::nullopt ;
1517+ }
1518+ auto numRemainingSamples = // this is an upper bound
1519+ swr_get_out_samples (streamInfo.swrContext .get (), 0 );
1520+
1521+ if (numRemainingSamples == 0 ) {
1522+ return std::nullopt ;
1523+ }
1524+
1525+ torch::Tensor lastSamples = torch::empty (
1526+ {getNumChannels (streamInfo.codecContext ), numRemainingSamples},
1527+ torch::kFloat32 );
1528+ uint8_t * lastSamplesData = static_cast <uint8_t *>(lastSamples.data_ptr ());
1529+
1530+ auto actualNumRemainingSamples = swr_convert (
1531+ streamInfo.swrContext .get (),
1532+ &lastSamplesData,
1533+ numRemainingSamples,
1534+ nullptr ,
1535+ 0 );
1536+ return lastSamples.narrow (
1537+ /* dim=*/ 1 , /* start=*/ 0 , /* length=*/ actualNumRemainingSamples);
1538+ }
1539+
14711540// --------------------------------------------------------------------------
14721541// OUTPUT ALLOCATION AND SHAPE CONVERSION
14731542// --------------------------------------------------------------------------
@@ -1703,14 +1772,16 @@ void VideoDecoder::createSwsContext(
17031772
17041773void VideoDecoder::createSwrContext (
17051774 StreamInfo& streamInfo,
1706- int sampleRate,
17071775 AVSampleFormat sourceSampleFormat,
1708- AVSampleFormat desiredSampleFormat) {
1776+ AVSampleFormat desiredSampleFormat,
1777+ int sourceSampleRate,
1778+ int desiredSampleRate) {
17091779 auto swrContext = allocateSwrContext (
17101780 streamInfo.codecContext ,
1711- sampleRate,
17121781 sourceSampleFormat,
1713- desiredSampleFormat);
1782+ desiredSampleFormat,
1783+ sourceSampleRate,
1784+ desiredSampleRate);
17141785
17151786 auto status = swr_init (swrContext);
17161787 TORCH_CHECK (
0 commit comments