88#include < cstdint>
99#include < cstdio>
1010#include < iostream>
11+ #include < limits>
1112#include < sstream>
1213#include < stdexcept>
1314#include < string_view>
@@ -552,7 +553,8 @@ void VideoDecoder::addAudioStream(int streamIndex) {
552553 containerMetadata_.allStreamMetadata [activeStreamIndex_];
553554 streamMetadata.sampleRate =
554555 static_cast <int64_t >(streamInfo.codecContext ->sample_rate );
555- streamMetadata.numChannels = getNumChannels (streamInfo.codecContext );
556+ streamMetadata.numChannels =
557+ static_cast <int64_t >(getNumChannels (streamInfo.codecContext ));
556558}
557559
558560// --------------------------------------------------------------------------
@@ -567,6 +569,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
567569
568570VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal (
569571 std::optional<torch::Tensor> preAllocatedOutputTensor) {
572+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
570573 AVFrameStream avFrameStream = decodeAVFrame (
571574 [this ](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
572575 return convertAVFrameToFrameOutput (avFrameStream, preAllocatedOutputTensor);
@@ -685,6 +688,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
685688}
686689
687690VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt (double seconds) {
691+ validateActiveStream (AVMEDIA_TYPE_VIDEO);
688692 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
689693 double frameStartTime =
690694 ptsToSeconds (streamInfo.lastDecodedAvFramePts , streamInfo.timeBase );
@@ -757,7 +761,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
757761 double startSeconds,
758762 double stopSeconds) {
759763 validateActiveStream (AVMEDIA_TYPE_VIDEO);
760-
761764 const auto & streamMetadata =
762765 containerMetadata_.allStreamMetadata [activeStreamIndex_];
763766 TORCH_CHECK (
@@ -835,6 +838,68 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
835838 return frameBatchOutput;
836839}
837840
841+ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio (
842+ double startSeconds,
843+ std::optional<double > stopSecondsOptional) {
844+ validateActiveStream (AVMEDIA_TYPE_AUDIO);
845+
846+ double stopSeconds =
847+ stopSecondsOptional.value_or (std::numeric_limits<double >::max ());
848+
849+ TORCH_CHECK (
850+ startSeconds <= stopSeconds,
851+ " Start seconds (" + std::to_string (startSeconds) +
852+ " ) must be less than or equal to stop seconds (" +
853+ std::to_string (stopSeconds) + " ." );
854+
855+ if (startSeconds == stopSeconds) {
856+ // For consistency with video
857+ return torch::empty ({0 });
858+ }
859+
860+ StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
861+
862+ // TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
863+ // We should remove it and seek back to the stream's beginning when needed.
864+ // See test_multiple_calls
865+ TORCH_CHECK (
866+ streamInfo.lastDecodedAvFramePts +
867+ streamInfo.lastDecodedAvFrameDuration <=
868+ secondsToClosestPts (startSeconds, streamInfo.timeBase ),
869+ " Audio decoder cannot seek backwards, or start from the last decoded frame." );
870+
871+ setCursorPtsInSeconds (startSeconds);
872+
873+ // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
874+ // cat(). This would save a copy. We know the duration of the output and the
875+ // sample rate, so in theory we know the number of output samples.
876+ std::vector<torch::Tensor> tensors;
877+
878+ auto stopPts = secondsToClosestPts (stopSeconds, streamInfo.timeBase );
879+ auto finished = false ;
880+ while (!finished) {
881+ try {
882+ AVFrameStream avFrameStream = decodeAVFrame ([this ](AVFrame* avFrame) {
883+ return cursor_ < avFrame->pts + getDuration (avFrame);
884+ });
885+ auto frameOutput = convertAVFrameToFrameOutput (avFrameStream);
886+ tensors.push_back (frameOutput.data );
887+ } catch (const EndOfFileException& e) {
888+ finished = true ;
889+ }
890+
891+ // If stopSeconds is in [begin, end] of the last decoded frame, we should
892+ // stop decoding more frames. Note that if we were to use [begin, end),
893+ // which may seem more natural, then we would decode the frame starting at
894+ // stopSeconds, which isn't what we want!
895+ auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
896+ streamInfo.lastDecodedAvFrameDuration ;
897+ finished |= (streamInfo.lastDecodedAvFramePts ) <= stopPts &&
898+ (stopPts <= lastDecodedAvFrameEnd);
899+ }
900+ return torch::cat (tensors, 1 );
901+ }
902+
838903// --------------------------------------------------------------------------
839904// SEEKING APIs
840905// --------------------------------------------------------------------------
@@ -871,6 +936,10 @@ I P P P I P P P I P P I P P I P
871936(2) is more efficient than (1) if there is an I frame between x and y.
872937*/
873938bool VideoDecoder::canWeAvoidSeeking () const {
939+ const StreamInfo& streamInfo = streamInfos_.at (activeStreamIndex_);
940+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
941+ return true ;
942+ }
874943 int64_t lastDecodedAvFramePts =
875944 streamInfos_.at (activeStreamIndex_).lastDecodedAvFramePts ;
876945 if (cursor_ < lastDecodedAvFramePts) {
@@ -897,7 +966,7 @@ bool VideoDecoder::canWeAvoidSeeking() const {
897966// AVFormatContext if it is needed. We can skip seeking in certain cases. See
898967// the comment of canWeAvoidSeeking() for details.
899968void VideoDecoder::maybeSeekToBeforeDesiredPts () {
900- validateActiveStream (AVMEDIA_TYPE_VIDEO );
969+ validateActiveStream ();
901970 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
902971
903972 decodeStats_.numSeeksAttempted ++;
@@ -942,7 +1011,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
9421011
9431012VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
9441013 std::function<bool (AVFrame*)> filterFunction) {
945- validateActiveStream (AVMEDIA_TYPE_VIDEO );
1014+ validateActiveStream ();
9461015
9471016 resetDecodeStats ();
9481017
@@ -1071,13 +1140,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10711140 AVFrame* avFrame = avFrameStream.avFrame .get ();
10721141 frameOutput.streamIndex = streamIndex;
10731142 auto & streamInfo = streamInfos_[streamIndex];
1074- TORCH_CHECK (streamInfo.stream ->codecpar ->codec_type == AVMEDIA_TYPE_VIDEO);
10751143 frameOutput.ptsSeconds = ptsToSeconds (
10761144 avFrame->pts , formatContext_->streams [streamIndex]->time_base );
10771145 frameOutput.durationSeconds = ptsToSeconds (
10781146 getDuration (avFrame), formatContext_->streams [streamIndex]->time_base );
1079- // TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
1080- if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
1147+ if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1148+ convertAudioAVFrameToFrameOutputOnCPU (
1149+ avFrameStream, frameOutput, preAllocatedOutputTensor);
1150+ } else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
10811151 convertAVFrameToFrameOutputOnCPU (
10821152 avFrameStream, frameOutput, preAllocatedOutputTensor);
10831153 } else if (streamInfo.videoStreamOptions .device .type () == torch::kCUDA ) {
@@ -1253,6 +1323,45 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
12531323 filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
12541324}
12551325
1326+ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1327+ VideoDecoder::AVFrameStream& avFrameStream,
1328+ FrameOutput& frameOutput,
1329+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
1330+ TORCH_CHECK (
1331+ !preAllocatedOutputTensor.has_value (),
1332+ " pre-allocated audio tensor not supported yet." );
1333+
1334+ const AVFrame* avFrame = avFrameStream.avFrame .get ();
1335+
1336+ auto numSamples = avFrame->nb_samples ; // per channel
1337+ auto numChannels = getNumChannels (avFrame);
1338+ torch::Tensor outputData =
1339+ torch::empty ({numChannels, numSamples}, torch::kFloat32 );
1340+
1341+ AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1342+ // TODO-AUDIO Implement all formats.
1343+ switch (format) {
1344+ case AV_SAMPLE_FMT_FLTP: {
1345+ uint8_t * outputChannelData = static_cast <uint8_t *>(outputData.data_ptr ());
1346+ auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1347+ for (auto channel = 0 ; channel < numChannels;
1348+ ++channel, outputChannelData += numBytesPerChannel) {
1349+ memcpy (
1350+ outputChannelData,
1351+ avFrame->extended_data [channel],
1352+ numBytesPerChannel);
1353+ }
1354+ break ;
1355+ }
1356+ default :
1357+ TORCH_CHECK (
1358+ false ,
1359+ " Unsupported audio format (yet!): " ,
1360+ av_get_sample_fmt_name (format));
1361+ }
1362+ frameOutput.data = outputData;
1363+ }
1364+
12561365// --------------------------------------------------------------------------
12571366// OUTPUT ALLOCATION AND SHAPE CONVERSION
12581367// --------------------------------------------------------------------------
0 commit comments