@@ -511,4 +511,283 @@ void AudioEncoder::flushBuffers() {
511511
512512 encodeFrame (autoAVPacket, UniqueAVFrame (nullptr ));
513513}
514+
515+ namespace {
516+
517+ torch::Tensor validateFrames (const torch::Tensor& frames) {
518+ TORCH_CHECK (
519+ frames.dtype () == torch::kUInt8 ,
520+ " frames must have uint8 dtype, got " ,
521+ frames.dtype ());
522+ TORCH_CHECK (
523+ frames.dim () == 4 ,
524+ " frames must have 4 dimensions (N, C, H, W), got " ,
525+ frames.dim ());
526+ TORCH_CHECK (
527+ frames.sizes ()[1 ] == 3 ,
528+ " frame must have 3 channels (R, G, B), got " ,
529+ frames.sizes ()[1 ]);
530+ // TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted
531+ return frames.contiguous ();
532+ }
533+
534+ } // namespace
535+
536+ VideoEncoder::~VideoEncoder () {
537+ if (avFormatContext_ && avFormatContext_->pb ) {
538+ avio_flush (avFormatContext_->pb );
539+ avio_close (avFormatContext_->pb );
540+ avFormatContext_->pb = nullptr ;
541+ }
542+ }
543+
544+ VideoEncoder::VideoEncoder (
545+ const torch::Tensor& frames,
546+ int frameRate,
547+ std::string_view fileName,
548+ const VideoStreamOptions& videoStreamOptions)
549+ : frames_(validateFrames(frames)), inFrameRate_(frameRate) {
550+ setFFmpegLogLevel ();
551+
552+ // Allocate output format context
553+ AVFormatContext* avFormatContext = nullptr ;
554+ int status = avformat_alloc_output_context2 (
555+ &avFormatContext, nullptr , nullptr , fileName.data ());
556+
557+ TORCH_CHECK (
558+ avFormatContext != nullptr ,
559+ " Couldn't allocate AVFormatContext. " ,
560+ " The destination file is " ,
561+ fileName,
562+ " , check the desired extension? " ,
563+ getFFMPEGErrorStringFromErrorCode (status));
564+ avFormatContext_.reset (avFormatContext);
565+
566+ status = avio_open (&avFormatContext_->pb , fileName.data (), AVIO_FLAG_WRITE);
567+ TORCH_CHECK (
568+ status >= 0 ,
569+ " avio_open failed. The destination file is " ,
570+ fileName,
571+ " , make sure it's a valid path? " ,
572+ getFFMPEGErrorStringFromErrorCode (status));
573+ // TODO-VideoEncoder: Add tests for above fileName related checks
574+
575+ initializeEncoder (videoStreamOptions);
576+ }
577+
578+ void VideoEncoder::initializeEncoder (
579+ const VideoStreamOptions& videoStreamOptions) {
580+ const AVCodec* avCodec =
581+ avcodec_find_encoder (avFormatContext_->oformat ->video_codec );
582+ TORCH_CHECK (avCodec != nullptr , " Video codec not found" );
583+
584+ AVCodecContext* avCodecContext = avcodec_alloc_context3 (avCodec);
585+ TORCH_CHECK (avCodecContext != nullptr , " Couldn't allocate codec context." );
586+ avCodecContext_.reset (avCodecContext);
587+
588+ // Set encoding options
589+ // TODO-VideoEncoder: Allow bitrate to be set
590+ std::optional<int > desiredBitRate = videoStreamOptions.bitRate ;
591+ if (desiredBitRate.has_value ()) {
592+ TORCH_CHECK (
593+ *desiredBitRate >= 0 , " bit_rate=" , *desiredBitRate, " must be >= 0." );
594+ }
595+ avCodecContext_->bit_rate = desiredBitRate.value_or (0 );
596+
597+ // Store dimension order and input pixel format
598+ // TODO-VideoEncoder: Remove assumption that tensor in NCHW format
599+ auto sizes = frames_.sizes ();
600+ inPixelFormat_ = AV_PIX_FMT_GBRP;
601+ inHeight_ = sizes[2 ];
602+ inWidth_ = sizes[3 ];
603+
604+ // Use specified dimensions or input dimensions
605+ // TODO-VideoEncoder: Allow height and width to be set
606+ outWidth_ = videoStreamOptions.width .value_or (inWidth_);
607+ outHeight_ = videoStreamOptions.height .value_or (inHeight_);
608+
609+ // Use YUV420P as default output format
610+ // TODO-VideoEncoder: Enable other pixel formats
611+ outPixelFormat_ = AV_PIX_FMT_YUV420P;
612+
613+ // Configure codec parameters
614+ avCodecContext_->codec_id = avCodec->id ;
615+ avCodecContext_->width = outWidth_;
616+ avCodecContext_->height = outHeight_;
617+ avCodecContext_->pix_fmt = outPixelFormat_;
618+ // TODO-VideoEncoder: Verify that frame_rate and time_base are correct
619+ avCodecContext_->time_base = {1 , inFrameRate_};
620+ avCodecContext_->framerate = {inFrameRate_, 1 };
621+
622+ // TODO-VideoEncoder: Allow GOP size and max B-frames to be set
623+ if (videoStreamOptions.gopSize .has_value ()) {
624+ avCodecContext_->gop_size = *videoStreamOptions.gopSize ;
625+ } else {
626+ avCodecContext_->gop_size = 12 ; // Default GOP size
627+ }
628+
629+ if (videoStreamOptions.maxBFrames .has_value ()) {
630+ avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames ;
631+ } else {
632+ avCodecContext_->max_b_frames = 0 ; // No max B-frames to reduce compression
633+ }
634+
635+ int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
636+ TORCH_CHECK (
637+ status == AVSUCCESS,
638+ " avcodec_open2 failed: " ,
639+ getFFMPEGErrorStringFromErrorCode (status));
640+
641+ AVStream* avStream = avformat_new_stream (avFormatContext_.get (), nullptr );
642+ TORCH_CHECK (avStream != nullptr , " Couldn't create new stream." );
643+
644+ // Set the stream time base to encode correct frame timestamps
645+ avStream->time_base = avCodecContext_->time_base ;
646+ status = avcodec_parameters_from_context (
647+ avStream->codecpar , avCodecContext_.get ());
648+ TORCH_CHECK (
649+ status == AVSUCCESS,
650+ " avcodec_parameters_from_context failed: " ,
651+ getFFMPEGErrorStringFromErrorCode (status));
652+ streamIndex_ = avStream->index ;
653+ }
654+
655+ void VideoEncoder::encode () {
656+ // To be on the safe side we enforce that encode() can only be called once
657+ TORCH_CHECK (!encodeWasCalled_, " Cannot call encode() twice." );
658+ encodeWasCalled_ = true ;
659+
660+ int status = avformat_write_header (avFormatContext_.get (), nullptr );
661+ TORCH_CHECK (
662+ status == AVSUCCESS,
663+ " Error in avformat_write_header: " ,
664+ getFFMPEGErrorStringFromErrorCode (status));
665+
666+ AutoAVPacket autoAVPacket;
667+ int numFrames = frames_.sizes ()[0 ];
668+ for (int i = 0 ; i < numFrames; ++i) {
669+ torch::Tensor currFrame = frames_[i];
670+ UniqueAVFrame avFrame = convertTensorToAVFrame (currFrame, i);
671+ encodeFrame (autoAVPacket, avFrame);
672+ }
673+
674+ flushBuffers ();
675+
676+ status = av_write_trailer (avFormatContext_.get ());
677+ TORCH_CHECK (
678+ status == AVSUCCESS,
679+ " Error in av_write_trailer: " ,
680+ getFFMPEGErrorStringFromErrorCode (status));
681+ }
682+
683+ UniqueAVFrame VideoEncoder::convertTensorToAVFrame (
684+ const torch::Tensor& frame,
685+ int frameIndex) {
686+ // Initialize and cache scaling context if it does not exist
687+ if (!swsContext_) {
688+ swsContext_.reset (sws_getContext (
689+ inWidth_,
690+ inHeight_,
691+ inPixelFormat_,
692+ outWidth_,
693+ outHeight_,
694+ outPixelFormat_,
695+ SWS_BILINEAR,
696+ nullptr ,
697+ nullptr ,
698+ nullptr ));
699+ TORCH_CHECK (swsContext_ != nullptr , " Failed to create scaling context" );
700+ }
701+
702+ UniqueAVFrame avFrame (av_frame_alloc ());
703+ TORCH_CHECK (avFrame != nullptr , " Failed to allocate AVFrame" );
704+
705+ // Set output frame properties
706+ avFrame->format = outPixelFormat_;
707+ avFrame->width = outWidth_;
708+ avFrame->height = outHeight_;
709+ avFrame->pts = frameIndex;
710+
711+ int status = av_frame_get_buffer (avFrame.get (), 0 );
712+ TORCH_CHECK (status >= 0 , " Failed to allocate frame buffer" );
713+
714+ // Need to convert/scale the frame
715+ // Create temporary frame with input format
716+ UniqueAVFrame inputFrame (av_frame_alloc ());
717+ TORCH_CHECK (inputFrame != nullptr , " Failed to allocate input AVFrame" );
718+
719+ inputFrame->format = inPixelFormat_;
720+ inputFrame->width = inWidth_;
721+ inputFrame->height = inHeight_;
722+
723+ uint8_t * tensorData = static_cast <uint8_t *>(frame.data_ptr ());
724+
725+ // TODO-VideoEncoder: Reorder tensor if in NHWC format
726+ int channelSize = inHeight_ * inWidth_;
727+ // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format
728+ // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format
729+ inputFrame->data [0 ] = tensorData + channelSize;
730+ inputFrame->data [1 ] = tensorData + (2 * channelSize);
731+ inputFrame->data [2 ] = tensorData;
732+
733+ inputFrame->linesize [0 ] = inWidth_;
734+ inputFrame->linesize [1 ] = inWidth_;
735+ inputFrame->linesize [2 ] = inWidth_;
736+
737+ status = sws_scale (
738+ swsContext_.get (),
739+ inputFrame->data ,
740+ inputFrame->linesize ,
741+ 0 ,
742+ inputFrame->height ,
743+ avFrame->data ,
744+ avFrame->linesize );
745+ TORCH_CHECK (status == outHeight_, " sws_scale failed" );
746+ return avFrame;
747+ }
748+
749+ void VideoEncoder::encodeFrame (
750+ AutoAVPacket& autoAVPacket,
751+ const UniqueAVFrame& avFrame) {
752+ auto status = avcodec_send_frame (avCodecContext_.get (), avFrame.get ());
753+ TORCH_CHECK (
754+ status == AVSUCCESS,
755+ " Error while sending frame: " ,
756+ getFFMPEGErrorStringFromErrorCode (status));
757+
758+ while (true ) {
759+ ReferenceAVPacket packet (autoAVPacket);
760+ status = avcodec_receive_packet (avCodecContext_.get (), packet.get ());
761+ if (status == AVERROR (EAGAIN) || status == AVERROR_EOF) {
762+ if (status == AVERROR_EOF) {
763+ // Flush remaining buffered packets
764+ status = av_interleaved_write_frame (avFormatContext_.get (), nullptr );
765+ TORCH_CHECK (
766+ status == AVSUCCESS,
767+ " Failed to flush packet: " ,
768+ getFFMPEGErrorStringFromErrorCode (status));
769+ }
770+ return ;
771+ }
772+ TORCH_CHECK (
773+ status >= 0 ,
774+ " Error receiving packet: " ,
775+ getFFMPEGErrorStringFromErrorCode (status));
776+
777+ packet->stream_index = streamIndex_;
778+
779+ status = av_interleaved_write_frame (avFormatContext_.get (), packet.get ());
780+ TORCH_CHECK (
781+ status == AVSUCCESS,
782+ " Error in av_interleaved_write_frame: " ,
783+ getFFMPEGErrorStringFromErrorCode (status));
784+ }
785+ }
786+
787+ void VideoEncoder::flushBuffers () {
788+ AutoAVPacket autoAVPacket;
789+ // Send null frame to signal end of input
790+ encodeFrame (autoAVPacket, UniqueAVFrame (nullptr ));
791+ }
792+
514793} // namespace facebook::torchcodec
0 commit comments