@@ -47,6 +47,9 @@ def decode_first_n_frames(self, video_file, n):
4747 def decode_first_n_frames_description (self , n ) -> str :
4848 return f"first { n } frames"
4949
50+ def init_decode_and_resize (self ):
51+ pass
52+
5053 @abc .abstractmethod
5154 def decode_and_resize (self , video_file , pts_list , height , width , device ):
5255 pass
@@ -107,9 +110,12 @@ def __init__(self, backend):
107110 self ._backend = backend
108111 self ._print_each_iteration_time = False
109112 import torchvision # noqa: F401
110- from torchvision .transforms import v2 as transforms_v2
111113
112114 self .torchvision = torchvision
115+
116+ def init_decode_and_resize (self ):
117+ from torchvision .transforms import v2 as transforms_v2
118+
113119 self .transforms_v2 = transforms_v2
114120
115121 def decode_frames (self , video_file , pts_list ):
@@ -267,6 +273,7 @@ def __init__(
267273 self ._color_conversion_library = color_conversion_library
268274 self ._device = device
269275
276+ def init_decode_and_resize (self ):
270277 from torchvision .transforms import v2 as transforms_v2
271278
272279 self .transforms_v2 = transforms_v2
@@ -379,6 +386,7 @@ def __init__(
379386 self ._seek_mode = seek_mode
380387 self ._stream_index = int (stream_index ) if stream_index else None
381388
389+ def init_decode_and_resize (self ):
382390 from torchvision .transforms import v2 as transforms_v2
383391
384392 self .transforms_v2 = transforms_v2
@@ -443,6 +451,7 @@ def __init__(
443451 self ._device = device
444452 self ._seek_mode = seek_mode
445453
454+ def init_decode_and_resize (self ):
446455 from torchvision .transforms import v2 as transforms_v2
447456
448457 self .transforms_v2 = transforms_v2
@@ -546,10 +555,12 @@ def __init__(self, stream_index: str | None = None):
546555
547556 self .torchaudio = torchaudio
548557
558+ self ._stream_index = int (stream_index ) if stream_index else None
559+
560+ def init_decode_and_resize (self ):
549561 from torchvision .transforms import v2 as transforms_v2
550562
551563 self .transforms_v2 = transforms_v2
552- self ._stream_index = int (stream_index ) if stream_index else None
553564
554565 def decode_frames (self , video_file , pts_list ):
555566 stream_reader = self .torchaudio .io .StreamReader (src = video_file )
@@ -883,6 +894,7 @@ def run_benchmarks(
883894
884895 if dataloader_parameters :
885896 bp = dataloader_parameters .batch_parameters
897+ decoder .init_decode_and_resize ()
886898 description = (
887899 f"concurrency { bp .num_threads } "
888900 f"batch { bp .batch_size } "
0 commit comments