Skip to content

Commit f69d6ca

Browse files
committed
refactor getCodec to handle decoders and encoders
1 parent 7b9fda4 commit f69d6ca

File tree

6 files changed

+44
-15
lines changed

6 files changed

+44
-15
lines changed

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ class CpuDeviceInterface : public DeviceInterface {
1818

1919
virtual ~CpuDeviceInterface() {}
2020

21-
std::optional<const AVCodec*> findCodec(
22-
[[maybe_unused]] const AVCodecID& codecId) override {
23-
return std::nullopt;
24-
}
2521

2622
virtual void initialize(
2723
const AVStream* avStream,

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,42 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
329329
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
330330
}
331331

332+
namespace {
333+
// Helper function to check if a codec supports CUDA hardware acceleration
334+
bool codecSupportsCudaHardware(const AVCodec* codec) {
335+
const AVCodecHWConfig* config = nullptr;
336+
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; ++j) {
337+
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
338+
return true;
339+
}
340+
}
341+
return false;
342+
}
343+
} // namespace
344+
332345
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
333346
// we have to do this because of an FFmpeg bug where hardware decoding is not
334347
// appropriately set, so we just go off and find the matching codec for the CUDA
335348
// device
336-
std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
349+
350+
std::optional<const AVCodec*> CudaDeviceInterface::findEncoder(
351+
const AVCodecID& codecId) {
352+
void* i = nullptr;
353+
const AVCodec* codec = nullptr;
354+
while ((codec = av_codec_iterate(&i)) != nullptr) {
355+
if (codec->id != codecId || !av_codec_is_encoder(codec)) {
356+
continue;
357+
}
358+
359+
if (codecSupportsCudaHardware(codec)) {
360+
return codec;
361+
}
362+
}
363+
364+
return std::nullopt;
365+
}
366+
367+
std::optional<const AVCodec*> CudaDeviceInterface::findDecoder(
337368
const AVCodecID& codecId) {
338369
void* i = nullptr;
339370
const AVCodec* codec = nullptr;
@@ -342,12 +373,8 @@ std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
342373
continue;
343374
}
344375

345-
const AVCodecHWConfig* config = nullptr;
346-
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr;
347-
++j) {
348-
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
349-
return codec;
350-
}
376+
if (codecSupportsCudaHardware(codec)) {
377+
return codec;
351378
}
352379
}
353380

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class CudaDeviceInterface : public DeviceInterface {
1818

1919
virtual ~CudaDeviceInterface();
2020

21-
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
21+
std::optional<const AVCodec*> findEncoder(const AVCodecID& codecId) override;
22+
std::optional<const AVCodec*> findDecoder(const AVCodecID& codecId) override;
2223

2324
void initialize(
2425
const AVStream* avStream,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ class DeviceInterface {
4646
return device_;
4747
};
4848

49-
virtual std::optional<const AVCodec*> findCodec(
49+
virtual std::optional<const AVCodec*> findEncoder(
50+
[[maybe_unused]] const AVCodecID& codecId) {
51+
return std::nullopt;
52+
};
53+
54+
virtual std::optional<const AVCodec*> findDecoder(
5055
[[maybe_unused]] const AVCodecID& codecId) {
5156
return std::nullopt;
5257
};

src/torchcodec/_core/Encoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ void VideoEncoder::initializeEncoder(
628628
// Try to find a hardware-accelerated encoder if not using CPU
629629
if (videoStreamOptions.device.type() != torch::kCPU) {
630630
auto hardwareCodec =
631-
deviceInterface_->findCodec(avFormatContext_->oformat->video_codec);
631+
deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec);
632632
if (hardwareCodec.has_value()) {
633633
avCodec = hardwareCodec.value();
634634
}

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ void SingleStreamDecoder::addStream(
435435
// addStream() which is supposed to be generic
436436
if (mediaType == AVMEDIA_TYPE_VIDEO) {
437437
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
438-
deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
438+
deviceInterface_->findDecoder(streamInfo.stream->codecpar->codec_id)
439439
.value_or(avCodec));
440440
}
441441

0 commit comments

Comments
 (0)