@@ -109,7 +109,7 @@ AudioEncoder::AudioEncoder(
109109 int sampleRate,
110110 std::string_view fileName,
111111 const AudioStreamOptions& audioStreamOptions)
112- : samples_(validateSamples(samples)) {
112+ : samples_(validateSamples(samples)), inSampleRate_(sampleRate) {
113113 setFFmpegLogLevel ();
114114 AVFormatContext* avFormatContext = nullptr ;
115115 int status = avformat_alloc_output_context2 (
@@ -132,7 +132,7 @@ AudioEncoder::AudioEncoder(
132132 " , make sure it's a valid path? " ,
133133 getFFMPEGErrorStringFromErrorCode (status));
134134
135- initializeEncoder (sampleRate, audioStreamOptions);
135+ initializeEncoder (audioStreamOptions);
136136}
137137
138138AudioEncoder::AudioEncoder (
@@ -142,6 +142,7 @@ AudioEncoder::AudioEncoder(
142142 std::unique_ptr<AVIOToTensorContext> avioContextHolder,
143143 const AudioStreamOptions& audioStreamOptions)
144144 : samples_(validateSamples(samples)),
145+ inSampleRate_ (sampleRate),
145146 avioContextHolder_(std::move(avioContextHolder)) {
146147 setFFmpegLogLevel ();
147148 AVFormatContext* avFormatContext = nullptr ;
@@ -159,11 +160,10 @@ AudioEncoder::AudioEncoder(
159160
160161 avFormatContext_->pb = avioContextHolder_->getAVIOContext ();
161162
162- initializeEncoder (sampleRate, audioStreamOptions);
163+ initializeEncoder (audioStreamOptions);
163164}
164165
165166void AudioEncoder::initializeEncoder (
166- int sampleRate,
167167 const AudioStreamOptions& audioStreamOptions) {
168168 // We use the AVFormatContext's default codec for that
169169 // specific format/container.
@@ -191,8 +191,9 @@ void AudioEncoder::initializeEncoder(
191191 // not related to the input sampes.
192192 setDefaultChannelLayout (avCodecContext_, outNumChannels_);
193193
194- validateSampleRate (*avCodec, sampleRate);
195- avCodecContext_->sample_rate = sampleRate;
194+ outSampleRate_ = audioStreamOptions.sampleRate .value_or (inSampleRate_);
195+ validateSampleRate (*avCodec, outSampleRate_);
196+ avCodecContext_->sample_rate = outSampleRate_;
196197
197198 // Input samples are expected to be FLTP. Not all encoders support FLTP, so we
198199 // may need to convert the samples into a supported output sample format,
@@ -217,6 +218,21 @@ void AudioEncoder::initializeEncoder(
217218 " avcodec_parameters_from_context failed: " ,
218219 getFFMPEGErrorStringFromErrorCode (status));
219220 streamIndex_ = avStream->index ;
221+
222+ // If sample rate conversion is needed and the encoder doesn't support
223+ // variable frame size, we need to create an intermediate FIFO. See
224+ // [Encoding loop, sample rate conversion and FIFO].
225+ if (((avCodec->capabilities & AV_CODEC_CAP_VARIABLE_FRAME_SIZE) == 0 ) &&
226+ (inSampleRate_ != outSampleRate_)) {
227+ // frame_size * 2 is a decent default size. FFmpeg automatically
228+ // re-allocates the fifo if more space is needed.
229+ auto avAudioFifo = av_audio_fifo_alloc (
230+ avCodecContext_->sample_fmt ,
231+ outNumChannels_,
232+ avCodecContext_->frame_size * 2 );
233+ TORCH_CHECK (avAudioFifo != nullptr , " Couldn't create AVAudioFifo." );
234+ avAudioFifo_.reset (avAudioFifo);
235+ }
220236}
221237
222238torch::Tensor AudioEncoder::encodeToTensor () {
@@ -234,24 +250,15 @@ void AudioEncoder::encode() {
234250 TORCH_CHECK (!encodeWasCalled_, " Cannot call encode() twice." );
235251 encodeWasCalled_ = true ;
236252
237- UniqueAVFrame avFrame (av_frame_alloc ());
238- TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
239253 // Default to 256 like in torchaudio
240254 int numSamplesAllocatedPerFrame =
241255 avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256 ;
242- avFrame->nb_samples = numSamplesAllocatedPerFrame;
243- avFrame->format = AV_SAMPLE_FMT_FLTP;
244- avFrame->sample_rate = avCodecContext_->sample_rate ;
256+ UniqueAVFrame avFrame = allocateAVFrame (
257+ numSamplesAllocatedPerFrame,
258+ inSampleRate_,
259+ static_cast <int >(samples_.sizes ()[0 ]),
260+ AV_SAMPLE_FMT_FLTP);
245261 avFrame->pts = 0 ;
246- // We set the channel layout of the frame to the default layout corresponding
247- // to the input samples' number of channels
248- setDefaultChannelLayout (avFrame, static_cast <int >(samples_.sizes ()[0 ]));
249-
250- auto status = av_frame_get_buffer (avFrame.get (), 0 );
251- TORCH_CHECK (
252- status == AVSUCCESS,
253- " Couldn't allocate avFrame's buffers: " ,
254- getFFMPEGErrorStringFromErrorCode (status));
255262
256263 AutoAVPacket autoAVPacket;
257264
@@ -261,19 +268,13 @@ void AudioEncoder::encode() {
261268 int numBytesPerSample = static_cast <int >(samples_.element_size ());
262269 int numBytesPerChannel = numSamples * numBytesPerSample;
263270
264- status = avformat_write_header (avFormatContext_.get (), nullptr );
271+ auto status = avformat_write_header (avFormatContext_.get (), nullptr );
265272 TORCH_CHECK (
266273 status == AVSUCCESS,
267274 " Error in avformat_write_header: " ,
268275 getFFMPEGErrorStringFromErrorCode (status));
269276
270277 while (numEncodedSamples < numSamples) {
271- status = av_frame_make_writable (avFrame.get ());
272- TORCH_CHECK (
273- status == AVSUCCESS,
274- " Couldn't make AVFrame writable: " ,
275- getFFMPEGErrorStringFromErrorCode (status));
276-
277278 int numSamplesToEncode =
278279 std::min (numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
279280 int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
@@ -294,10 +295,9 @@ void AudioEncoder::encode() {
294295 avFrame->nb_samples = numSamplesToEncode;
295296
296297 UniqueAVFrame convertedAVFrame = maybeConvertAVFrame (avFrame);
297- encodeInnerLoop (autoAVPacket, convertedAVFrame);
298+ encodeFrameThroughFifo (autoAVPacket, convertedAVFrame);
298299
299300 numEncodedSamples += numSamplesToEncode;
300- avFrame->pts += static_cast <int64_t >(numSamplesToEncode);
301301 }
302302 TORCH_CHECK (numEncodedSamples == numSamples, " Hmmmmmm something went wrong." );
303303
@@ -313,7 +313,8 @@ void AudioEncoder::encode() {
313313UniqueAVFrame AudioEncoder::maybeConvertAVFrame (const UniqueAVFrame& avFrame) {
314314 if (static_cast <AVSampleFormat>(avFrame->format ) ==
315315 avCodecContext_->sample_fmt &&
316- getNumChannels (avFrame) == outNumChannels_) {
316+ getNumChannels (avFrame) == outNumChannels_ &&
317+ avFrame->sample_rate == outSampleRate_) {
317318 // Note: the clone references the same underlying data, it's a cheap copy.
318319 return UniqueAVFrame (av_frame_clone (avFrame.get ()));
319320 }
@@ -322,31 +323,99 @@ UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) {
322323 swrContext_.reset (createSwrContext (
323324 static_cast <AVSampleFormat>(avFrame->format ),
324325 avCodecContext_->sample_fmt ,
325- avFrame->sample_rate , // No sample rate conversion
326326 avFrame->sample_rate ,
327+ outSampleRate_,
327328 avFrame,
328329 outNumChannels_));
329330 }
330331 UniqueAVFrame convertedAVFrame = convertAudioAVFrameSamples (
331332 swrContext_,
332333 avFrame,
333334 avCodecContext_->sample_fmt ,
334- avFrame-> sample_rate , // No sample rate conversion
335+ outSampleRate_,
335336 outNumChannels_);
337+
338+ if (avFrame->sample_rate == outSampleRate_) {
339+ TORCH_CHECK (
340+ convertedAVFrame->nb_samples == avFrame->nb_samples ,
341+ " convertedAVFrame->nb_samples=" ,
342+ convertedAVFrame->nb_samples ,
343+ " differs from " ,
344+ " avFrame->nb_samples=" ,
345+ avFrame->nb_samples ,
346+ " This is unexpected, please report on the TorchCodec bug tracker." );
347+ }
348+ return convertedAVFrame;
349+ }
350+
351+ void AudioEncoder::encodeFrameThroughFifo (
352+ AutoAVPacket& autoAVPacket,
353+ const UniqueAVFrame& avFrame,
354+ // flushFifo is only set to true in maybeFlushSwrBuffers(), i.e. at the very
355+ // end of the encoding process when we're flushing buffers. We also want to
356+ // flush the FIFO so as to not leave any remaining samples in it.
357+ bool flushFifo) {
358+ if (avAudioFifo_ == nullptr ) {
359+ encodeFrame (autoAVPacket, avFrame);
360+ return ;
361+ }
362+ int numSamplesWritten = av_audio_fifo_write (
363+ avAudioFifo_.get (),
364+ reinterpret_cast <void **>(avFrame->data ),
365+ avFrame->nb_samples );
336366 TORCH_CHECK (
337- convertedAVFrame->nb_samples == avFrame->nb_samples ,
338- " convertedAVFrame->nb_samples=" ,
339- convertedAVFrame->nb_samples ,
340- " differs from " ,
341- " avFrame->nb_samples=" ,
367+ numSamplesWritten == avFrame->nb_samples ,
368+ " Tried to write " ,
342369 avFrame->nb_samples ,
343- " This is unexpected, please report on the TorchCodec bug tracker." );
344- return convertedAVFrame;
370+ " samples, but only wrote " ,
371+ numSamplesWritten);
372+
373+ UniqueAVFrame newavFrame = allocateAVFrame (
374+ avCodecContext_->frame_size ,
375+ outSampleRate_,
376+ outNumChannels_,
377+ avCodecContext_->sample_fmt );
378+
379+ // Explaining the while bound:
380+ // - if we're not flushing the FIFO, i.e. in most cases, we want to pull
381+ // exactly `frame_size` samples from the FIFO, so we have to stop before it
382+ // contains less than `frame_size` samples.
383+ // - if we're flushing the FIFO, we want to read from the FIFO until the very
384+ // last sample it contains.
385+ //
386+ // In both cases, for as long as we can, we're trying to pull exatly
387+ // `frame_size` samples from the FIFO and send each `frame_size`-sized avFrame
388+ // to encodeFrame(). Only the very last avFrame of the encoding process is
389+ // allowed to contained less than frame_size samples. That only happens when
390+ // flushFifo is true.
391+ while (av_audio_fifo_size (avAudioFifo_.get ()) >=
392+ (flushFifo ? 1 : avCodecContext_->frame_size )) {
393+ int samplesToRead = std::min (
394+ av_audio_fifo_size (avAudioFifo_.get ()), newavFrame->nb_samples );
395+ int numSamplesRead = av_audio_fifo_read (
396+ avAudioFifo_.get (),
397+ reinterpret_cast <void **>(newavFrame->data ),
398+ samplesToRead);
399+ TORCH_CHECK (
400+ numSamplesRead == samplesToRead,
401+ " Tried to read " ,
402+ samplesToRead,
403+ " samples, but only read " ,
404+ numSamplesRead);
405+
406+ newavFrame->nb_samples = numSamplesRead;
407+ encodeFrame (autoAVPacket, newavFrame);
408+ }
345409}
346410
347- void AudioEncoder::encodeInnerLoop (
411+ void AudioEncoder::encodeFrame (
348412 AutoAVPacket& autoAVPacket,
349413 const UniqueAVFrame& avFrame) {
414+ if (avFrame != nullptr ) {
415+ avFrame->pts = lastEncodedAVFramePts_;
416+ lastEncodedAVFramePts_ += avFrame->nb_samples ;
417+ }
418+
350419 auto status = avcodec_send_frame (avCodecContext_.get (), avFrame.get ());
351420 TORCH_CHECK (
352421 status == AVSUCCESS,
@@ -385,11 +454,41 @@ void AudioEncoder::encodeInnerLoop(
385454 }
386455}
387456
457+ void AudioEncoder::maybeFlushSwrBuffers (AutoAVPacket& autoAVPacket) {
458+ // Similar to the decoder's method with the same name, but for encoding this
459+ // time. That is, when sample conversion is involved, libswresample may have
460+ // buffered some samples that we now need to flush and send to the encoder.
461+ if (swrContext_ == nullptr && inSampleRate_ == outSampleRate_) {
462+ return ;
463+ }
464+ TORCH_CHECK (
465+ swrContext_ != nullptr ,
466+ " swrContext is null, but sample rate conversion is needed. " ,
467+ " This is unexpected, please report on the TorchCodec bug tracker." );
468+
469+ int numRemainingSamples = // this is an upper bound
470+ swr_get_out_samples (swrContext_.get (), 0 );
471+ if (numRemainingSamples == 0 ) {
472+ return ;
473+ }
474+
475+ UniqueAVFrame avFrame = allocateAVFrame (
476+ numRemainingSamples,
477+ outSampleRate_,
478+ outNumChannels_,
479+ avCodecContext_->sample_fmt );
480+ int actualNumRemainingSamples = swr_convert (
481+ swrContext_.get (), avFrame->data , avFrame->nb_samples , NULL , 0 );
482+ avFrame->nb_samples = actualNumRemainingSamples;
483+
484+ // We're potentially sending avFrame through the FIFO (if it exists), in which
485+ // case we also want to flush the FIFO itself.
486+ encodeFrameThroughFifo (autoAVPacket, avFrame, /* flushFifo=*/ true );
487+ }
488+
388489void AudioEncoder::flushBuffers () {
389- // We flush the main FFmpeg buffers, but not swresample buffers. Flushing
390- // swresample is only necessary when converting sample rates, which we don't
391- // do for encoding.
392490 AutoAVPacket autoAVPacket;
393- encodeInnerLoop (autoAVPacket, UniqueAVFrame (nullptr ));
491+ maybeFlushSwrBuffers (autoAVPacket);
492+ encodeFrame (autoAVPacket, UniqueAVFrame (nullptr ));
394493}
395494} // namespace facebook::torchcodec
0 commit comments