|
6 | 6 |
|
7 | 7 | #if canImport(CoreML) |
8 | 8 | import CoreML |
9 | | -import XCTest |
10 | | - |
11 | 9 | @testable import Generation |
| 10 | +import Testing |
12 | 11 |
|
13 | | -final class LogitsWarperTests: XCTestCase { |
| 12 | +@Suite("Logits Warper Tests") |
| 13 | +struct LogitsWarperTests { |
14 | 14 | private let accuracy: Float = 0.00001 |
15 | 15 |
|
16 | | - func testTemperatureLogitsWarper() { |
| 16 | + @Test("Temperature logits warper functionality") |
| 17 | + func temperatureLogitsWarper() { |
17 | 18 | let result1 = TemperatureLogitsWarper(temperature: 0.0)([], []) |
18 | | - XCTAssertTrue(result1.indices.isEmpty) |
19 | | - XCTAssertTrue(result1.logits.isEmpty) |
| 19 | + #expect(result1.indices.isEmpty) |
| 20 | + #expect(result1.logits.isEmpty) |
20 | 21 |
|
21 | 22 | let result2 = TemperatureLogitsWarper(temperature: 1.0)([], []) |
22 | | - XCTAssertTrue(result2.indices.isEmpty) |
23 | | - XCTAssertTrue(result2.logits.isEmpty) |
| 23 | + #expect(result2.indices.isEmpty) |
| 24 | + #expect(result2.logits.isEmpty) |
24 | 25 |
|
25 | 26 | let result3 = TemperatureLogitsWarper(temperature: 1.0)([0, 1], [2.0, 1.0]) |
26 | | - XCTAssertEqual(result3.indices, [0, 1]) |
27 | | - XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy) |
| 27 | + #expect(result3.indices == [0, 1]) |
| 28 | + #expect(isClose(result3.logits, [2.0, 1.0], accuracy: accuracy)) |
28 | 29 |
|
29 | 30 | let result4 = TemperatureLogitsWarper(temperature: 2.0)([0, 1], [2.0, 1.0]) |
30 | | - XCTAssertEqual(result4.indices, [0, 1]) |
31 | | - XCTAssertEqual(result4.logits, [1.0, 0.5], accuracy: accuracy) |
| 31 | + #expect(result4.indices == [0, 1]) |
| 32 | + #expect(isClose(result4.logits, [1.0, 0.5], accuracy: accuracy)) |
32 | 33 |
|
33 | 34 | let result5 = TemperatureLogitsWarper(temperature: 0.5)([0, 1], [2.0, 1.0]) |
34 | | - XCTAssertEqual(result5.indices, [0, 1]) |
35 | | - XCTAssertEqual(result5.logits, [4.0, 2.0], accuracy: accuracy) |
| 35 | + #expect(result5.indices == [0, 1]) |
| 36 | + #expect(isClose(result5.logits, [4.0, 2.0], accuracy: accuracy)) |
36 | 37 |
|
37 | 38 | let result6 = TemperatureLogitsWarper(temperature: 0.5)([200, 100], [2.0, 1.0]) |
38 | | - XCTAssertEqual(result6.indices, [200, 100]) |
39 | | - XCTAssertEqual(result6.logits, [4.0, 2.0], accuracy: accuracy) |
| 39 | + #expect(result6.indices == [200, 100]) |
| 40 | + #expect(isClose(result6.logits, [4.0, 2.0], accuracy: accuracy)) |
40 | 41 | } |
41 | 42 |
|
42 | | - func testTopKLogitsWarper() { |
| 43 | + @Test("Top-K logits warper functionality") |
| 44 | + func topKLogitsWarper() { |
43 | 45 | let result1 = TopKLogitsWarper(k: 0)([], []) |
44 | | - XCTAssertTrue(result1.indices.isEmpty) |
45 | | - XCTAssertTrue(result1.logits.isEmpty) |
| 46 | + #expect(result1.indices.isEmpty) |
| 47 | + #expect(result1.logits.isEmpty) |
46 | 48 |
|
47 | 49 | let result2 = TopKLogitsWarper(k: 3)([], []) |
48 | | - XCTAssertTrue(result2.indices.isEmpty) |
49 | | - XCTAssertTrue(result2.logits.isEmpty) |
| 50 | + #expect(result2.indices.isEmpty) |
| 51 | + #expect(result2.logits.isEmpty) |
50 | 52 |
|
51 | 53 | let result3 = TopKLogitsWarper(k: 3)([0, 1], [2.0, 1.0]) |
52 | | - XCTAssertEqual(result3.indices, [0, 1]) |
53 | | - XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy) |
| 54 | + #expect(result3.indices == [0, 1]) |
| 55 | + #expect(isClose(result3.logits, [2.0, 1.0], accuracy: accuracy)) |
54 | 56 |
|
55 | 57 | let result4 = TopKLogitsWarper(k: 3)([0, 1, 2], [2.0, 1.0, 3.0]) |
56 | | - XCTAssertEqual(result4.indices, [2, 0, 1]) |
57 | | - XCTAssertEqual(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy) |
| 58 | + #expect(result4.indices == [2, 0, 1]) |
| 59 | + #expect(isClose(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy)) |
58 | 60 |
|
59 | 61 | let result5 = TopKLogitsWarper(k: 4)([0, 1, 2, 3, 4, 5], [2.0, 1.0, 3.0, -1.0, 123.0, 0.0]) |
60 | | - XCTAssertEqual(result5.indices, [4, 2, 0, 1]) |
61 | | - XCTAssertEqual(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy) |
| 62 | + #expect(result5.indices == [4, 2, 0, 1]) |
| 63 | + #expect(isClose(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy)) |
62 | 64 |
|
63 | 65 | let result6 = TopKLogitsWarper(k: 3)([10, 1, 52], [2.0, 1.0, 3.0]) |
64 | | - XCTAssertEqual(result6.indices, [52, 10, 1]) |
65 | | - XCTAssertEqual(result6.logits, [3.0, 2.0, 1.0], accuracy: accuracy) |
| 66 | + #expect(result6.indices == [52, 10, 1]) |
| 67 | + #expect(isClose(result6.logits, [3.0, 2.0, 1.0], accuracy: accuracy)) |
66 | 68 | } |
67 | 69 |
|
68 | | - func testTopPLogitsWarper() { |
| 70 | + @Test("Top-P logits warper functionality") |
| 71 | + func topPLogitsWarper() { |
69 | 72 | let result1 = TopPLogitsWarper(p: 0.99)([], []) |
70 | | - XCTAssertTrue(result1.indices.isEmpty) |
71 | | - XCTAssertTrue(result1.logits.isEmpty) |
| 73 | + #expect(result1.indices.isEmpty) |
| 74 | + #expect(result1.logits.isEmpty) |
72 | 75 |
|
73 | 76 | let logits = (0..<10).map { Float($0) } |
74 | 77 | let indices = Array(logits.indices) |
75 | 78 | let result2 = TopPLogitsWarper(p: 0.99)(indices, logits) |
76 | | - XCTAssertEqual(result2.indices, [9, 8, 7, 6, 5]) |
77 | | - XCTAssertEqual(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy) |
| 79 | + #expect(result2.indices == [9, 8, 7, 6, 5]) |
| 80 | + #expect(isClose(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy)) |
78 | 81 |
|
79 | 82 | let result3 = TopPLogitsWarper(p: 0.95)(indices, logits) |
80 | | - XCTAssertEqual(result3.indices, [9, 8, 7]) |
81 | | - XCTAssertEqual(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy) |
| 83 | + #expect(result3.indices == [9, 8, 7]) |
| 84 | + #expect(isClose(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy)) |
82 | 85 |
|
83 | 86 | let result4 = TopPLogitsWarper(p: 0.6321493)(indices, logits) |
84 | | - XCTAssertEqual(result4.indices, [9, 8]) |
85 | | - XCTAssertEqual(result4.logits, [9.0, 8.0], accuracy: accuracy) |
| 87 | + #expect(result4.indices == [9, 8]) |
| 88 | + #expect(isClose(result4.logits, [9.0, 8.0], accuracy: accuracy)) |
86 | 89 |
|
87 | 90 | let result5 = TopPLogitsWarper(p: 0.95)([3, 1, 8], [0, 1, 2]) |
88 | | - XCTAssertEqual(result5.indices, [8, 1, 3]) |
89 | | - XCTAssertEqual(result5.logits, [2, 1, 0], accuracy: accuracy) |
| 91 | + #expect(result5.indices == [8, 1, 3]) |
| 92 | + #expect(isClose(result5.logits, [2, 1, 0], accuracy: accuracy)) |
90 | 93 | } |
91 | 94 |
|
92 | | - func testRepetitionPenaltyWarper() { |
| 95 | + @Test("Repetition penalty warper functionality") |
| 96 | + func repetitionPenaltyWarper() { |
93 | 97 | let indices = Array(0..<10) |
94 | 98 | let logits = indices.map { Float($0) } |
95 | 99 |
|
96 | 100 | let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits) |
97 | | - XCTAssertEqual(result1.indices, indices) |
98 | | - XCTAssertEqual(result1.logits, logits, accuracy: accuracy) |
| 101 | + #expect(result1.indices == indices) |
| 102 | + #expect(isClose(result1.logits, logits, accuracy: accuracy)) |
99 | 103 |
|
100 | 104 | let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits) |
101 | | - XCTAssertEqual(result2.indices, indices) |
| 105 | + #expect(result2.indices == indices) |
102 | 106 | let logits2 = indices.map { Float($0) / 3.75 } |
103 | | - XCTAssertEqual(result2.logits, logits2, accuracy: accuracy) |
| 107 | + #expect(isClose(result2.logits, logits2, accuracy: accuracy)) |
104 | 108 |
|
105 | 109 | let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119]) |
106 | | - XCTAssertEqual(result3.indices, [0, 1, 2]) |
107 | | - XCTAssertEqual(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4) |
| 110 | + #expect(result3.indices == [0, 1, 2]) |
| 111 | + #expect(isClose(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4)) |
108 | 112 |
|
109 | 113 | let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158]) |
110 | | - XCTAssertEqual(result4.indices, [2, 3, 4]) |
111 | | - XCTAssertEqual(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4) |
| 114 | + #expect(result4.indices == [2, 3, 4]) |
| 115 | + #expect(isClose(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4)) |
112 | 116 |
|
113 | 117 | let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966]) |
114 | | - XCTAssertEqual(result5.indices, [0, 1, 2]) |
115 | | - XCTAssertEqual(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4) |
| 118 | + #expect(result5.indices == [0, 1, 2]) |
| 119 | + #expect(isClose(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4)) |
116 | 120 |
|
117 | 121 | let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755]) |
118 | | - XCTAssertEqual(result6.indices, [3, 1, 2]) |
119 | | - XCTAssertEqual(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4) |
| 122 | + #expect(result6.indices == [3, 1, 2]) |
| 123 | + #expect(isClose(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4)) |
120 | 124 | } |
121 | 125 |
|
122 | | - func testLogitsProcessor() { |
| 126 | + @Test("Logits processor functionality") |
| 127 | + func logitsProcessor() { |
123 | 128 | let processor1 = LogitsProcessor(logitsWarpers: []) |
124 | 129 | let result1 = processor1([]) |
125 | | - XCTAssertTrue(result1.indices.isEmpty) |
126 | | - XCTAssertTrue(result1.logits.isEmpty) |
| 130 | + #expect(result1.indices.isEmpty) |
| 131 | + #expect(result1.logits.isEmpty) |
127 | 132 |
|
128 | 133 | let processor2 = LogitsProcessor(logitsWarpers: []) |
129 | 134 | let result2 = processor2([2.0, 1.0]) |
130 | | - XCTAssertEqual(result2.indices, [0, 1]) |
131 | | - XCTAssertEqual(result2.logits, [2.0, 1.0], accuracy: accuracy) |
| 135 | + #expect(result2.indices == [0, 1]) |
| 136 | + #expect(isClose(result2.logits, [2.0, 1.0], accuracy: accuracy)) |
132 | 137 |
|
133 | 138 | let processor3 = LogitsProcessor( |
134 | 139 | logitsWarpers: [TopKLogitsWarper(k: 3)] |
135 | 140 | ) |
136 | 141 | let result3 = processor3([2.0, 1.0, 3.0, -5.0]) |
137 | | - XCTAssertEqual(result3.indices, [2, 0, 1]) |
138 | | - XCTAssertEqual(result3.logits, [3.0, 2.0, 1.0], accuracy: accuracy) |
| 142 | + #expect(result3.indices == [2, 0, 1]) |
| 143 | + #expect(isClose(result3.logits, [3.0, 2.0, 1.0], accuracy: accuracy)) |
139 | 144 |
|
140 | 145 | let processor4 = LogitsProcessor( |
141 | 146 | logitsWarpers: [TopKLogitsWarper(k: 3), TopPLogitsWarper(p: 0.99)] |
142 | 147 | ) |
143 | 148 | let result4 = processor4([2.0, 1.0, 3.0, -5.0, -23.0, 12.5]) |
144 | | - XCTAssertEqual(result4.indices, [5]) |
145 | | - XCTAssertEqual(result4.logits, [12.5], accuracy: accuracy) |
| 149 | + #expect(result4.indices == [5]) |
| 150 | + #expect(isClose(result4.logits, [12.5], accuracy: accuracy)) |
146 | 151 |
|
147 | 152 | let processor5 = LogitsProcessor( |
148 | 153 | logitsWarpers: [TopKLogitsWarper(k: 4), TopPLogitsWarper(p: 0.99)] |
149 | 154 | ) |
150 | 155 | let result5 = processor5([2.0, 1.0, 3.0, -5.0, -3.0, 4.5]) |
151 | | - XCTAssertEqual(result5.indices, [5, 2, 0, 1]) |
152 | | - XCTAssertEqual(result5.logits, [4.5, 3.0, 2.0, 1.0], accuracy: accuracy) |
| 156 | + #expect(result5.indices == [5, 2, 0, 1]) |
| 157 | + #expect(isClose(result5.logits, [4.5, 3.0, 2.0, 1.0], accuracy: accuracy)) |
153 | 158 | } |
154 | 159 | } |
155 | 160 | #endif // canImport(CoreML) |
0 commit comments