5555 SINE_MONO_S32 ,
5656 SINE_MONO_S32_44100 ,
5757 SINE_MONO_S32_8000 ,
58+ unsplit_device_str ,
5859)
5960
6061torch ._dynamo .config .capture_dynamic_output_shape_ops = True
@@ -66,7 +67,8 @@ class TestVideoDecoderOps:
6667 @pytest .mark .parametrize ("device" , all_supported_devices ())
6768 def test_seek_and_next (self , device ):
6869 decoder = create_from_file (str (NASA_VIDEO .path ))
69- add_video_stream (decoder , device = device )
70+ device , device_variant = unsplit_device_str (device )
71+ add_video_stream (decoder , device = device , device_variant = device_variant )
7072 frame0 , _ , _ = get_next_frame (decoder )
7173 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
7274 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -83,7 +85,8 @@ def test_seek_and_next(self, device):
8385 @pytest .mark .parametrize ("device" , all_supported_devices ())
8486 def test_seek_to_negative_pts (self , device ):
8587 decoder = create_from_file (str (NASA_VIDEO .path ))
86- add_video_stream (decoder , device = device )
88+ device , device_variant = unsplit_device_str (device )
89+ add_video_stream (decoder , device = device , device_variant = device_variant )
8790 frame0 , _ , _ = get_next_frame (decoder )
8891 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
8992 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -95,7 +98,8 @@ def test_seek_to_negative_pts(self, device):
9598 @pytest .mark .parametrize ("device" , all_supported_devices ())
9699 def test_get_frame_at_pts (self , device ):
97100 decoder = create_from_file (str (NASA_VIDEO .path ))
98- add_video_stream (decoder , device = device )
101+ device , device_variant = unsplit_device_str (device )
102+ add_video_stream (decoder , device = device , device_variant = device_variant )
99103 # This frame has pts=6.006 and duration=0.033367, so it should be visible
100104 # at timestamps in the range [6.006, 6.039367) (not including the last timestamp).
101105 frame6 , _ , _ = get_frame_at_pts (decoder , 6.006 )
@@ -119,7 +123,8 @@ def test_get_frame_at_pts(self, device):
119123 @pytest .mark .parametrize ("device" , all_supported_devices ())
120124 def test_get_frame_at_index (self , device ):
121125 decoder = create_from_file (str (NASA_VIDEO .path ))
122- add_video_stream (decoder , device = device )
126+ device , device_variant = unsplit_device_str (device )
127+ add_video_stream (decoder , device = device , device_variant = device_variant )
123128 frame0 , _ , _ = get_frame_at_index (decoder , frame_index = 0 )
124129 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
125130 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -137,7 +142,8 @@ def test_get_frame_at_index(self, device):
137142 @pytest .mark .parametrize ("device" , all_supported_devices ())
138143 def test_get_frame_with_info_at_index (self , device ):
139144 decoder = create_from_file (str (NASA_VIDEO .path ))
140- add_video_stream (decoder , device = device )
145+ device , device_variant = unsplit_device_str (device )
146+ add_video_stream (decoder , device = device , device_variant = device_variant )
141147 frame6 , pts , duration = get_frame_at_index (decoder , frame_index = 180 )
142148 reference_frame6 = NASA_VIDEO .get_frame_data_by_index (
143149 INDEX_OF_FRAME_AT_6_SECONDS
@@ -149,7 +155,8 @@ def test_get_frame_with_info_at_index(self, device):
149155 @pytest .mark .parametrize ("device" , all_supported_devices ())
150156 def test_get_frames_at_indices (self , device ):
151157 decoder = create_from_file (str (NASA_VIDEO .path ))
152- add_video_stream (decoder , device = device )
158+ device , device_variant = unsplit_device_str (device )
159+ add_video_stream (decoder , device = device , device_variant = device_variant )
153160 frames0and180 , * _ = get_frames_at_indices (decoder , frame_indices = [0 , 180 ])
154161 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
155162 reference_frame180 = NASA_VIDEO .get_frame_data_by_index (
@@ -161,7 +168,8 @@ def test_get_frames_at_indices(self, device):
161168 @pytest .mark .parametrize ("device" , all_supported_devices ())
162169 def test_get_frames_at_indices_unsorted_indices (self , device ):
163170 decoder = create_from_file (str (NASA_VIDEO .path ))
164- _add_video_stream (decoder , device = device )
171+ device , device_variant = unsplit_device_str (device )
172+ add_video_stream (decoder , device = device , device_variant = device_variant )
165173
166174 frame_indices = [2 , 0 , 1 , 0 , 2 ]
167175
@@ -188,7 +196,8 @@ def test_get_frames_at_indices_unsorted_indices(self, device):
188196 @pytest .mark .parametrize ("device" , all_supported_devices ())
189197 def test_get_frames_at_indices_negative_indices (self , device ):
190198 decoder = create_from_file (str (NASA_VIDEO .path ))
191- add_video_stream (decoder , device = device )
199+ device , device_variant = unsplit_device_str (device )
200+ add_video_stream (decoder , device = device , device_variant = device_variant )
192201 frames389and387and1 , * _ = get_frames_at_indices (
193202 decoder , frame_indices = [- 1 , - 3 , - 389 ]
194203 )
@@ -202,7 +211,8 @@ def test_get_frames_at_indices_negative_indices(self, device):
202211 @pytest .mark .parametrize ("device" , all_supported_devices ())
203212 def test_get_frames_at_indices_fail_on_invalid_negative_indices (self , device ):
204213 decoder = create_from_file (str (NASA_VIDEO .path ))
205- add_video_stream (decoder , device = device )
214+ device , device_variant = unsplit_device_str (device )
215+ add_video_stream (decoder , device = device , device_variant = device_variant )
206216 with pytest .raises (
207217 IndexError ,
208218 match = "negative indices must have an absolute value less than the number of frames" ,
@@ -214,7 +224,8 @@ def test_get_frames_at_indices_fail_on_invalid_negative_indices(self, device):
214224 @pytest .mark .parametrize ("device" , all_supported_devices ())
215225 def test_get_frames_by_pts (self , device ):
216226 decoder = create_from_file (str (NASA_VIDEO .path ))
217- _add_video_stream (decoder , device = device )
227+ device , device_variant = unsplit_device_str (device )
228+ add_video_stream (decoder , device = device , device_variant = device_variant )
218229
219230 # Note: 13.01 should give the last video frame for the NASA video
220231 timestamps = [2 , 0 , 1 , 0 + 1e-3 , 13.01 , 2 + 1e-3 ]
@@ -246,7 +257,8 @@ def test_pts_apis_against_index_ref(self, device):
246257 # APIs exactly where those frames are supposed to start. We assert that
247258 # we get the expected frame.
248259 decoder = create_from_file (str (NASA_VIDEO .path ))
249- add_video_stream (decoder , device = device )
260+ device , device_variant = unsplit_device_str (device )
261+ add_video_stream (decoder , device = device , device_variant = device_variant )
250262
251263 metadata = get_json_metadata (decoder )
252264 metadata_dict = json .loads (metadata )
@@ -297,7 +309,8 @@ def test_pts_apis_against_index_ref(self, device):
297309 @pytest .mark .parametrize ("device" , all_supported_devices ())
298310 def test_get_frames_in_range (self , device ):
299311 decoder = create_from_file (str (NASA_VIDEO .path ))
300- add_video_stream (decoder , device = device )
312+ device , device_variant = unsplit_device_str (device )
313+ add_video_stream (decoder , device = device , device_variant = device_variant )
301314
302315 # ensure that the degenerate case of a range of size 1 works
303316 ref_frame0 = NASA_VIDEO .get_frame_data_by_range (0 , 1 )
@@ -337,7 +350,8 @@ def test_get_frames_in_range(self, device):
337350 @pytest .mark .parametrize ("device" , all_supported_devices ())
338351 def test_throws_exception_at_eof (self , device ):
339352 decoder = create_from_file (str (NASA_VIDEO .path ))
340- add_video_stream (decoder , device = device )
353+ device , device_variant = unsplit_device_str (device )
354+ add_video_stream (decoder , device = device , device_variant = device_variant )
341355
342356 seek_to_pts (decoder , 12.979633 )
343357 last_frame , _ , _ = get_next_frame (decoder )
@@ -352,7 +366,8 @@ def test_throws_exception_at_eof(self, device):
352366 @pytest .mark .parametrize ("device" , all_supported_devices ())
353367 def test_throws_exception_if_seek_too_far (self , device ):
354368 decoder = create_from_file (str (NASA_VIDEO .path ))
355- add_video_stream (decoder , device = device )
369+ device , device_variant = unsplit_device_str (device )
370+ add_video_stream (decoder , device = device , device_variant = device_variant )
356371 # pts=12.979633 is the last frame in the video.
357372 seek_to_pts (decoder , 12.979633 + 1.0e-4 )
358373 with pytest .raises (IndexError , match = "no more frames" ):
@@ -363,9 +378,11 @@ def test_compile_seek_and_next(self, device):
363378 # TODO_OPEN_ISSUE Scott (T180277797): Get this to work with the inductor stack. Right now
364379 # compilation fails because it can't handle tensors of size unknown at
365380 # compile-time.
381+ device , device_variant = unsplit_device_str (device )
382+
366383 @torch .compile (fullgraph = True , backend = "eager" )
367384 def get_frame1_and_frame_time6 (decoder ):
368- add_video_stream (decoder , device = device )
385+ add_video_stream (decoder , device = device , device_variant = device_variant )
369386 frame0 , _ , _ = get_next_frame (decoder )
370387 seek_to_pts (decoder , 6.0 )
371388 frame_time6 , _ , _ = get_next_frame (decoder )
@@ -408,7 +425,8 @@ def test_create_decoder(self, create_from, device):
408425 else :
409426 raise ValueError ("Oops, double check the parametrization of this test!" )
410427
411- add_video_stream (decoder , device = device )
428+ device , device_variant = unsplit_device_str (device )
429+ add_video_stream (decoder , device = device , device_variant = device_variant )
412430 frame0 , _ , _ = get_next_frame (decoder )
413431 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
414432 assert_frames_equal (frame0 , reference_frame0 .to (device ))
@@ -536,9 +554,11 @@ def test_seek_mode_custom_frame_mappings(self, device):
536554 decoder = create_from_file (
537555 str (NASA_VIDEO .path ), seek_mode = "custom_frame_mappings"
538556 )
557+ device , device_variant = unsplit_device_str (device )
539558 add_video_stream (
540559 decoder ,
541560 device = device ,
561+ device_variant = device_variant ,
542562 stream_index = stream_index ,
543563 custom_frame_mappings = NASA_VIDEO .get_custom_frame_mappings (
544564 stream_index = stream_index
@@ -1067,7 +1087,8 @@ def seek(self, offset: int, whence: int) -> int:
10671087 open (NASA_VIDEO .path , mode = "rb" , buffering = buffering )
10681088 )
10691089 decoder = create_from_file_like (file_counter , "approximate" )
1070- add_video_stream (decoder , device = device )
1090+ device , device_variant = unsplit_device_str (device )
1091+ add_video_stream (decoder , device = device , device_variant = device_variant )
10711092
10721093 frame0 , * _ = get_next_frame (decoder )
10731094 reference_frame0 = NASA_VIDEO .get_frame_data_by_index (0 )
0 commit comments