@@ -12,65 +12,79 @@ final class LogitsWarperTests: XCTestCase {
1212 private let accuracy : Float = 0.00001
1313
1414 func testTemperatureLogitsWarper( ) {
15- let result1 = TemperatureLogitsWarper ( temperature: 0.0 ) ( [ ] )
15+ let result1 = TemperatureLogitsWarper ( temperature: 0.0 ) ( [ ] , [ ] )
1616 XCTAssertTrue ( result1. indexes. isEmpty)
1717 XCTAssertTrue ( result1. logits. isEmpty)
1818
19- let result2 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ ] )
19+ let result2 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ ] , [ ] )
2020 XCTAssertTrue ( result2. indexes. isEmpty)
2121 XCTAssertTrue ( result2. logits. isEmpty)
2222
23- let result3 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ 2.0 , 1.0 ] )
23+ let result3 = TemperatureLogitsWarper ( temperature: 1.0 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
2424 XCTAssertEqual ( result3. indexes, [ 0 , 1 ] )
2525 XCTAssertEqual ( result3. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
2626
27- let result4 = TemperatureLogitsWarper ( temperature: 2.0 ) ( [ 2.0 , 1.0 ] )
27+ let result4 = TemperatureLogitsWarper ( temperature: 2.0 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
2828 XCTAssertEqual ( result4. indexes, [ 0 , 1 ] )
2929 XCTAssertEqual ( result4. logits, [ 1.0 , 0.5 ] , accuracy: accuracy)
3030
31- let result5 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 2.0 , 1.0 ] )
31+ let result5 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
3232 XCTAssertEqual ( result5. indexes, [ 0 , 1 ] )
3333 XCTAssertEqual ( result5. logits, [ 4.0 , 2.0 ] , accuracy: accuracy)
34+
35+ let result6 = TemperatureLogitsWarper ( temperature: 0.5 ) ( [ 200 , 100 ] , [ 2.0 , 1.0 ] )
36+ XCTAssertEqual ( result6. indexes, [ 200 , 100 ] )
37+ XCTAssertEqual ( result6. logits, [ 4.0 , 2.0 ] , accuracy: accuracy)
3438 }
3539
3640 func testTopKLogitsWarper( ) {
37- let result1 = TopKLogitsWarper ( k: 0 ) ( [ ] )
41+ let result1 = TopKLogitsWarper ( k: 0 ) ( [ ] , [ ] )
3842 XCTAssertTrue ( result1. indexes. isEmpty)
3943 XCTAssertTrue ( result1. logits. isEmpty)
4044
41- let result2 = TopKLogitsWarper ( k: 3 ) ( [ ] )
45+ let result2 = TopKLogitsWarper ( k: 3 ) ( [ ] , [ ] )
4246 XCTAssertTrue ( result2. indexes. isEmpty)
4347 XCTAssertTrue ( result2. logits. isEmpty)
4448
45- let result3 = TopKLogitsWarper ( k: 3 ) ( [ 2.0 , 1.0 ] )
49+ let result3 = TopKLogitsWarper ( k: 3 ) ( [ 0 , 1 ] , [ 2.0 , 1.0 ] )
4650 XCTAssertEqual ( result3. indexes, [ 0 , 1 ] )
4751 XCTAssertEqual ( result3. logits, [ 2.0 , 1.0 ] , accuracy: accuracy)
4852
49- let result4 = TopKLogitsWarper ( k: 3 ) ( [ 2.0 , 1.0 , 3.0 ] )
53+ let result4 = TopKLogitsWarper ( k: 3 ) ( [ 0 , 1 , 2 ] , [ 2.0 , 1.0 , 3.0 ] )
5054 XCTAssertEqual ( result4. indexes, [ 2 , 0 , 1 ] )
5155 XCTAssertEqual ( result4. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
5256
53- let result5 = TopKLogitsWarper ( k: 4 ) ( [ 2.0 , 1.0 , 3.0 , - 1.0 , 123.0 , 0.0 ] )
57+ let result5 = TopKLogitsWarper ( k: 4 ) ( [ 0 , 1 , 2 , 3 , 4 , 5 ] , [ 2.0 , 1.0 , 3.0 , - 1.0 , 123.0 , 0.0 ] )
5458 XCTAssertEqual ( result5. indexes, [ 4 , 2 , 0 , 1 ] )
5559 XCTAssertEqual ( result5. logits, [ 123.0 , 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
60+
61+ let result6 = TopKLogitsWarper ( k: 3 ) ( [ 10 , 1 , 52 ] , [ 2.0 , 1.0 , 3.0 ] )
62+ XCTAssertEqual ( result6. indexes, [ 52 , 10 , 1 ] )
63+ XCTAssertEqual ( result6. logits, [ 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
5664 }
5765
5866 func testTopPLogitsWarper( ) {
59- let result1 = TopPLogitsWarper ( p: 0.99 ) ( [ ] )
67+ let result1 = TopPLogitsWarper ( p: 0.99 ) ( [ ] , [ ] )
6068 XCTAssertTrue ( result1. indexes. isEmpty)
6169 XCTAssertTrue ( result1. logits. isEmpty)
6270
63- let result2 = TopPLogitsWarper ( p: 0.99 ) ( ( 0 ..< 10 ) . map { Float ( $0) } )
71+ let logits = ( 0 ..< 10 ) . map { Float ( $0) }
72+ let indexes = Array ( logits. indices)
73+ let result2 = TopPLogitsWarper ( p: 0.99 ) ( indexes, logits)
6474 XCTAssertEqual ( result2. indexes, [ 9 , 8 , 7 , 6 , 5 ] )
6575 XCTAssertEqual ( result2. logits, [ 9.0 , 8.0 , 7.0 , 6.0 , 5.0 ] , accuracy: accuracy)
6676
67- let result3 = TopPLogitsWarper ( p: 0.95 ) ( ( 0 ..< 10 ) . map { Float ( $0 ) } )
77+ let result3 = TopPLogitsWarper ( p: 0.95 ) ( indexes , logits )
6878 XCTAssertEqual ( result3. indexes, [ 9 , 8 , 7 ] )
6979 XCTAssertEqual ( result3. logits, [ 9.0 , 8.0 , 7.0 ] , accuracy: accuracy)
7080
71- let result4 = TopPLogitsWarper ( p: 0.6321493 ) ( ( 0 ..< 10 ) . map { Float ( $0 ) } )
81+ let result4 = TopPLogitsWarper ( p: 0.6321493 ) ( indexes , logits )
7282 XCTAssertEqual ( result4. indexes, [ 9 , 8 ] )
7383 XCTAssertEqual ( result4. logits, [ 9.0 , 8.0 ] , accuracy: accuracy)
84+
85+ let result5 = TopPLogitsWarper ( p: 0.95 ) ( [ 3 , 1 , 8 ] , [ 0 , 1 , 2 ] )
86+ XCTAssertEqual ( result5. indexes, [ 8 , 1 , 3 ] )
87+ XCTAssertEqual ( result5. logits, [ 2 , 1 , 0 ] , accuracy: accuracy)
7488 }
7589
7690 func testLogitsProcessor( ) {
@@ -95,7 +109,14 @@ final class LogitsWarperTests: XCTestCase {
95109 logitsWarpers: [ TopKLogitsWarper ( k: 3 ) , TopPLogitsWarper ( p: 0.99 ) ]
96110 )
97111 let result4 = processor4 ( [ 2.0 , 1.0 , 3.0 , - 5.0 , - 23.0 , 12.5 ] )
98- XCTAssertEqual ( result4. indexes, [ 0 ] )
112+ XCTAssertEqual ( result4. indexes, [ 5 ] )
99113 XCTAssertEqual ( result4. logits, [ 12.5 ] , accuracy: accuracy)
114+
115+ let processor5 = LogitsProcessor (
116+ logitsWarpers: [ TopKLogitsWarper ( k: 4 ) , TopPLogitsWarper ( p: 0.99 ) ]
117+ )
118+ let result5 = processor5 ( [ 2.0 , 1.0 , 3.0 , - 5.0 , - 3.0 , 4.5 ] )
119+ XCTAssertEqual ( result5. indexes, [ 5 , 2 , 0 , 1 ] )
120+ XCTAssertEqual ( result5. logits, [ 4.5 , 3.0 , 2.0 , 1.0 ] , accuracy: accuracy)
100121 }
101122}
0 commit comments