Skip to content
Open
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
run: python -m pip install --upgrade pip
- name: Install dependencies and FFmpeg
run: |
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
conda install "ffmpeg=7.0.1" pkg-config pybind11 -c conda-forge
ffmpeg -version
- name: Build and install torchcodec
Expand Down
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ files = src/torchcodec
show_error_codes = True
pretty = True
allow_redefinition = True
follow_untyped_imports = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 changes: 1 addition & 1 deletion src/torchcodec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
# but that results in circular import.
from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa
from . import decoders, encoders, samplers # noqa
from . import decoders, encoders, samplers, transforms # noqa

try:
# Note that version.py is generated during install.
Expand Down
88 changes: 87 additions & 1 deletion src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import numbers
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
from typing import List, Literal, Optional, Sequence, Tuple, Union

import torch
from torch import device as torch_device, Tensor
Expand All @@ -19,6 +19,7 @@
create_decoder,
ERROR_REPORTING_INSTRUCTIONS,
)
from torchcodec.transforms import DecoderTransform, Resize


class VideoDecoder:
Expand Down Expand Up @@ -66,6 +67,10 @@ class VideoDecoder:
probably is. Default: "exact".
Read more about this parameter in:
:ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
transforms (sequence of transform objects, optional): Sequence of transforms to be
applied to the decoded frames by the decoder itself, in order. Accepts both
``torchcodec.transforms.DecoderTransform`` and ``torchvision.transforms.v2.Transform``
objects. All transforms are applied in the ouput pixel format and colorspace.
custom_frame_mappings (str, bytes, or file-like object, optional):
Mapping of frames to their metadata, typically generated via ffprobe.
This enables accurate frame seeking without requiring a full video scan.
Expand Down Expand Up @@ -104,6 +109,7 @@ def __init__(
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, torch_device]] = "cpu",
seek_mode: Literal["exact", "approximate"] = "exact",
transforms: Optional[Sequence[DecoderTransform]] = None,
custom_frame_mappings: Optional[
Union[str, bytes, io.RawIOBase, io.BufferedReader]
] = None,
Expand Down Expand Up @@ -148,13 +154,16 @@ def __init__(

device_variant = _get_cuda_backend()

transform_specs = _make_transform_specs(transforms)

core.add_video_stream(
self._decoder,
stream_index=stream_index,
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
device=device,
device_variant=device_variant,
transform_specs=transform_specs,
custom_frame_mappings=custom_frame_mappings_data,
)

Expand Down Expand Up @@ -432,6 +441,83 @@ def _get_and_validate_stream_metadata(
)


def _convert_to_decoder_native_transforms(
transforms: Sequence[DecoderTransform],
) -> List[DecoderTransform]:
"""Convert a sequence of transforms that may contain TorchVision transform
objects into a list of only TorchCodec transform objects.

Args:
transforms: Squence of transform objects. The objects can be one of two
types:
1. torchcodec.transforms.DecoderTransform
2. torchvision.transforms.v2.Transform
Our type annotation only mentions the first type so that we don't
have a hard dependency on TorchVision.

Returns:
List of DecoderTransform objects.
"""
try:
from torchvision.transforms import v2

tv_available = True
except ImportError:
tv_available = False

converted_transforms = []
for transform in transforms:
if not isinstance(transform, DecoderTransform):
if not tv_available:
raise ValueError(
f"The supplied transform, {transform}, is not a TorchCodec "
" DecoderTransform. TorchCodec also accept TorchVision "
"v2 transforms, but TorchVision is not installed."
)
if isinstance(transform, v2.Resize):
if len(transform.size) != 2:
raise ValueError(
"TorchVision Resize transform must have a (height, width) "
f"pair for the size, got {transform.size}."
)
converted_transforms.append(Resize(size=transform.size))
else:
raise ValueError(
f"Unsupported transform: {transform}. Transforms must be "
"either a TorchCodec DecoderTransform or a TorchVision "
"v2 transform."
)
else:
converted_transforms.append(transform)

return converted_transforms


def _make_transform_specs(
transforms: Optional[Sequence[DecoderTransform]],
) -> str:
"""Given a sequence of transforms, turn those into the specification string
the core API expects.

Args:
transforms: Optional sequence of transform objects. The objects can be
one of two types:
1. torchcodec.transforms.DecoderTransform
2. torchvision.transforms.v2.Transform
Our type annotation only mentions the first type so that we don't
have a hard dependency on TorchVision.

Returns:
String of transforms in the format the core API expects: transform
specifications separate by semicolons.
"""
if transforms is None:
return ""

transforms = _convert_to_decoder_native_transforms(transforms)
return ";".join([t.make_params() for t in transforms])


def _read_custom_frame_mappings(
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
) -> tuple[Tensor, Tensor, Tensor]:
Expand Down
7 changes: 7 additions & 0 deletions src/torchcodec/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._decoder_transforms import DecoderTransform, Resize # noqa
58 changes: 58 additions & 0 deletions src/torchcodec/transforms/_decoder_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Sequence


@dataclass
class DecoderTransform(ABC):
"""Base class for all decoder transforms.

A DecoderTransform is a transform that is applied by the decoder before
returning the decoded frame. The implementation does not live in TorchCodec
itself, but in the underyling decoder. Applying DecoderTransforms to frames
should be both faster and more memory efficient than receiving normally
decoded frames and applying the same kind of transform.

Most DecoderTransforms have a complementary transform in TorchVision,
specificially in torchvision.transforms.v2. For such transforms, we ensure
that:

1. Default behaviors are the same.
2. The parameters for the DecoderTransform are a subset of the
TorchVision transform.
3. Parameters with the same name control the same behavior and accept a
subset of the same types.
4. The difference between the frames returned by a DecoderTransform and
the complementary TorchVision transform are small.

All DecoderTranforms are applied in the output pixel format and colorspace.
"""

@abstractmethod
def make_params(self) -> str:
pass


@dataclass
class Resize(DecoderTransform):
"""Resize the decoded frame to a given size.

Complementary TorchVision transform: torchvision.transforms.v2.Resize.
Interpolation is always bilinear. Anti-aliasing is always on.

Args:
size: (sequence of int): Desired output size. Must be a sequence of
the form (height, width).
"""

size: Sequence[int]

def make_params(self) -> str:
assert len(self.size) == 2
return f"resize, {self.size[0]}, {self.size[1]}"
Loading
Loading