5555 SINE_MONO_S32 ,
5656 SINE_MONO_S32_44100 ,
5757 SINE_MONO_S32_8000 ,
58+ unsplit_device_str ,
5859)
5960
6061torch ._dynamo .config .capture_dynamic_output_shape_ops = True
@@ -66,7 +67,8 @@ class TestVideoDecoderOps:
6667 @pytest .mark .parametrize ("device" , all_supported_devices ())
6768 def test_seek_and_next (self , device ):
6869 decoder = create_from_file (str (NASA_VIDEO .path ))
69- add_video_stream (decoder , device = device )
70+ device , device_variant = unsplit_device_str (device )
71+ add_video_stream (decoder , device = device , device_variant = device_variant )
7072 frame0 , _ , _ = get_next_frame (decoder )
7173 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
7274 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -83,7 +85,8 @@ def test_seek_and_next(self, device):
8385 @pytest .mark .parametrize ("device" , all_supported_devices ())
8486 def test_seek_to_negative_pts (self , device ):
8587 decoder = create_from_file (str (NASA_VIDEO .path ))
86- add_video_stream (decoder , device = device )
88+ device , device_variant = unsplit_device_str (device )
89+ add_video_stream (decoder , device = device , device_variant = device_variant )
8790 frame0 , _ , _ = get_next_frame (decoder )
8891 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
8992 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -95,7 +98,8 @@ def test_seek_to_negative_pts(self, device):
9598 @pytest .mark .parametrize ("device" , all_supported_devices ())
9699 def test_get_frame_at_pts (self , device ):
97100 decoder = create_from_file (str (NASA_VIDEO .path ))
98- add_video_stream (decoder , device = device )
101+ device , device_variant = unsplit_device_str (device )
102+ add_video_stream (decoder , device = device , device_variant = device_variant )
99103 # This frame has pts=6.006 and duration=0.033367, so it should be visible
100104 # at timestamps in the range [6.006, 6.039367) (not including the last timestamp).
101105 frame6 , _ , _ = get_frame_at_pts (decoder , 6.006 )
@@ -119,7 +123,8 @@ def test_get_frame_at_pts(self, device):
119123 @pytest .mark .parametrize ("device" , all_supported_devices ())
120124 def test_get_frame_at_index (self , device ):
121125 decoder = create_from_file (str (NASA_VIDEO .path ))
122- add_video_stream (decoder , device = device )
126+ device , device_variant = unsplit_device_str (device )
127+ add_video_stream (decoder , device = device , device_variant = device_variant )
123128 frame0 , _ , _ = get_frame_at_index (decoder , frame_index = 0 )
124129 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
125130 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -137,7 +142,8 @@ def test_get_frame_at_index(self, device):
137142 @pytest .mark .parametrize ("device" , all_supported_devices ())
138143 def test_get_frame_with_info_at_index (self , device ):
139144 decoder = create_from_file (str (NASA_VIDEO .path ))
140- add_video_stream (decoder , device = device )
145+ device , device_variant = unsplit_device_str (device )
146+ add_video_stream (decoder , device = device , device_variant = device_variant )
141147 frame6 , pts , duration = get_frame_at_index (decoder , frame_index = 180 )
142148 reference_frame6 = NASA_VIDEO .get_frame_data_by_index (
143149 INDEX_OF_FRAME_AT_6_SECONDS
@@ -149,7 +155,8 @@ def test_get_frame_with_info_at_index(self, device):
149155 @pytest .mark .parametrize ("device" , all_supported_devices ())
150156 def test_get_frames_at_indices (self , device ):
151157 decoder = create_from_file (str (NASA_VIDEO .path ))
152- add_video_stream (decoder , device = device )
158+ device , device_variant = unsplit_device_str (device )
159+ add_video_stream (decoder , device = device , device_variant = device_variant )
153160 frames0and180 , * _ = get_frames_at_indices (decoder , frame_indices = [0 , 180 ])
154161 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
155162 reference_frame180 = NASA_VIDEO .get_frame_data_by_index (
@@ -161,7 +168,8 @@ def test_get_frames_at_indices(self, device):
161168 @pytest .mark .parametrize ("device" , all_supported_devices ())
162169 def test_get_frames_at_indices_unsorted_indices (self , device ):
163170 decoder = create_from_file (str (NASA_VIDEO .path ))
164- _add_video_stream (decoder , device = device )
171+ device , device_variant = unsplit_device_str (device )
172+ add_video_stream (decoder , device = device , device_variant = device_variant )
165173
166174 frame_indices = [2 , 0 , 1 , 0 , 2 ]
167175
@@ -188,7 +196,8 @@ def test_get_frames_at_indices_unsorted_indices(self, device):
188196 @pytest .mark .parametrize ("device" , all_supported_devices ())
189197 def test_get_frames_at_indices_negative_indices (self , device ):
190198 decoder = create_from_file (str (NASA_VIDEO .path ))
191- add_video_stream (decoder , device = device )
199+ device , device_variant = unsplit_device_str (device )
200+ add_video_stream (decoder , device = device , device_variant = device_variant )
192201 frames389and387and1 , * _ = get_frames_at_indices (
193202 decoder , frame_indices = [- 1 , - 3 , - 389 ]
194203 )
@@ -202,7 +211,8 @@ def test_get_frames_at_indices_negative_indices(self, device):
202211 @pytest .mark .parametrize ("device" , all_supported_devices ())
203212 def test_get_frames_at_indices_fail_on_invalid_negative_indices (self , device ):
204213 decoder = create_from_file (str (NASA_VIDEO .path ))
205- add_video_stream (decoder , device = device )
214+ device , device_variant = unsplit_device_str (device )
215+ add_video_stream (decoder , device = device , device_variant = device_variant )
206216 with pytest .raises (
207217 IndexError ,
208218 match = "negative indices must have an absolute value less than the number of frames" ,
@@ -214,7 +224,8 @@ def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
214224 @pytest .mark .parametrize ("device" , all_supported_devices ())
215225 def test_get_frames_by_pts (self , device ):
216226 decoder = create_from_file (str (NASA_VIDEO .path ))
217- _add_video_stream (decoder , device = device )
227+ device , device_variant = unsplit_device_str (device )
228+ add_video_stream (decoder , device = device , device_variant = device_variant )
218229
219230 # Note: 13.01 should give the last video frame for the NASA video
220231 timestamps = [2 , 0 , 1 , 0 + 1e-3 , 13.01 , 2 + 1e-3 ]
@@ -246,7 +257,8 @@ def test_pts_apis_against_index_ref(self, device):
246257 # APIs exactly where those frames are supposed to start. We assert that
247258 # we get the expected frame.
248259 decoder = create_from_file (str (NASA_VIDEO .path ))
249- add_video_stream (decoder , device = device )
260+ device , device_variant = unsplit_device_str (device )
261+ add_video_stream (decoder , device = device , device_variant = device_variant )
250262
251263 metadata = get_json_metadata (decoder )
252264 metadata_dict = json .loads (metadata )
@@ -297,7 +309,8 @@ def test_pts_apis_against_index_ref(self, device):
297309 @pytest .mark .parametrize ("device" , all_supported_devices ())
298310 def test_get_frames_in_range (self , device ):
299311 decoder = create_from_file (str (NASA_VIDEO .path ))
300- add_video_stream (decoder , device = device )
312+ device , device_variant = unsplit_device_str (device )
313+ add_video_stream (decoder , device = device , device_variant = device_variant )
301314
302315 # ensure that the degenerate case of a range of size 1 works
303316 ref_frame0 = NASA_VIDEO .get_frame_data_by_range (0 , 1 )
@@ -336,8 +349,11 @@ def test_get_frames_in_range(self, device):
336349
337350 @pytest .mark .parametrize ("device" , all_supported_devices ())
338351 def test_throws_exception_at_eof (self , device ):
352+ if device == "cuda:0:beta" :
353+ pytest .skip ("TODONVDEC P0: this hangs forever, fix this!!" )
339354 decoder = create_from_file (str (NASA_VIDEO .path ))
340- add_video_stream (decoder , device = device )
355+ device , device_variant = unsplit_device_str (device )
356+ add_video_stream (decoder , device = device , device_variant = device_variant )
341357
342358 seek_to_pts (decoder , 12.979633 )
343359 last_frame , _ , _ = get_next_frame (decoder )
@@ -351,8 +367,11 @@ def test_throws_exception_at_eof(self, device):
351367
352368 @pytest .mark .parametrize ("device" , all_supported_devices ())
353369 def test_throws_exception_if_seek_too_far (self , device ):
370+ if device == "cuda:0:beta" :
371+ pytest .skip ("TODONVDEC P0: this hangs forever, fix this!!" )
354372 decoder = create_from_file (str (NASA_VIDEO .path ))
355- add_video_stream (decoder , device = device )
373+ device , device_variant = unsplit_device_str (device )
374+ add_video_stream (decoder , device = device , device_variant = device_variant )
356375 # pts=12.979633 is the last frame in the video.
357376 seek_to_pts (decoder , 12.979633 + 1.0e-4 )
358377 with pytest .raises (IndexError , match = "no more frames" ):
@@ -363,9 +382,11 @@ def test_compile_seek_and_next(self, device):
363382 # TODO_OPEN_ISSUE Scott (T180277797): Get this to work with the inductor stack. Right now
364383 # compilation fails because it can't handle tensors of size unknown at
365384 # compile-time.
385+ device , device_variant = unsplit_device_str (device )
386+
366387 @torch .compile (fullgraph = True , backend = "eager" )
367388 def get_frame1_and_frame_time6 (decoder ):
368- add_video_stream (decoder , device = device )
389+ add_video_stream (decoder , device = device , device_variant = device_variant )
369390 frame0 , _ , _ = get_next_frame (decoder )
370391 seek_to_pts (decoder , 6.0 )
371392 frame_time6 , _ , _ = get_next_frame (decoder )
@@ -408,7 +429,8 @@ def test_create_decoder(self, create_from, device):
408429 else :
409430 raise ValueError ("Oops, double check the parametrization of this test!" )
410431
411- add_video_stream (decoder , device = device )
432+ device , device_variant = unsplit_device_str (device )
433+ add_video_stream (decoder , device = device , device_variant = device_variant )
412434 frame0 , _ , _ = get_next_frame (decoder )
413435 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
414436 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -536,9 +558,11 @@ def test_seek_mode_custom_frame_mappings(self, device):
536558 decoder = create_from_file (
537559 str (NASA_VIDEO .path ), seek_mode = "custom_frame_mappings"
538560 )
561+ device , device_variant = unsplit_device_str (device )
539562 add_video_stream (
540563 decoder ,
541564 device = device ,
565+ device_variant = device_variant ,
542566 stream_index = stream_index ,
543567 custom_frame_mappings = NASA_VIDEO .get_custom_frame_mappings (
544568 stream_index = stream_index
@@ -1067,7 +1091,8 @@ def seek(self, offset: int, whence: int) -> int:
10671091 open (NASA_VIDEO .path , mode = "rb" , buffering = buffering )
10681092 )
10691093 decoder = create_from_file_like (file_counter , "approximate" )
1070- add_video_stream (decoder , device = device )
1094+ device , device_variant = unsplit_device_str (device )
1095+ add_video_stream (decoder , device = device , device_variant = device_variant )
10711096
10721097 frame0 , * _ = get_next_frame (decoder )
10731098 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
0 commit comments