Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ class CpuDeviceInterface : public DeviceInterface {

virtual ~CpuDeviceInterface() {}

std::optional<const AVCodec*> findCodec(
[[maybe_unused]] const AVCodecID& codecId) override {
return std::nullopt;
}

virtual void initialize(
const AVStream* avStream,
const UniqueDecodingAVFormatContext& avFormatCtx,
Expand Down
39 changes: 32 additions & 7 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,40 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
}

namespace {
// Helper function to check if a codec supports CUDA hardware acceleration
bool codecSupportsCudaHardware(const AVCodec* codec) {
const AVCodecHWConfig* config = nullptr;
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; ++j) {
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
return true;
}
}
return false;
}
} // namespace

// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
// we have to do this because of an FFmpeg bug where hardware decoding is not
// appropriately set, so we just go off and find the matching codec for the CUDA
// device
std::optional<const AVCodec*> CudaDeviceInterface::findCodec(

std::optional<const AVCodec*> CudaDeviceInterface::findEncoder(
const AVCodecID& codecId) {
void* i = nullptr;
const AVCodec* codec = nullptr;
while ((codec = av_codec_iterate(&i)) != nullptr) {
if (codec->id != codecId || !av_codec_is_encoder(codec)) {
continue;
}
if (codecSupportsCudaHardware(codec)) {
return codec;
}
}
return std::nullopt;
}

std::optional<const AVCodec*> CudaDeviceInterface::findDecoder(
const AVCodecID& codecId) {
void* i = nullptr;
const AVCodec* codec = nullptr;
Expand All @@ -342,12 +371,8 @@ std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
continue;
}

const AVCodecHWConfig* config = nullptr;
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr;
++j) {
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
return codec;
}
if (codecSupportsCudaHardware(codec)) {
return codec;
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/_core/CudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class CudaDeviceInterface : public DeviceInterface {

virtual ~CudaDeviceInterface();

std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
std::optional<const AVCodec*> findEncoder(const AVCodecID& codecId) override;
std::optional<const AVCodec*> findDecoder(const AVCodecID& codecId) override;

void initialize(
const AVStream* avStream,
Expand Down
7 changes: 6 additions & 1 deletion src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ class DeviceInterface {
return device_;
};

virtual std::optional<const AVCodec*> findCodec(
virtual std::optional<const AVCodec*> findEncoder(
[[maybe_unused]] const AVCodecID& codecId) {
return std::nullopt;
};

virtual std::optional<const AVCodec*> findDecoder(
[[maybe_unused]] const AVCodecID& codecId) {
return std::nullopt;
};
Expand Down
23 changes: 22 additions & 1 deletion src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,23 @@ VideoEncoder::VideoEncoder(

void VideoEncoder::initializeEncoder(
const VideoStreamOptions& videoStreamOptions) {
deviceInterface_ = createDeviceInterface(
videoStreamOptions.device, videoStreamOptions.deviceVariant);
TORCH_CHECK(
deviceInterface_ != nullptr,
"Failed to create device interface. This should never happen, please report.");

const AVCodec* avCodec =
avcodec_find_encoder(avFormatContext_->oformat->video_codec);
TORCH_CHECK(avCodec != nullptr, "Video codec not found");

// Try to find a hardware-accelerated encoder if not using CPU
if (videoStreamOptions.device.type() != torch::kCPU) {
avCodec =
deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec)
.value_or(avCodec);
}

AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
avCodecContext_.reset(avCodecContext);
Expand Down Expand Up @@ -662,12 +675,20 @@ void VideoEncoder::initializeEncoder(
// Apply videoStreamOptions
AVDictionary* options = nullptr;
if (videoStreamOptions.crf.has_value()) {
// nvenc encoders use qp, others use crf (for C++ tests)
std::string_view quality_param =
(strstr(avCodec->name, "nvenc") == nullptr) ? "crf" : "qp";
av_dict_set(
&options,
"crf",
quality_param.data(),
std::to_string(videoStreamOptions.crf.value()).c_str(),
0);
}

// Register the hardware device context with the codec
// context before calling avcodec_open2().
deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get());

int status = avcodec_open2(avCodecContext_.get(), avCodec, &options);
av_dict_free(&options);

Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
#include <torch/types.h>
#include "AVIOContextHolder.h"
#include "DeviceInterface.h"
#include "FFMPEGCommon.h"
#include "StreamOptions.h"

Expand Down Expand Up @@ -177,6 +178,7 @@ class VideoEncoder {
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;

std::unique_ptr<AVIOContextHolder> avioContextHolder_;
std::unique_ptr<DeviceInterface> deviceInterface_;

bool encodeWasCalled_ = false;
};
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ AVPacket* ReferenceAVPacket::operator->() {

AVCodecOnlyUseForCallingAVFindBestStream
makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) {
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100)
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) // FFmpeg < 5.0.3
return const_cast<AVCodec*>(codec);
#else
return codec;
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ void SingleStreamDecoder::addStream(
// addStream() which is supposed to be generic
if (mediaType == AVMEDIA_TYPE_VIDEO) {
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
deviceInterface_->findDecoder(streamInfo.stream->codecpar->codec_id)
.value_or(avCodec));
}

Expand Down
14 changes: 11 additions & 3 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str device=\"cpu\", int? crf=None) -> ()");
m.def(
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device=\"cpu\", int? crf=None) -> Tensor");
m.def(
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device=\"cpu\",int? crf=None) -> ()");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def(
Expand Down Expand Up @@ -603,9 +603,12 @@ void encode_video_to_file(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view file_name,
std::string_view device = "cpu",
std::optional<int64_t> crf = std::nullopt) {
VideoStreamOptions videoStreamOptions;
videoStreamOptions.crf = crf;

videoStreamOptions.device = torch::Device(std::string(device));
VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
Expand All @@ -618,10 +621,13 @@ at::Tensor encode_video_to_tensor(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view format,
std::string_view device = "cpu",
std::optional<int64_t> crf = std::nullopt) {
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
VideoStreamOptions videoStreamOptions;
videoStreamOptions.crf = crf;

videoStreamOptions.device = torch::Device(std::string(device));
return VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
Expand All @@ -636,6 +642,7 @@ void _encode_video_to_file_like(
int64_t frame_rate,
std::string_view format,
int64_t file_like_context,
std::string_view device = "cpu",
std::optional<int64_t> crf = std::nullopt) {
auto fileLikeContext =
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
Expand All @@ -645,6 +652,7 @@ void _encode_video_to_file_like(

VideoStreamOptions videoStreamOptions;
videoStreamOptions.crf = crf;
videoStreamOptions.device = torch::Device(std::string(device));

VideoEncoder encoder(
frames,
Expand Down
10 changes: 8 additions & 2 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def encode_video_to_file_like(
frame_rate: int,
format: str,
file_like: Union[io.RawIOBase, io.BufferedIOBase],
device: Optional[str] = "cpu",
crf: Optional[int] = None,
) -> None:
"""Encode video frames to a file-like object.
Expand All @@ -221,6 +222,7 @@ def encode_video_to_file_like(
frame_rate: Frame rate in frames per second
format: Video format (e.g., "mp4", "mov", "mkv")
file_like: File-like object that supports write() and seek() methods
device: Device to use for encoding (default: "cpu")
crf: Optional constant rate factor for encoding quality
"""
assert _pybind_ops is not None
Expand All @@ -230,6 +232,7 @@ def encode_video_to_file_like(
frame_rate,
format,
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
device,
crf,
)

Expand Down Expand Up @@ -318,7 +321,8 @@ def encode_video_to_file_abstract(
frames: torch.Tensor,
frame_rate: int,
filename: str,
crf: Optional[int],
device: str = "cpu",
crf: Optional[int] = None,
) -> None:
return

Expand All @@ -328,7 +332,8 @@ def encode_video_to_tensor_abstract(
frames: torch.Tensor,
frame_rate: int,
format: str,
crf: Optional[int],
device: str = "cpu",
crf: Optional[int] = None,
) -> torch.Tensor:
return torch.empty([], dtype=torch.long)

Expand All @@ -339,6 +344,7 @@ def _encode_video_to_file_like_abstract(
frame_rate: int,
format: str,
file_like_context: int,
device: str = "cpu",
crf: Optional[int] = None,
) -> None:
return
Expand Down
23 changes: 20 additions & 3 deletions src/torchcodec/encoders/_video_encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pathlib import Path
from typing import Union
from typing import Optional, Union

import torch
from torch import Tensor
from torch import device as torch_device, Tensor

from torchcodec import _core

Expand All @@ -16,9 +16,18 @@ class VideoEncoder:
C is 3 channels (RGB), H is height, and W is width.
Values must be uint8 in the range ``[0, 255]``.
frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
device (str or torch.device, optional): The device to use for encoding. Default: "cpu".
If you pass a CUDA device, frames will be encoded on GPU.
Note: The "beta" CUDA backend is not supported for encoding.
"""

def __init__(self, frames: Tensor, *, frame_rate: int):
def __init__(
self,
frames: Tensor,
*,
frame_rate: int,
device: Optional[Union[str, torch_device]] = "cpu",
):
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
if not isinstance(frames, Tensor):
raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.")
Expand All @@ -29,8 +38,13 @@ def __init__(self, frames: Tensor, *, frame_rate: int):
if frame_rate <= 0:
raise ValueError(f"{frame_rate = } must be > 0.")

# Validate and store device
if isinstance(device, torch_device):
device = str(device)

self._frames = frames
self._frame_rate = frame_rate
self._device = device

def to_file(
self,
Expand All @@ -47,6 +61,7 @@ def to_file(
frames=self._frames,
frame_rate=self._frame_rate,
filename=str(dest),
device=self._device,
)

def to_tensor(
Expand All @@ -66,6 +81,7 @@ def to_tensor(
frames=self._frames,
frame_rate=self._frame_rate,
format=format,
device=self._device,
)

def to_file_like(
Expand All @@ -89,4 +105,5 @@ def to_file_like(
frame_rate=self._frame_rate,
format=format,
file_like=file_like,
device=self._device,
)
Loading
Loading