Skip to content

Commit 7b9fda4

Browse files
Dan-FloresDan-Flores
authored andcommitted
use device in Encoder
1 parent ec59356 commit 7b9fda4

File tree

4 files changed

+30
-2
lines changed

4 files changed

+30
-2
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,25 @@ VideoEncoder::VideoEncoder(
615615

616616
void VideoEncoder::initializeEncoder(
617617
const VideoStreamOptions& videoStreamOptions) {
618+
deviceInterface_ = createDeviceInterface(
619+
videoStreamOptions.device, videoStreamOptions.deviceVariant);
620+
TORCH_CHECK(
621+
deviceInterface_ != nullptr,
622+
"Failed to create device interface. This should never happen, please report.");
623+
618624
const AVCodec* avCodec =
619625
avcodec_find_encoder(avFormatContext_->oformat->video_codec);
620626
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
621627

628+
// Try to find a hardware-accelerated encoder if not using CPU
629+
if (videoStreamOptions.device.type() != torch::kCPU) {
630+
auto hardwareCodec =
631+
deviceInterface_->findCodec(avFormatContext_->oformat->video_codec);
632+
if (hardwareCodec.has_value()) {
633+
avCodec = hardwareCodec.value();
634+
}
635+
}
636+
622637
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
623638
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
624639
avCodecContext_.reset(avCodecContext);
@@ -668,6 +683,11 @@ void VideoEncoder::initializeEncoder(
668683
std::to_string(videoStreamOptions.crf.value()).c_str(),
669684
0);
670685
}
686+
687+
// Register the hardware device context with the codec
688+
// context before calling avcodec_open2().
689+
deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get());
690+
671691
int status = avcodec_open2(avCodecContext_.get(), avCodec, &options);
672692
av_dict_free(&options);
673693

src/torchcodec/_core/Encoder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22
#include <torch/types.h>
33
#include "src/torchcodec/_core/AVIOContextHolder.h"
4+
#include "src/torchcodec/_core/DeviceInterface.h"
45
#include "src/torchcodec/_core/FFMPEGCommon.h"
56
#include "src/torchcodec/_core/StreamOptions.h"
67

@@ -177,6 +178,7 @@ class VideoEncoder {
177178
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;
178179

179180
std::unique_ptr<AVIOContextHolder> avioContextHolder_;
181+
std::unique_ptr<DeviceInterface> deviceInterface_;
180182

181183
bool encodeWasCalled_ = false;
182184
};

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ AVPacket* ReferenceAVPacket::operator->() {
4040

4141
AVCodecOnlyUseForCallingAVFindBestStream
4242
makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) {
43-
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100)
43+
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) // FFmpeg < 5.0.3
4444
return const_cast<AVCodec*>(codec);
4545
#else
4646
return codec;

src/torchcodec/encoders/_video_encoder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ class VideoEncoder:
2222
Note: The "beta" CUDA backend is not supported for encoding.
2323
"""
2424

25-
def __init__(self, frames: Tensor, *, frame_rate: int):
25+
def __init__(
26+
self,
27+
frames: Tensor,
28+
*,
29+
frame_rate: int,
30+
device: Optional[Union[str, torch_device]] = "cpu",
31+
):
2632
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
2733
if not isinstance(frames, Tensor):
2834
raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.")

0 commit comments

Comments
 (0)