@@ -646,13 +646,18 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
646646
647647 @pytest .mark .parametrize ("device" , all_supported_devices ())
648648 @pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
649- def test_get_frames_played_at (self , device , seek_mode ):
649+ @pytest .mark .parametrize ("input_type" , ("list" , "tensor" ))
650+ def test_get_frames_played_at (self , device , seek_mode , input_type ):
650651 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
651652 device , _ = unsplit_device_str (device )
652653
653654 # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
654655 # index 35. We use those indices as reference to test against.
655- seconds = [0.84 , 1.17 , 0.85 ]
656+ if input_type == "list" :
657+ seconds = [0.84 , 1.17 , 0.85 ]
658+ else : # tensor
659+ seconds = torch .tensor ([0.84 , 1.17 , 0.85 ])
660+
656661 reference_indices = [25 , 35 , 25 ]
657662 frames = decoder .get_frames_played_at (seconds )
658663
@@ -694,7 +699,9 @@ def test_get_frames_played_at_fails(self, device, seek_mode):
694699 with pytest .raises (RuntimeError , match = "must be less than" ):
695700 decoder .get_frames_played_at ([14 ])
696701
697- with pytest .raises (RuntimeError , match = "Expected a value of type" ):
702+ with pytest .raises (
703+ ValueError , match = "Couldn't convert timestamps input to a tensor"
704+ ):
698705 decoder .get_frames_played_at (["bad" ])
699706
700707 @pytest .mark .parametrize ("device" , all_supported_devices ())
0 commit comments