Skip to content

Commit 492a6bc

Browse files
authored
Better compile test, fixup abstract methods (#525)
1 parent 7d1d7fa commit 492a6bc

File tree

3 files changed

+38
-47
lines changed

3 files changed

+38
-47
lines changed

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def get_frame_at_pts_abstract(
185185
def get_frames_by_pts_abstract(
186186
decoder: torch.Tensor,
187187
*,
188-
stream_index: int,
189188
timestamps: List[float],
190189
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
191190
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
@@ -198,7 +197,7 @@ def get_frames_by_pts_abstract(
198197

199198
@register_fake("torchcodec_ns::get_frame_at_index")
200199
def get_frame_at_index_abstract(
201-
decoder: torch.Tensor, *, stream_index: int, frame_index: int
200+
decoder: torch.Tensor, *, frame_index: int
202201
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
203202
image_size = [get_ctx().new_dynamic_size() for _ in range(3)]
204203
return (
@@ -212,7 +211,6 @@ def get_frame_at_index_abstract(
212211
def get_frames_at_indices_abstract(
213212
decoder: torch.Tensor,
214213
*,
215-
stream_index: int,
216214
frame_indices: List[int],
217215
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
218216
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
@@ -227,7 +225,6 @@ def get_frames_at_indices_abstract(
227225
def get_frames_in_range_abstract(
228226
decoder: torch.Tensor,
229227
*,
230-
stream_index: int,
231228
start: int,
232229
stop: int,
233230
step: Optional[int] = None,
@@ -244,7 +241,6 @@ def get_frames_in_range_abstract(
244241
def get_frames_by_pts_in_range_abstract(
245242
decoder: torch.Tensor,
246243
*,
247-
stream_index: int,
248244
start_seconds: float,
249245
stop_seconds: float,
250246
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -257,9 +253,7 @@ def get_frames_by_pts_in_range_abstract(
257253

258254

259255
@register_fake("torchcodec_ns::_get_key_frame_indices")
260-
def get_key_frame_indices_abstract(
261-
decoder: torch.Tensor, *, stream_index: int
262-
) -> torch.Tensor:
256+
def get_key_frame_indices_abstract(decoder: torch.Tensor) -> torch.Tensor:
263257
return torch.empty([], dtype=torch.int)
264258

265259

@@ -282,7 +276,6 @@ def get_stream_json_metadata_abstract(decoder: torch.Tensor, stream_idx: int) ->
282276
def _test_frame_pts_equality_abstract(
283277
decoder: torch.Tensor,
284278
*,
285-
stream_index: int,
286279
frame_index: int,
287280
pts_seconds_to_test: float,
288281
) -> bool:

test/decoders/test_video_decoder.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextlib
8+
79
import numpy
810
import pytest
911
import torch
@@ -874,6 +876,38 @@ def test_get_key_frame_indices(self, device):
874876
key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0
875877
)
876878

879+
@pytest.mark.parametrize("device", cpu_and_cuda())
880+
def test_compile(self, device):
881+
decoder = VideoDecoder(NASA_VIDEO.path, device=device)
877882

878-
if __name__ == "__main__":
879-
pytest.main()
883+
@contextlib.contextmanager
884+
def restore_capture_scalar_outputs():
885+
try:
886+
original = torch._dynamo.config.capture_scalar_outputs
887+
yield
888+
finally:
889+
torch._dynamo.config.capture_scalar_outputs = original
890+
891+
# TODO: We get a graph break because we call Tensor.item() to turn the
892+
# tensors in FrameBatch into scalars. When we work on compilation and exportability,
893+
# we should investigate.
894+
with restore_capture_scalar_outputs():
895+
torch._dynamo.config.capture_scalar_outputs = True
896+
897+
@torch.compile(fullgraph=True, backend="eager")
898+
def get_some_frames(decoder):
899+
frames = []
900+
frames.append(decoder.get_frame_at(1))
901+
frames.append(decoder.get_frame_at(3))
902+
frames.append(decoder.get_frame_at(5))
903+
return frames
904+
905+
frames = get_some_frames(decoder)
906+
907+
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device)
908+
ref_frame3 = NASA_VIDEO.get_frame_data_by_index(3).to(device)
909+
ref_frame5 = NASA_VIDEO.get_frame_data_by_index(5).to(device)
910+
911+
assert_frames_equal(ref_frame1, frames[0].data)
912+
assert_frames_equal(ref_frame3, frames[1].data)
913+
assert_frames_equal(ref_frame5, frames[2].data)

test/decoders/test_video_decoder_ops.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
os.environ["TORCH_LOGS"] = "output_code"
1010
import json
1111
import subprocess
12-
from typing import Tuple
1312

1413
import numpy as np
1514
import pytest
@@ -48,20 +47,6 @@
4847
INDEX_OF_FRAME_AT_6_SECONDS = 180
4948

5049

51-
class ReferenceDecoder:
52-
def __init__(self, device="cpu"):
53-
self.decoder: torch.Tensor = create_from_file(str(NASA_VIDEO.path))
54-
add_video_stream(self.decoder, device=device)
55-
56-
def get_next_frame(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
57-
assert self.decoder is not None
58-
return get_next_frame(self.decoder)
59-
60-
def seek(self, pts: float):
61-
assert self.decoder is not None
62-
seek_to_pts(self.decoder, pts)
63-
64-
6550
class TestOps:
6651
@pytest.mark.parametrize("device", cpu_and_cuda())
6752
def test_seek_and_next(self, device):
@@ -352,27 +337,6 @@ def get_frame1_and_frame_time6(decoder):
352337
assert_frames_equal(frame0, reference_frame0.to(device))
353338
assert_frames_equal(frame_time6, reference_frame_time6.to(device))
354339

355-
@pytest.mark.parametrize("device", cpu_and_cuda())
356-
def test_class_based_compile_seek_and_next(self, device):
357-
# TODO_OPEN_ISSUE Scott (T180277797): Ditto as above.
358-
@torch.compile(fullgraph=True, backend="eager")
359-
def class_based_get_frame1_and_frame_time6(
360-
decoder: ReferenceDecoder,
361-
) -> Tuple[torch.Tensor, torch.Tensor]:
362-
frame0, _, _ = decoder.get_next_frame()
363-
decoder.seek(6.0)
364-
frame_time6, _, _ = decoder.get_next_frame()
365-
return frame0, frame_time6
366-
367-
decoder = ReferenceDecoder(device=device)
368-
frame0, frame_time6 = class_based_get_frame1_and_frame_time6(decoder)
369-
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
370-
reference_frame_time6 = NASA_VIDEO.get_frame_data_by_index(
371-
INDEX_OF_FRAME_AT_6_SECONDS
372-
)
373-
assert_frames_equal(frame0, reference_frame0.to(device))
374-
assert_frames_equal(frame_time6, reference_frame_time6.to(device))
375-
376340
@pytest.mark.parametrize("device", cpu_and_cuda())
377341
@pytest.mark.parametrize("create_from", ("file", "tensor", "bytes"))
378342
def test_create_decoder(self, create_from, device):

0 commit comments

Comments
 (0)