Skip to content

Commit c0ddf78

Browse files
authored
[TorchCodec] Late-init TorchVision imports
Differential Revision: D80824234 Pull Request resolved: #842
1 parent 9672b81 commit c0ddf78

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def decode_first_n_frames(self, video_file, n):
4747
def decode_first_n_frames_description(self, n) -> str:
4848
return f"first {n} frames"
4949

50+
def init_decode_and_resize(self):
51+
pass
52+
5053
@abc.abstractmethod
5154
def decode_and_resize(self, video_file, pts_list, height, width, device):
5255
pass
@@ -107,9 +110,12 @@ def __init__(self, backend):
107110
self._backend = backend
108111
self._print_each_iteration_time = False
109112
import torchvision # noqa: F401
110-
from torchvision.transforms import v2 as transforms_v2
111113

112114
self.torchvision = torchvision
115+
116+
def init_decode_and_resize(self):
117+
from torchvision.transforms import v2 as transforms_v2
118+
113119
self.transforms_v2 = transforms_v2
114120

115121
def decode_frames(self, video_file, pts_list):
@@ -267,6 +273,7 @@ def __init__(
267273
self._color_conversion_library = color_conversion_library
268274
self._device = device
269275

276+
def init_decode_and_resize(self):
270277
from torchvision.transforms import v2 as transforms_v2
271278

272279
self.transforms_v2 = transforms_v2
@@ -379,6 +386,7 @@ def __init__(
379386
self._seek_mode = seek_mode
380387
self._stream_index = int(stream_index) if stream_index else None
381388

389+
def init_decode_and_resize(self):
382390
from torchvision.transforms import v2 as transforms_v2
383391

384392
self.transforms_v2 = transforms_v2
@@ -443,6 +451,7 @@ def __init__(
443451
self._device = device
444452
self._seek_mode = seek_mode
445453

454+
def init_decode_and_resize(self):
446455
from torchvision.transforms import v2 as transforms_v2
447456

448457
self.transforms_v2 = transforms_v2
@@ -546,10 +555,12 @@ def __init__(self, stream_index: str | None = None):
546555

547556
self.torchaudio = torchaudio
548557

558+
self._stream_index = int(stream_index) if stream_index else None
559+
560+
def init_decode_and_resize(self):
549561
from torchvision.transforms import v2 as transforms_v2
550562

551563
self.transforms_v2 = transforms_v2
552-
self._stream_index = int(stream_index) if stream_index else None
553564

554565
def decode_frames(self, video_file, pts_list):
555566
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
@@ -883,6 +894,7 @@ def run_benchmarks(
883894

884895
if dataloader_parameters:
885896
bp = dataloader_parameters.batch_parameters
897+
decoder.init_decode_and_resize()
886898
description = (
887899
f"concurrency {bp.num_threads}"
888900
f"batch {bp.batch_size}"

0 commit comments

Comments
 (0)