@@ -13,110 +13,110 @@ final class LogitsWarperTests: XCTestCase {
1313
1414 func testTemperatureLogitsWarper( ) {
1515 let result1 = TemperatureLogitsWarper ( temperature: 0.0 ) ( [ ] , [ ] )
16- XCTAssertTrue ( result1. indexes . isEmpty)
16+ XCTAssertTrue ( result1. indices . isEmpty)
1717 XCTAssertTrue ( result1. logits. isEmpty)
1818
1919 let result2 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ ] , [ ] )
20- XCTAssertTrue ( result2. indexes . isEmpty)
20+ XCTAssertTrue ( result2. indices . isEmpty)
2121 XCTAssertTrue ( result2. logits. isEmpty)
2222
2323 let result3 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
24- XCTAssertEqual ( result3. indexes , [ 0 , 1 ] )
24+ XCTAssertEqual ( result3. indices , [ 0 , 1 ] )
2525 XCTAssertEqual ( result3. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
2626
2727 let result4 = TemperatureLogitsWarper ( temperature: 2.0 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
28- XCTAssertEqual ( result4. indexes , [ 0 , 1 ] )
28+ XCTAssertEqual ( result4. indices , [ 0 , 1 ] )
2929 XCTAssertEqual ( result4. logits, [ 1.0 , 0.5 ] , accuracy: accuracy)
3030
3131 let result5 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
32- XCTAssertEqual ( result5. indexes , [ 0 , 1 ] )
32+ XCTAssertEqual ( result5. indices , [ 0 , 1 ] )
3333 XCTAssertEqual ( result5. logits, [ 4.0 , 2.0 ] , accuracy: accuracy)
3434
3535 let result6 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 200 , 100 ] , [ 2.0 , 1.0 ] )
36- XCTAssertEqual ( result6. indexes , [ 200 , 100 ] )
36+ XCTAssertEqual ( result6. indices , [ 200 , 100 ] )
3737 XCTAssertEqual ( result6. logits, [ 4.0 , 2.0 ] , accuracy: accuracy)
3838 }
3939
4040 func testTopKLogitsWarper( ) {
4141 let result1 = TopKLogitsWarper ( k: 0 ) ( [ ] , [ ] )
42- XCTAssertTrue ( result1. indexes . isEmpty)
42+ XCTAssertTrue ( result1. indices . isEmpty)
4343 XCTAssertTrue ( result1. logits. isEmpty)
4444
4545 let result2 = TopKLogitsWarper ( k: 3 ) ( [ ] , [ ] )
46- XCTAssertTrue ( result2. indexes . isEmpty)
46+ XCTAssertTrue ( result2. indices . isEmpty)
4747 XCTAssertTrue ( result2. logits. isEmpty)
4848
4949 let result3 = TopKLogitsWarper ( k: 3 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
50- XCTAssertEqual ( result3. indexes , [ 0 , 1 ] )
50+ XCTAssertEqual ( result3. indices , [ 0 , 1 ] )
5151 XCTAssertEqual ( result3. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
5252
5353 let result4 = TopKLogitsWarper ( k: 3 ) ( [ 0 , 1 , 2 ] , [ 2.0 , 1.0 , 3.0 ] )
54- XCTAssertEqual ( result4. indexes , [ 2 , 0 , 1 ] )
54+ XCTAssertEqual ( result4. indices , [ 2 , 0 , 1 ] )
5555 XCTAssertEqual ( result4. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
5656
5757 let result5 = TopKLogitsWarper ( k: 4 ) ( [ 0 , 1 , 2 , 3 , 4 , 5 ] , [ 2.0 , 1.0 , 3.0 , - 1.0 , 123.0 , 0.0 ] )
58- XCTAssertEqual ( result5. indexes , [ 4 , 2 , 0 , 1 ] )
58+ XCTAssertEqual ( result5. indices , [ 4 , 2 , 0 , 1 ] )
5959 XCTAssertEqual ( result5. logits, [ 123.0 , 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
6060
6161 let result6 = TopKLogitsWarper ( k: 3 ) ( [ 10 , 1 , 52 ] , [ 2.0 , 1.0 , 3.0 ] )
62- XCTAssertEqual ( result6. indexes , [ 52 , 10 , 1 ] )
62+ XCTAssertEqual ( result6. indices , [ 52 , 10 , 1 ] )
6363 XCTAssertEqual ( result6. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
6464 }
6565
6666 func testTopPLogitsWarper( ) {
6767 let result1 = TopPLogitsWarper ( p: 0.99 ) ( [ ] , [ ] )
68- XCTAssertTrue ( result1. indexes . isEmpty)
68+ XCTAssertTrue ( result1. indices . isEmpty)
6969 XCTAssertTrue ( result1. logits. isEmpty)
7070
7171 let logits = ( 0 ..< 10 ) . map { Float ( $0) }
72- let indexes = Array ( logits. indices)
73- let result2 = TopPLogitsWarper ( p: 0.99 ) ( indexes , logits)
74- XCTAssertEqual ( result2. indexes , [ 9 , 8 , 7 , 6 , 5 ] )
72+ let indices = Array ( logits. indices)
73+ let result2 = TopPLogitsWarper ( p: 0.99 ) ( indices , logits)
74+ XCTAssertEqual ( result2. indices , [ 9 , 8 , 7 , 6 , 5 ] )
7575 XCTAssertEqual ( result2. logits, [ 9.0 , 8.0 , 7.0 , 6.0 , 5.0 ] , accuracy: accuracy)
7676
77- let result3 = TopPLogitsWarper ( p: 0.95 ) ( indexes , logits)
78- XCTAssertEqual ( result3. indexes , [ 9 , 8 , 7 ] )
77+ let result3 = TopPLogitsWarper ( p: 0.95 ) ( indices , logits)
78+ XCTAssertEqual ( result3. indices , [ 9 , 8 , 7 ] )
7979 XCTAssertEqual ( result3. logits, [ 9.0 , 8.0 , 7.0 ] , accuracy: accuracy)
8080
81- let result4 = TopPLogitsWarper ( p: 0.6321493 ) ( indexes , logits)
82- XCTAssertEqual ( result4. indexes , [ 9 , 8 ] )
81+ let result4 = TopPLogitsWarper ( p: 0.6321493 ) ( indices , logits)
82+ XCTAssertEqual ( result4. indices , [ 9 , 8 ] )
8383 XCTAssertEqual ( result4. logits, [ 9.0 , 8.0 ] , accuracy: accuracy)
8484
8585 let result5 = TopPLogitsWarper ( p: 0.95 ) ( [ 3 , 1 , 8 ] , [ 0 , 1 , 2 ] )
86- XCTAssertEqual ( result5. indexes , [ 8 , 1 , 3 ] )
86+ XCTAssertEqual ( result5. indices , [ 8 , 1 , 3 ] )
8787 XCTAssertEqual ( result5. logits, [ 2 , 1 , 0 ] , accuracy: accuracy)
8888 }
8989
9090 func testLogitsProcessor( ) {
9191 let processor1 = LogitsProcessor ( logitsWarpers: [ ] )
9292 let result1 = processor1 ( [ ] )
93- XCTAssertTrue ( result1. indexes . isEmpty)
93+ XCTAssertTrue ( result1. indices . isEmpty)
9494 XCTAssertTrue ( result1. logits. isEmpty)
9595
9696 let processor2 = LogitsProcessor ( logitsWarpers: [ ] )
9797 let result2 = processor2 ( [ 2.0 , 1.0 ] )
98- XCTAssertEqual ( result2. indexes , [ 0 , 1 ] )
98+ XCTAssertEqual ( result2. indices , [ 0 , 1 ] )
9999 XCTAssertEqual ( result2. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
100100
101101 let processor3 = LogitsProcessor (
102102 logitsWarpers: [ TopKLogitsWarper ( k: 3 ) ]
103103 )
104104 let result3 = processor3 ( [ 2.0 , 1.0 , 3.0 , - 5.0 ] )
105- XCTAssertEqual ( result3. indexes , [ 2 , 0 , 1 ] )
105+ XCTAssertEqual ( result3. indices , [ 2 , 0 , 1 ] )
106106 XCTAssertEqual ( result3. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
107107
108108 let processor4 = LogitsProcessor (
109109 logitsWarpers: [ TopKLogitsWarper ( k: 3 ) , TopPLogitsWarper ( p: 0.99 ) ]
110110 )
111111 let result4 = processor4 ( [ 2.0 , 1.0 , 3.0 , - 5.0 , - 23.0 , 12.5 ] )
112- XCTAssertEqual ( result4. indexes , [ 5 ] )
112+ XCTAssertEqual ( result4. indices , [ 5 ] )
113113 XCTAssertEqual ( result4. logits, [ 12.5 ] , accuracy: accuracy)
114114
115115 let processor5 = LogitsProcessor (
116116 logitsWarpers: [ TopKLogitsWarper ( k: 4 ) , TopPLogitsWarper ( p: 0.99 ) ]
117117 )
118118 let result5 = processor5 ( [ 2.0 , 1.0 , 3.0 , - 5.0 , - 3.0 , 4.5 ] )
119- XCTAssertEqual ( result5. indexes , [ 5 , 2 , 0 , 1 ] )
119+ XCTAssertEqual ( result5. indices , [ 5 , 2 , 0 , 1 ] )
120120 XCTAssertEqual ( result5. logits, [ 4.5 , 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
121121 }
122122}
0 commit comments