diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index c156a833c..57a4916a4 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -62,7 +62,7 @@ jobs: run: python -m pip install --upgrade pip - name: Install dependencies and FFmpeg run: | - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu conda install "ffmpeg=7.0.1" pkg-config pybind11 -c conda-forge ffmpeg -version - name: Build and install torchcodec diff --git a/mypy.ini b/mypy.ini index bd0ee6ac8..f018ba4f8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,3 +4,4 @@ files = src/torchcodec show_error_codes = True pretty = True allow_redefinition = True +follow_untyped_imports = True diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index c0f4c2b6d..9ea601fde 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -7,7 +7,7 @@ # Note: usort wants to put Frame and FrameBatch after decoders and samplers, # but that results in circular import. from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa -from . import decoders, encoders, samplers # noqa +from . import decoders, encoders, samplers, transforms # noqa try: # Note that version.py is generated during install. diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 130927c2e..af8b9e99d 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -8,7 +8,7 @@ import json import numbers from pathlib import Path -from typing import Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Sequence, Tuple, Union import torch from torch import device as torch_device, Tensor @@ -19,6 +19,7 @@ create_decoder, ERROR_REPORTING_INSTRUCTIONS, ) +from torchcodec.transforms import DecoderTransform, Resize class VideoDecoder: @@ -66,6 +67,10 @@ class VideoDecoder: probably is. Default: "exact". Read more about this parameter in: :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py` + transforms (sequence of transform objects, optional): Sequence of transforms to be + applied to the decoded frames by the decoder itself, in order. Accepts both + ``torchcodec.transforms.DecoderTransform`` and ``torchvision.transforms.v2.Transform`` + objects. All transforms are applied in the ouput pixel format and colorspace. custom_frame_mappings (str, bytes, or file-like object, optional): Mapping of frames to their metadata, typically generated via ffprobe. This enables accurate frame seeking without requiring a full video scan. @@ -104,6 +109,7 @@ def __init__( num_ffmpeg_threads: int = 1, device: Optional[Union[str, torch_device]] = "cpu", seek_mode: Literal["exact", "approximate"] = "exact", + transforms: Optional[Sequence[DecoderTransform]] = None, custom_frame_mappings: Optional[ Union[str, bytes, io.RawIOBase, io.BufferedReader] ] = None, @@ -148,6 +154,8 @@ def __init__( device_variant = _get_cuda_backend() + transform_specs = _make_transform_specs(transforms) + core.add_video_stream( self._decoder, stream_index=stream_index, @@ -155,6 +163,7 @@ def __init__( num_threads=num_ffmpeg_threads, device=device, device_variant=device_variant, + transform_specs=transform_specs, custom_frame_mappings=custom_frame_mappings_data, ) @@ -432,6 +441,83 @@ def _get_and_validate_stream_metadata( ) +def _convert_to_decoder_native_transforms( + transforms: Sequence[DecoderTransform], +) -> List[DecoderTransform]: + """Convert a sequence of transforms that may contain TorchVision transform + objects into a list of only TorchCodec transform objects. + + Args: + transforms: Squence of transform objects. The objects can be one of two + types: + 1. torchcodec.transforms.DecoderTransform + 2. torchvision.transforms.v2.Transform + Our type annotation only mentions the first type so that we don't + have a hard dependency on TorchVision. + + Returns: + List of DecoderTransform objects. + """ + try: + from torchvision.transforms import v2 + + tv_available = True + except ImportError: + tv_available = False + + converted_transforms = [] + for transform in transforms: + if not isinstance(transform, DecoderTransform): + if not tv_available: + raise ValueError( + f"The supplied transform, {transform}, is not a TorchCodec " + " DecoderTransform. TorchCodec also accept TorchVision " + "v2 transforms, but TorchVision is not installed." + ) + if isinstance(transform, v2.Resize): + if len(transform.size) != 2: + raise ValueError( + "TorchVision Resize transform must have a (height, width) " + f"pair for the size, got {transform.size}." + ) + converted_transforms.append(Resize(size=transform.size)) + else: + raise ValueError( + f"Unsupported transform: {transform}. Transforms must be " + "either a TorchCodec DecoderTransform or a TorchVision " + "v2 transform." + ) + else: + converted_transforms.append(transform) + + return converted_transforms + + +def _make_transform_specs( + transforms: Optional[Sequence[DecoderTransform]], +) -> str: + """Given a sequence of transforms, turn those into the specification string + the core API expects. + + Args: + transforms: Optional sequence of transform objects. The objects can be + one of two types: + 1. torchcodec.transforms.DecoderTransform + 2. torchvision.transforms.v2.Transform + Our type annotation only mentions the first type so that we don't + have a hard dependency on TorchVision. + + Returns: + String of transforms in the format the core API expects: transform + specifications separate by semicolons. + """ + if transforms is None: + return "" + + transforms = _convert_to_decoder_native_transforms(transforms) + return ";".join([t.make_params() for t in transforms]) + + def _read_custom_frame_mappings( custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader] ) -> tuple[Tensor, Tensor, Tensor]: diff --git a/src/torchcodec/transforms/__init__.py b/src/torchcodec/transforms/__init__.py new file mode 100644 index 000000000..9f4a92f81 --- /dev/null +++ b/src/torchcodec/transforms/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._decoder_transforms import DecoderTransform, Resize # noqa diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py new file mode 100644 index 000000000..ca889abd1 --- /dev/null +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Sequence + + +@dataclass +class DecoderTransform(ABC): + """Base class for all decoder transforms. + + A DecoderTransform is a transform that is applied by the decoder before + returning the decoded frame. The implementation does not live in TorchCodec + itself, but in the underyling decoder. Applying DecoderTransforms to frames + should be both faster and more memory efficient than receiving normally + decoded frames and applying the same kind of transform. + + Most DecoderTransforms have a complementary transform in TorchVision, + specificially in torchvision.transforms.v2. For such transforms, we ensure + that: + + 1. Default behaviors are the same. + 2. The parameters for the DecoderTransform are a subset of the + TorchVision transform. + 3. Parameters with the same name control the same behavior and accept a + subset of the same types. + 4. The difference between the frames returned by a DecoderTransform and + the complementary TorchVision transform are small. + + All DecoderTranforms are applied in the output pixel format and colorspace. + """ + + @abstractmethod + def make_params(self) -> str: + pass + + +@dataclass +class Resize(DecoderTransform): + """Resize the decoded frame to a given size. + + Complementary TorchVision transform: torchvision.transforms.v2.Resize. + Interpolation is always bilinear. Anti-aliasing is always on. + + Args: + size: (sequence of int): Desired output size. Must be a sequence of + the form (height, width). + """ + + size: Sequence[int] + + def make_params(self) -> str: + assert len(self.size) == 2 + return f"resize, {self.size[0]}, {self.size[1]}" diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index 370849726..edc3cfe15 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -13,6 +13,7 @@ import pytest import torch +import torchcodec from torchcodec._core import ( _add_video_stream, @@ -21,6 +22,7 @@ get_frame_at_index, get_json_metadata, ) +from torchcodec.decoders import VideoDecoder from torchvision.transforms import v2 @@ -34,7 +36,89 @@ TEST_SRC_2_720P, ) -torch._dynamo.config.capture_dynamic_output_shape_ops = True + +class TestPublicVideoDecoderTransformOps: + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0), (2.0, 2.0)), + ) + @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) + def test_resize_torchvision( + self, video, height_scaling_factor, width_scaling_factor + ): + height = int(video.get_height() * height_scaling_factor) + width = int(video.get_width() * width_scaling_factor) + + # We're using both the TorchCodec object and the TorchVision object to + # ensure that they specify exactly the same thing. + decoder_resize = VideoDecoder( + video.path, transforms=[torchcodec.transforms.Resize(size=(height, width))] + ) + decoder_resize_tv = VideoDecoder( + video.path, transforms=[v2.Resize(size=(height, width))] + ) + + decoder_full = VideoDecoder(video.path) + + num_frames = len(decoder_resize) + assert num_frames == len(decoder_full) + + for frame_index in [ + 0, + int(num_frames * 0.1), + int(num_frames * 0.2), + int(num_frames * 0.3), + int(num_frames * 0.4), + int(num_frames * 0.5), + int(num_frames * 0.75), + int(num_frames * 0.90), + num_frames - 1, + ]: + frame_resize_tv = decoder_resize_tv[frame_index] + frame_resize = decoder_resize[frame_index] + assert_frames_equal(frame_resize_tv, frame_resize) + + frame_full = decoder_full[frame_index] + + frame_tv = v2.functional.resize(frame_full, size=(height, width)) + frame_tv_no_antialias = v2.functional.resize( + frame_full, size=(height, width), antialias=False + ) + + expected_shape = (video.get_num_color_channels(), height, width) + assert frame_resize.shape == expected_shape + assert frame_tv.shape == expected_shape + assert frame_tv_no_antialias.shape == expected_shape + + assert_tensor_close_on_at_least( + frame_resize, frame_tv, percentage=99.8, atol=1 + ) + torch.testing.assert_close(frame_resize, frame_tv, rtol=0, atol=6) + + if height_scaling_factor < 1 or width_scaling_factor < 1: + # Antialias only relevant when down-scaling! + with pytest.raises(AssertionError, match="Expected at least"): + assert_tensor_close_on_at_least( + frame_resize, frame_tv_no_antialias, percentage=99, atol=1 + ) + with pytest.raises(AssertionError, match="Tensor-likes are not close"): + torch.testing.assert_close( + frame_resize, frame_tv_no_antialias, rtol=0, atol=6 + ) + + def test_resize_fails(self): + with pytest.raises( + ValueError, + match=r"must have a \(height, width\) pair for the size", + ): + VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))]) + + def test_transform_fails(self): + with pytest.raises( + ValueError, + match="Unsupported transform", + ): + VideoDecoder(NASA_VIDEO.path, transforms=[v2.RandomHorizontalFlip(p=1.0)]) class TestCoreVideoDecoderTransformOps: @@ -172,68 +256,6 @@ def test_transform_fails(self): ): add_video_stream(decoder, transform_specs="invalid, 1, 2") - @pytest.mark.parametrize( - "height_scaling_factor, width_scaling_factor", - ((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0), (2.0, 2.0)), - ) - @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) - def test_resize_torchvision( - self, video, height_scaling_factor, width_scaling_factor - ): - num_frames = self.get_num_frames_core_ops(video) - - height = int(video.get_height() * height_scaling_factor) - width = int(video.get_width() * width_scaling_factor) - resize_spec = f"resize, {height}, {width}" - - decoder_resize = create_from_file(str(video.path)) - add_video_stream(decoder_resize, transform_specs=resize_spec) - - decoder_full = create_from_file(str(video.path)) - add_video_stream(decoder_full) - - for frame_index in [ - 0, - int(num_frames * 0.1), - int(num_frames * 0.2), - int(num_frames * 0.3), - int(num_frames * 0.4), - int(num_frames * 0.5), - int(num_frames * 0.75), - int(num_frames * 0.90), - num_frames - 1, - ]: - expected_shape = (video.get_num_color_channels(), height, width) - frame_resize, *_ = get_frame_at_index( - decoder_resize, frame_index=frame_index - ) - - frame_full, *_ = get_frame_at_index(decoder_full, frame_index=frame_index) - frame_tv = v2.functional.resize(frame_full, size=(height, width)) - frame_tv_no_antialias = v2.functional.resize( - frame_full, size=(height, width), antialias=False - ) - - assert frame_resize.shape == expected_shape - assert frame_tv.shape == expected_shape - assert frame_tv_no_antialias.shape == expected_shape - - assert_tensor_close_on_at_least( - frame_resize, frame_tv, percentage=99.8, atol=1 - ) - torch.testing.assert_close(frame_resize, frame_tv, rtol=0, atol=6) - - if height_scaling_factor < 1 or width_scaling_factor < 1: - # Antialias only relevant when down-scaling! - with pytest.raises(AssertionError, match="Expected at least"): - assert_tensor_close_on_at_least( - frame_resize, frame_tv_no_antialias, percentage=99, atol=1 - ) - with pytest.raises(AssertionError, match="Tensor-likes are not close"): - torch.testing.assert_close( - frame_resize, frame_tv_no_antialias, rtol=0, atol=6 - ) - def test_resize_ffmpeg(self): height = 135 width = 240