Skip to content

Commit b597922

Browse files
mollyxuMolly Xu
andauthored
Suport device=None (#1025)
Co-authored-by: Molly Xu <mollyxu@fb.com>
1 parent 8e615e3 commit b597922

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ class VideoDecoder:
5555
decoding which is best if you are running a single instance of ``VideoDecoder``.
5656
Passing 0 lets FFmpeg decide on the number of threads.
5757
Default: 1.
58-
device (str or torch.device, optional): The device to use for decoding. Default: "cpu".
58+
device (str or torch.device, optional): The device to use for decoding.
59+
If ``None`` (default), uses the current default device.
5960
If you pass a CUDA device, we recommend trying the "beta" CUDA
6061
backend which is faster! See :func:`~torchcodec.decoders.set_cuda_backend`.
6162
seek_mode (str, optional): Determines if frame access will be "exact" or
@@ -102,7 +103,7 @@ def __init__(
102103
stream_index: Optional[int] = None,
103104
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
104105
num_ffmpeg_threads: int = 1,
105-
device: Optional[Union[str, torch_device]] = "cpu",
106+
device: Optional[Union[str, torch_device]] = None,
106107
seek_mode: Literal["exact", "approximate"] = "exact",
107108
custom_frame_mappings: Optional[
108109
Union[str, bytes, io.RawIOBase, io.BufferedReader]
@@ -143,7 +144,9 @@ def __init__(
143144
if num_ffmpeg_threads is None:
144145
raise ValueError(f"{num_ffmpeg_threads = } should be an int.")
145146

146-
if isinstance(device, torch_device):
147+
if device is None:
148+
device = str(torch.get_default_device())
149+
elif isinstance(device, torch_device):
147150
device = str(device)
148151

149152
device_variant = _get_cuda_backend()

test/test_decoders.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,31 @@ def test_device_instance(self):
388388
decoder = VideoDecoder(NASA_VIDEO.path, device=torch.device("cpu"))
389389
assert isinstance(decoder.metadata, VideoStreamMetadata)
390390

391+
@pytest.mark.parametrize(
392+
"device_str",
393+
[
394+
"cpu",
395+
pytest.param("cuda", marks=pytest.mark.needs_cuda),
396+
],
397+
)
398+
def test_device_none_default_device(self, device_str):
399+
# VideoDecoder defaults to device=None, which should respect both
400+
# torch.device() context manager and torch.set_default_device().
401+
402+
# Test with context manager
403+
with torch.device(device_str):
404+
decoder = VideoDecoder(NASA_VIDEO.path)
405+
assert decoder[0].device.type == device_str
406+
407+
# Test with set_default_device
408+
original_device = torch.get_default_device()
409+
try:
410+
torch.set_default_device(device_str)
411+
decoder = VideoDecoder(NASA_VIDEO.path)
412+
assert decoder[0].device.type == device_str
413+
finally:
414+
torch.set_default_device(original_device)
415+
391416
@pytest.mark.parametrize("device", all_supported_devices())
392417
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
393418
def test_getitem_fails(self, device, seek_mode):

0 commit comments

Comments
 (0)