@@ -328,6 +328,19 @@ def test_getitem_slice(self, device, seek_mode):
328328 )
329329 assert_frames_equal (ref386_389 , slice386_389 )
330330
331+ # slices with upper bound greater than len(decoder) are supported
332+ slice387_389 = decoder [- 3 :10000 ].to (device )
333+ assert slice387_389 .shape == torch .Size (
334+ [
335+ 3 ,
336+ NASA_VIDEO .num_color_channels ,
337+ NASA_VIDEO .height ,
338+ NASA_VIDEO .width ,
339+ ]
340+ )
341+ ref387_389 = NASA_VIDEO .get_frame_data_by_range (387 , 390 ).to (device )
342+ assert_frames_equal (ref387_389 , slice387_389 )
343+
331344 # an empty range is valid!
332345 empty_frame = decoder [5 :5 ]
333346 assert_frames_equal (empty_frame , NASA_VIDEO .empty_chw_tensor .to (device ))
@@ -437,6 +450,11 @@ def test_get_frame_at(self, device, seek_mode):
437450 expected_frame_info .duration_seconds , rel = 1e-3
438451 )
439452
453+ # test negative frame index
454+ frame_minus1 = decoder .get_frame_at (- 1 )
455+ ref_frame_minus1 = NASA_VIDEO .get_frame_data_by_index (389 ).to (device )
456+ assert_frames_equal (ref_frame_minus1 , frame_minus1 .data )
457+
440458 # test numpy.int64
441459 frame9 = decoder .get_frame_at (numpy .int64 (9 ))
442460 assert_frames_equal (ref_frame9 , frame9 .data )
@@ -470,7 +488,7 @@ def test_get_frame_at_fails(self, device, seek_mode):
470488 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
471489
472490 with pytest .raises (IndexError , match = "out of bounds" ):
473- frame = decoder .get_frame_at (- 1 ) # noqa
491+ frame = decoder .get_frame_at (- 10000 ) # noqa
474492
475493 with pytest .raises (IndexError , match = "out of bounds" ):
476494 frame = decoder .get_frame_at (10000 ) # noqa
@@ -480,7 +498,8 @@ def test_get_frame_at_fails(self, device, seek_mode):
480498 def test_get_frames_at (self , device , seek_mode ):
481499 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
482500
483- frames = decoder .get_frames_at ([35 , 25 ])
501+ # test positive and negative frame index
502+ frames = decoder .get_frames_at ([35 , 25 , - 1 , - 2 ])
484503
485504 assert isinstance (frames , FrameBatch )
486505
@@ -490,12 +509,20 @@ def test_get_frames_at(self, device, seek_mode):
490509 assert_frames_equal (
491510 frames [1 ].data , NASA_VIDEO .get_frame_data_by_index (25 ).to (device )
492511 )
512+ assert_frames_equal (
513+ frames [2 ].data , NASA_VIDEO .get_frame_data_by_index (389 ).to (device )
514+ )
515+ assert_frames_equal (
516+ frames [3 ].data , NASA_VIDEO .get_frame_data_by_index (388 ).to (device )
517+ )
493518
494519 assert frames .pts_seconds .device .type == "cpu"
495520 expected_pts_seconds = torch .tensor (
496521 [
497522 NASA_VIDEO .get_frame_info (35 ).pts_seconds ,
498523 NASA_VIDEO .get_frame_info (25 ).pts_seconds ,
524+ NASA_VIDEO .get_frame_info (389 ).pts_seconds ,
525+ NASA_VIDEO .get_frame_info (388 ).pts_seconds ,
499526 ],
500527 dtype = torch .float64 ,
501528 )
@@ -508,6 +535,8 @@ def test_get_frames_at(self, device, seek_mode):
508535 [
509536 NASA_VIDEO .get_frame_info (35 ).duration_seconds ,
510537 NASA_VIDEO .get_frame_info (25 ).duration_seconds ,
538+ NASA_VIDEO .get_frame_info (389 ).duration_seconds ,
539+ NASA_VIDEO .get_frame_info (388 ).duration_seconds ,
511540 ],
512541 dtype = torch .float64 ,
513542 )
@@ -520,8 +549,11 @@ def test_get_frames_at(self, device, seek_mode):
520549 def test_get_frames_at_fails (self , device , seek_mode ):
521550 decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
522551
523- with pytest .raises (RuntimeError , match = "Invalid frame index=-1" ):
524- decoder .get_frames_at ([- 1 ])
552+ expected_converted_index = - 10000 + len (decoder )
553+ with pytest .raises (
554+ RuntimeError , match = f"Invalid frame index={ expected_converted_index } "
555+ ):
556+ decoder .get_frames_at ([- 10000 ])
525557
526558 with pytest .raises (RuntimeError , match = "Invalid frame index=390" ):
527559 decoder .get_frames_at ([390 ])
0 commit comments