Skip to content

Commit 1e06ea5

Browse files
author
Daniel Flores
committed
video encoder python file
1 parent 75a3325 commit 1e06ea5

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from ._audio_encoder import AudioEncoder # noqa
2+
from ._video_encoder import VideoEncoder # noqa
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from pathlib import Path
2+
from typing import Union
3+
4+
import torch
5+
from torch import Tensor
6+
7+
from torchcodec import _core
8+
9+
10+
class VideoEncoder:
11+
"""A video encoder.
12+
13+
Args:
14+
frames (``torch.Tensor``): The frames to encode. This must be a 4D
15+
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
16+
C is 3 channels (RGB), H is height, and W is width.
17+
A 3D tensor of shape ``(C, H, W)`` is also accepted as a single RGB frame.
18+
Values must be uint8 in the range ``[0, 255]``.
19+
frame_rate (int): The frame rate to use when encoding the
20+
**input** ``frames``.
21+
"""
22+
23+
def __init__(self, frames: Tensor, *, frame_rate: int):
24+
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
25+
if not isinstance(frames, Tensor):
26+
raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.")
27+
if frames.ndim == 3:
28+
# make it 4D and assume single RGB frame, CHW -> NCHW
29+
frames = torch.unsqueeze(frames, 0)
30+
if frames.ndim != 4:
31+
raise ValueError(f"Expected 3D or 4D frames, got {frames.shape = }.")
32+
if frames.dtype != torch.uint8:
33+
raise ValueError(f"Expected uint8 frames, got {frames.dtype = }.")
34+
if frame_rate <= 0:
35+
raise ValueError(f"{frame_rate = } must be > 0.")
36+
37+
self._frames = frames
38+
self._frame_rate = frame_rate
39+
40+
def to_file(
41+
self,
42+
dest: Union[str, Path],
43+
) -> None:
44+
"""Encode frames into a file.
45+
46+
Args:
47+
dest (str or ``pathlib.Path``): The path to the output file, e.g.
48+
``video.mp4``. The extension of the file determines the video
49+
format and container.
50+
"""
51+
_core.encode_video_to_file(
52+
frames=self._frames,
53+
frame_rate=self._frame_rate,
54+
filename=str(dest),
55+
)
56+
57+
def to_tensor(
58+
self,
59+
format: str,
60+
) -> Tensor:
61+
"""Encode frames into raw bytes, as a 1D uint8 Tensor.
62+
63+
Args:
64+
format (str): The format of the encoded frames, e.g. "mp4", "mov",
65+
"mkv", "avi", "webm", "flv", or "gif"
66+
67+
Returns:
68+
Tensor: The raw encoded bytes as 4D uint8 Tensor.
69+
"""
70+
return _core.encode_video_to_tensor(
71+
frames=self._frames,
72+
frame_rate=self._frame_rate,
73+
format=format,
74+
)
75+
76+
def to_file_like(
77+
self,
78+
file_like,
79+
format: str,
80+
) -> None:
81+
"""Encode frames into a file-like object.
82+
83+
Args:
84+
file_like: A file-like object that supports ``write()`` and
85+
``seek()`` methods, such as io.BytesIO(), an open file in binary
86+
write mode, etc. Methods must have the following signature:
87+
``write(data: bytes) -> int`` and ``seek(offset: int, whence:
88+
int = 0) -> int``.
89+
format (str): The format of the encoded frames, e.g. "mp4", "mov",
90+
"mkv", "avi", "webm", "flv", or "gif".
91+
"""
92+
_core.encode_video_to_file_like(
93+
frames=self._frames,
94+
frame_rate=self._frame_rate,
95+
format=format,
96+
file_like=file_like,
97+
)

0 commit comments

Comments
 (0)