@@ -133,6 +133,24 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
133133 return UniqueCUvideodecoder (decoder, CUvideoDecoderDeleter{});
134134}
135135
136+ cudaVideoCodec validateCodecSupport (AVCodecID codecId) {
137+ switch (codecId) {
138+ case AV_CODEC_ID_H264:
139+ return cudaVideoCodec_H264;
140+ case AV_CODEC_ID_HEVC:
141+ return cudaVideoCodec_HEVC;
142+ // TODONVDEC P0: support more codecs
143+ // case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
144+ // case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
145+ // case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
146+ // case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
147+ // case AV_CODEC_ID_MJPEG: return cudaVideoCodec_JPEG;
148+ default : {
149+ TORCH_CHECK (false , " Unsupported codec type: " , avcodec_get_name (codecId));
150+ }
151+ }
152+ }
153+
136154} // namespace
137155
138156BetaCudaDeviceInterface::BetaCudaDeviceInterface (const torch::Device& device)
@@ -158,29 +176,62 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
158176 }
159177}
160178
161- void BetaCudaDeviceInterface::initializeInterface (AVStream* avStream) {
162- torch::Tensor dummyTensorForCudaInitialization = torch::empty (
163- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
179+ void BetaCudaDeviceInterface::initializeBSF (
180+ const AVCodecParameters* codecPar,
181+ const UniqueDecodingAVFormatContext& avFormatCtx) {
182+ // Setup bit stream filters (BSF):
183+ // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
184+ // This is only needed for some formats, like H264 or HEVC.
164185
165- TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
166- timeBase_ = avStream->time_base ;
186+ TORCH_CHECK (codecPar != nullptr , " codecPar cannot be null" );
187+ TORCH_CHECK (avFormatCtx != nullptr , " AVFormatContext cannot be null" );
188+ TORCH_CHECK (
189+ avFormatCtx->iformat != nullptr ,
190+ " AVFormatContext->iformat cannot be null" );
191+ std::string filterName;
192+
193+ // Matching logic is taken from DALI
194+ switch (codecPar->codec_id ) {
195+ case AV_CODEC_ID_H264: {
196+ const std::string formatName = avFormatCtx->iformat ->long_name
197+ ? avFormatCtx->iformat ->long_name
198+ : " " ;
199+
200+ if (formatName == " QuickTime / MOV" ||
201+ formatName == " FLV (Flash Video)" ||
202+ formatName == " Matroska / WebM" || formatName == " raw H.264 video" ) {
203+ filterName = " h264_mp4toannexb" ;
204+ }
205+ break ;
206+ }
167207
168- const AVCodecParameters* codecpar = avStream->codecpar ;
169- TORCH_CHECK (codecpar != nullptr , " CodecParameters cannot be null" );
208+ case AV_CODEC_ID_HEVC: {
209+ const std::string formatName = avFormatCtx->iformat ->long_name
210+ ? avFormatCtx->iformat ->long_name
211+ : " " ;
170212
171- TORCH_CHECK (
172- // TODONVDEC P0 support more
173- avStream->codecpar ->codec_id == AV_CODEC_ID_H264,
174- " Can only do H264 for now" );
213+ if (formatName == " QuickTime / MOV" ||
214+ formatName == " FLV (Flash Video)" ||
215+ formatName == " Matroska / WebM" || formatName == " raw HEVC video" ) {
216+ filterName = " hevc_mp4toannexb" ;
217+ }
218+ break ;
219+ }
175220
176- // Setup bit stream filters (BSF):
177- // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
178- // This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For
179- // now we apply BSF unconditionally, but it should be optional and dependent
180- // on codec and container.
181- const AVBitStreamFilter* avBSF = av_bsf_get_by_name (" h264_mp4toannexb" );
221+ default :
222+ // No bitstream filter needed for other codecs
223+ // TODONVDEC P1 MPEG4 will need one!
224+ break ;
225+ }
226+
227+ if (filterName.empty ()) {
228+ // Only initialize BSF if we actually need one
229+ return ;
230+ }
231+
232+ const AVBitStreamFilter* avBSF = av_bsf_get_by_name (filterName.c_str ());
182233 TORCH_CHECK (
183- avBSF != nullptr , " Failed to find h264_mp4toannexb bitstream filter" );
234+ avBSF != nullptr , " Failed to find bitstream filter: " , filterName );
184235
185236 AVBSFContext* avBSFContext = nullptr ;
186237 int retVal = av_bsf_alloc (avBSF, &avBSFContext);
@@ -191,7 +242,7 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
191242
192243 bitstreamFilter_.reset (avBSFContext);
193244
194- retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecpar );
245+ retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecPar );
195246 TORCH_CHECK (
196247 retVal >= AVSUCCESS,
197248 " Failed to copy codec parameters: " ,
@@ -202,10 +253,25 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
202253 retVal == AVSUCCESS,
203254 " Failed to initialize bitstream filter: " ,
204255 getFFMPEGErrorStringFromErrorCode (retVal));
256+ }
257+
258+ void BetaCudaDeviceInterface::initializeInterface (
259+ const AVStream* avStream,
260+ const UniqueDecodingAVFormatContext& avFormatCtx) {
261+ torch::Tensor dummyTensorForCudaInitialization = torch::empty (
262+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
263+
264+ TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
265+ timeBase_ = avStream->time_base ;
266+
267+ const AVCodecParameters* codecPar = avStream->codecpar ;
268+ TORCH_CHECK (codecPar != nullptr , " CodecParameters cannot be null" );
269+
270+ initializeBSF (codecPar, avFormatCtx);
205271
206272 // Create parser. Default values that aren't obvious are taken from DALI.
207273 CUVIDPARSERPARAMS parserParams = {};
208- parserParams.CodecType = cudaVideoCodec_H264 ;
274+ parserParams.CodecType = validateCodecSupport (codecPar-> codec_id ) ;
209275 parserParams.ulMaxNumDecodeSurfaces = 8 ;
210276 parserParams.ulMaxDisplayDelay = 0 ;
211277 // Callback setup, all are triggered by the parser within a call
0 commit comments