55# LICENSE file in the root directory of this source tree.
66
77import io
8+ import json
89import numbers
910from pathlib import Path
1011from typing import Literal , Optional , Tuple , Union
@@ -62,7 +63,25 @@ class VideoDecoder:
6263 probably is. Default: "exact".
6364 Read more about this parameter in:
6465 :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
65-
66+ custom_frame_mappings (str, bytes, or file-like object, optional):
67+ Mapping of frames to their metadata, typically generated via ffprobe.
68+ This enables accurate frame seeking without requiring a full video scan.
69+ Do not set seek_mode when custom_frame_mappings is provided.
70+ Expected JSON format:
71+
72+ .. code-block:: json
73+
74+ {
75+ "frames": [
76+ {
77+ "pts": 0,
78+ "duration": 1001,
79+ "key_frame": 1
80+ }
81+ ]
82+ }
83+
84+ Alternative field names "pkt_pts" and "pkt_duration" are also supported.
6685
6786 Attributes:
6887 metadata (VideoStreamMetadata): Metadata of the video stream.
@@ -80,6 +99,9 @@ def __init__(
8099 num_ffmpeg_threads : int = 1 ,
81100 device : Optional [Union [str , torch_device ]] = "cpu" ,
82101 seek_mode : Literal ["exact" , "approximate" ] = "exact" ,
102+ custom_frame_mappings : Optional [
103+ Union [str , bytes , io .RawIOBase , io .BufferedReader ]
104+ ] = None ,
83105 ):
84106 torch ._C ._log_api_usage_once ("torchcodec.decoders.VideoDecoder" )
85107 allowed_seek_modes = ("exact" , "approximate" )
@@ -89,6 +111,21 @@ def __init__(
89111 f"Supported values are { ', ' .join (allowed_seek_modes )} ."
90112 )
91113
114+ # Validate seek_mode and custom_frame_mappings are not mismatched
115+ if custom_frame_mappings is not None and seek_mode == "approximate" :
116+ raise ValueError (
117+ "custom_frame_mappings is incompatible with seek_mode='approximate'. "
118+ "Use seek_mode='custom_frame_mappings' or leave it unspecified to automatically use custom frame mappings."
119+ )
120+
121+ # Auto-select custom_frame_mappings seek_mode and process data when mappings are provided
122+ custom_frame_mappings_data = None
123+ if custom_frame_mappings is not None :
124+ seek_mode = "custom_frame_mappings" # type: ignore[assignment]
125+ custom_frame_mappings_data = _read_custom_frame_mappings (
126+ custom_frame_mappings
127+ )
128+
92129 self ._decoder = create_decoder (source = source , seek_mode = seek_mode )
93130
94131 allowed_dimension_orders = ("NCHW" , "NHWC" )
@@ -110,6 +147,7 @@ def __init__(
110147 dimension_order = dimension_order ,
111148 num_threads = num_ffmpeg_threads ,
112149 device = device ,
150+ custom_frame_mappings = custom_frame_mappings_data ,
113151 )
114152
115153 (
@@ -379,3 +417,57 @@ def _get_and_validate_stream_metadata(
379417 end_stream_seconds ,
380418 num_frames ,
381419 )
420+
421+
422+ def _read_custom_frame_mappings (
423+ custom_frame_mappings : Union [str , bytes , io .RawIOBase , io .BufferedReader ]
424+ ) -> tuple [Tensor , Tensor , Tensor ]:
425+ """Parse custom frame mappings from JSON data and extract frame metadata.
426+
427+ Args:
428+ custom_frame_mappings: JSON data containing frame metadata, provided as:
429+ - A JSON string (str, bytes)
430+ - A file-like object with a read() method
431+
432+ Returns:
433+ A tuple of three tensors:
434+ - all_frames (Tensor): Presentation timestamps (PTS) for each frame
435+ - is_key_frame (Tensor): Boolean tensor indicating which frames are key frames
436+ - duration (Tensor): Duration of each frame
437+ """
438+ try :
439+ input_data = (
440+ json .load (custom_frame_mappings )
441+ if hasattr (custom_frame_mappings , "read" )
442+ else json .loads (custom_frame_mappings )
443+ )
444+ except json .JSONDecodeError as e :
445+ raise ValueError (
446+ f"Invalid custom frame mappings: { e } . It should be a valid JSON string or a file-like object."
447+ ) from e
448+
449+ if not input_data or "frames" not in input_data :
450+ raise ValueError (
451+ "Invalid custom frame mappings. The input is empty or missing the required 'frames' key."
452+ )
453+
454+ first_frame = input_data ["frames" ][0 ]
455+ pts_key = next ((key for key in ("pts" , "pkt_pts" ) if key in first_frame ), None )
456+ duration_key = next (
457+ (key for key in ("duration" , "pkt_duration" ) if key in first_frame ), None
458+ )
459+ key_frame_present = "key_frame" in first_frame
460+
461+ if not pts_key or not duration_key or not key_frame_present :
462+ raise ValueError (
463+ "Invalid custom frame mappings. The 'pts'/'pkt_pts', 'duration'/'pkt_duration', and 'key_frame' keys are required in the frame metadata."
464+ )
465+
466+ frame_data = [
467+ (float (frame [pts_key ]), frame ["key_frame" ], float (frame [duration_key ]))
468+ for frame in input_data ["frames" ]
469+ ]
470+ all_frames , is_key_frame , duration = map (torch .tensor , zip (* frame_data ))
471+ if not (len (all_frames ) == len (is_key_frame ) == len (duration )):
472+ raise ValueError ("Mismatched lengths in frame index data" )
473+ return all_frames , is_key_frame , duration
0 commit comments