@@ -43,15 +43,15 @@ def test_slice_without_cu_num_generated_tokens(self):
4343 cu_num_generated_tokens = None ,
4444 )
4545
46- sliced = logprobsLists .slice (1 , 3 )
46+ sliced = logprobsLists .slice_request (1 , num_positions = 2 )
4747 assert sliced .logprob_token_ids == [[2 ], [3 ]]
4848 assert sliced .logprobs == [[0.2 ], [0.3 ]]
4949 assert sliced .sampled_token_ranks == [2 , 3 ]
5050 assert sliced .cu_num_generated_tokens is None
5151
5252 def test_slice_from_start (self ):
5353 """Test slicing from the start position"""
54- sliced = self .logprobsLists .slice (0 , 2 )
54+ sliced = self .logprobsLists .slice_request (0 , num_positions = 5 )
5555 assert len (sliced .logprob_token_ids ) == 5
5656 assert sliced .logprob_token_ids == [
5757 [1 , 2 ],
@@ -60,11 +60,11 @@ def test_slice_from_start(self):
6060 [7 , 8 ],
6161 [9 , 10 ],
6262 ]
63- assert sliced .cu_num_generated_tokens == [ 0 , 2 , 5 ]
63+ assert sliced .cu_num_generated_tokens is None
6464
6565 def test_slice_from_middle (self ):
6666 """Test slicing from the middle position"""
67- sliced = self .logprobsLists .slice (1 , 3 )
67+ sliced = self .logprobsLists .slice_request (1 , num_positions = 7 )
6868 assert len (sliced .logprob_token_ids ) == 7
6969 assert sliced .logprob_token_ids == [
7070 [5 , 6 ],
@@ -75,27 +75,25 @@ def test_slice_from_middle(self):
7575 [15 , 16 ],
7676 [17 , 18 ],
7777 ]
78- assert sliced .cu_num_generated_tokens == [ 0 , 3 , 7 ]
78+ assert sliced .cu_num_generated_tokens is None
7979
8080 def test_slice_single_request (self ):
8181 """Test slicing a single request"""
82- sliced = self .logprobsLists .slice (1 , 2 )
82+ sliced = self .logprobsLists .slice_request (1 , num_positions = 3 )
8383 assert len (sliced .logprob_token_ids ) == 3
8484 assert sliced .logprob_token_ids == [[5 , 6 ], [7 , 8 ], [9 , 10 ]]
85- assert sliced .cu_num_generated_tokens == [ 0 , 3 ]
85+ assert sliced .cu_num_generated_tokens is None
8686
8787 def test_slice_last_request (self ):
8888 """Test slicing the last request"""
89- sliced = self .logprobsLists .slice (2 , 3 )
89+ sliced = self .logprobsLists .slice_request (2 , num_positions = 4 )
9090 assert len (sliced .logprob_token_ids ) == 4
9191 assert sliced .logprob_token_ids == [[11 , 12 ], [13 , 14 ], [15 , 16 ], [17 , 18 ]]
92- assert sliced .cu_num_generated_tokens == [ 0 , 4 ]
92+ assert sliced .cu_num_generated_tokens is None
9393
9494 def test_slice_all_requests (self ):
9595 """Test slicing all requests (full slice)"""
96- sliced = self .logprobsLists .slice (0 , 3 )
96+ sliced = self .logprobsLists .slice_request (0 , num_positions = 9 )
9797 assert len (sliced .logprob_token_ids ) == 9 # All tokens
9898 assert sliced .logprob_token_ids == self .logprobsLists .logprob_token_ids
99- assert (
100- sliced .cu_num_generated_tokens == self .logprobsLists .cu_num_generated_tokens
101- )
99+ assert sliced .cu_num_generated_tokens is None
0 commit comments