Skip to content

Commit 5a5441e

Browse files
authored
Complete the transition to Swift Testing (#181)
1 parent 9e1e9f4 commit 5a5441e

File tree

2 files changed

+105
-83
lines changed

2 files changed

+105
-83
lines changed
Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
1+
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -13,19 +13,21 @@
1313
// limitations under the License.
1414

1515
@testable import HomomorphicEncryption
16-
import XCTest
16+
import Testing
1717

18-
class Array2dTests: XCTestCase {
18+
@Suite
19+
struct Array2dTests {
20+
@Test
1921
func testInit() {
2022
func runTest<T: FixedWidthInteger & Sendable>(_: T.Type) {
2123
let data = [T](1...6)
2224
let array = Array2d(data: data, rowCount: 3, columnCount: 2)
2325

2426
let data2d: [[T]] = [[1, 2], [3, 4], [5, 6]]
25-
XCTAssertEqual(array, Array2d(data: data2d))
27+
#expect(array == Array2d(data: data2d))
2628

27-
XCTAssert(Array2d<T>(data: []).shape == (rowCount: 0, columnCount: 0))
28-
XCTAssert(Array2d<T>(data: [[]]).shape == (rowCount: 0, columnCount: 0))
29+
#expect(Array2d<T>(data: []).shape == (rowCount: 0, columnCount: 0))
30+
#expect(Array2d<T>(data: [[]]).shape == (rowCount: 0, columnCount: 0))
2931
}
3032

3133
runTest(Int.self)
@@ -36,7 +38,8 @@ class Array2dTests: XCTestCase {
3638
runTest(UInt128.self)
3739
}
3840

39-
func testZeroAndZeroize() {
41+
@Test
42+
func zeroAndZeroize() {
4043
func runTest<T: FixedWidthInteger & Sendable>(_: T.Type) {
4144
let data = [T](1...16)
4245
var array = Array2d(data: data, rowCount: 2, columnCount: 8)
@@ -46,8 +49,8 @@ class Array2dTests: XCTestCase {
4649
data: [T](repeating: 0, count: 16),
4750
rowCount: 2,
4851
columnCount: 8)
49-
XCTAssertEqual(array, zero)
50-
XCTAssertEqual(array, Array2d.zero(rowCount: 2, columnCount: 8))
52+
#expect(array == zero)
53+
#expect(array == Array2d.zero(rowCount: 2, columnCount: 8))
5154
}
5255
runTest(Int.self)
5356
runTest(Int32.self)
@@ -57,92 +60,100 @@ class Array2dTests: XCTestCase {
5760
runTest(UInt128.self)
5861
}
5962

60-
func testShape() {
63+
@Test
64+
func shape() {
6165
let data = [Int](0..<16)
6266
let array = Array2d(data: data, rowCount: 2, columnCount: 8)
63-
XCTAssert(array.shape == (rowCount: 2, columnCount: 8))
67+
#expect(array.shape == (rowCount: 2, columnCount: 8))
6468
}
6569

66-
func testIndices4x4() {
70+
@Test
71+
func indices4x4() {
6772
let data = [Int](0..<16)
6873
let array = Array2d(data: data, rowCount: 4, columnCount: 4)
6974

70-
XCTAssertEqual(array.collectValues(indices: array.rowIndices(row: 0)), [0, 1, 2, 3])
71-
XCTAssertEqual(array.collectValues(indices: array.columnIndices(column: 0)), [0, 4, 8, 12])
75+
#expect(array.collectValues(indices: array.rowIndices(row: 0)) == [0, 1, 2, 3])
76+
#expect(array.collectValues(indices: array.columnIndices(column: 0)) == [0, 4, 8, 12])
7277
}
7378

74-
func testIndices2x8() {
79+
@Test
80+
func indices2x8() {
7581
let data = [Int](0..<16)
7682
let array = Array2d(data: data, rowCount: 2, columnCount: 8)
7783

78-
XCTAssertEqual(array.collectValues(indices: array.rowIndices(row: 0)), [0, 1, 2, 3, 4, 5, 6, 7])
79-
XCTAssertEqual(array.collectValues(indices: array.columnIndices(column: 0)), [0, 8])
80-
XCTAssertEqual(array.collectValues(indices: array.columnIndices(column: 7)), [7, 15])
84+
#expect(array.collectValues(indices: array.rowIndices(row: 0)) == [0, 1, 2, 3, 4, 5, 6, 7])
85+
#expect(array.collectValues(indices: array.columnIndices(column: 0)) == [0, 8])
86+
#expect(array.collectValues(indices: array.columnIndices(column: 7)) == [7, 15])
8187
}
8288

83-
func testTransposed() {
89+
@Test
90+
func transposed() {
8491
let data = [Int](0..<16)
8592
let array = Array2d(data: data, rowCount: 2, columnCount: 8)
8693
let transposed = array.transposed()
8794

88-
XCTAssert(array.shape == (2, 8))
89-
XCTAssert(transposed.shape == (8, 2))
95+
#expect(array.shape == (2, 8))
96+
#expect(transposed.shape == (8, 2))
9097

91-
XCTAssertEqual(transposed.collectValues(indices: transposed.rowIndices(row: 0)), [0, 8])
92-
XCTAssertEqual(transposed.collectValues(indices: transposed.rowIndices(row: 7)), [7, 15])
93-
XCTAssertEqual(transposed.collectValues(indices: transposed.columnIndices(column: 0)), [0, 1, 2, 3, 4, 5, 6, 7])
94-
XCTAssertEqual(
95-
transposed.collectValues(indices: transposed.columnIndices(column: 1)),
96-
[8, 9, 10, 11, 12, 13, 14, 15])
98+
#expect(transposed.collectValues(indices: transposed.rowIndices(row: 0)) == [0, 8])
99+
#expect(transposed.collectValues(indices: transposed.rowIndices(row: 7)) == [7, 15])
100+
#expect(transposed.collectValues(indices: transposed.columnIndices(column: 0)) == [0, 1, 2, 3, 4, 5, 6, 7])
101+
#expect(
102+
transposed.collectValues(indices: transposed.columnIndices(column: 1)) ==
103+
[8, 9, 10, 11, 12, 13, 14, 15])
97104
}
98105

99-
func testResizeColumn() {
106+
@Test
107+
func resizeColumn() {
100108
var array = Array2d(data: [Int](0..<6), rowCount: 2, columnCount: 3)
101109

102110
array.resizeColumn(newColumnCount: 5, defaultValue: 99)
103111
let newData: [Int] = [0, 1, 2, 99, 99, 3, 4, 5, 99, 99]
104-
XCTAssertEqual(array, Array2d(data: newData, rowCount: 2, columnCount: 5))
112+
#expect(array == Array2d(data: newData, rowCount: 2, columnCount: 5))
105113

106114
array.resizeColumn(newColumnCount: 3)
107-
XCTAssertEqual(array, Array2d(data: [Int](0..<6), rowCount: 2, columnCount: 3))
115+
#expect(array == Array2d(data: [Int](0..<6), rowCount: 2, columnCount: 3))
108116
}
109117

110-
func testRemoveLastRows() {
118+
@Test
119+
func removeLastRows() {
111120
let data = [Int](0..<32)
112121
var array = Array2d(data: data, rowCount: 4, columnCount: 8)
113122

114123
array.removeLastRows(2)
115-
XCTAssertEqual(array, Array2d(data: [Int](0..<16), rowCount: 2, columnCount: 8))
124+
#expect(array == Array2d(data: [Int](0..<16), rowCount: 2, columnCount: 8))
116125

117126
array.removeLastRows(1)
118-
XCTAssertEqual(array, Array2d(data: [Int](0..<8), rowCount: 1, columnCount: 8))
127+
#expect(array == Array2d(data: [Int](0..<8), rowCount: 1, columnCount: 8))
119128

120129
array.removeLastRows(1)
121-
XCTAssertEqual(array, Array2d(data: [], rowCount: 0, columnCount: 8))
130+
#expect(array == Array2d(data: [], rowCount: 0, columnCount: 8))
122131
}
123132

124-
func testAppendRows() {
133+
@Test
134+
func appendRows() {
125135
let data = [Int](0..<32)
126136
var array = Array2d(data: data, rowCount: 4, columnCount: 8)
127137
array.append(rows: [])
128-
XCTAssertEqual(array, Array2d(data: data, rowCount: 4, columnCount: 8))
138+
#expect(array == Array2d(data: data, rowCount: 4, columnCount: 8))
129139

130140
array.append(rows: [32, 33, 34, 35, 36, 37, 38, 39])
131-
XCTAssertEqual(array, Array2d(data: [Int](0..<40), rowCount: 5, columnCount: 8))
141+
#expect(array == Array2d(data: [Int](0..<40), rowCount: 5, columnCount: 8))
132142

133143
array.append(rows: Array(40..<56))
134-
XCTAssertEqual(array, Array2d(data: [Int](0..<56), rowCount: 7, columnCount: 8))
144+
#expect(array == Array2d(data: [Int](0..<56), rowCount: 7, columnCount: 8))
135145
}
136146

137-
func testMap() {
147+
@Test
148+
func map() {
138149
let data = [Int](0..<32)
139150
let array = Array2d(data: data, rowCount: 4, columnCount: 8)
140151

141152
let arrayPlus1 = array.map { UInt($0) + 1 }
142153
let expected = Array2d(data: [UInt](1..<33), rowCount: 4, columnCount: 8)
143-
XCTAssertEqual(arrayPlus1, expected)
154+
#expect(arrayPlus1 == expected)
144155

145156
let roundtripArray = arrayPlus1.map { Int($0 - 1) }
146-
XCTAssertEqual(roundtripArray, array)
157+
#expect(roundtripArray == array)
147158
}
148159
}

Tests/HomomorphicEncryptionTests/NttTests.swift

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
1+
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -13,33 +13,36 @@
1313
// limitations under the License.
1414

1515
@testable import HomomorphicEncryption
16-
import XCTest
17-
18-
final class NttTests: XCTestCase {
19-
func testIsPrimitiveRootOfUnity() {
20-
XCTAssertTrue(UInt32(12).isPrimitiveRootOfUnity(degree: 2, modulus: 13))
21-
XCTAssertFalse(UInt32(11).isPrimitiveRootOfUnity(degree: 2, modulus: 13))
22-
XCTAssertFalse(UInt32(12).isPrimitiveRootOfUnity(degree: 4, modulus: 13))
23-
24-
XCTAssertTrue(UInt64(28).isPrimitiveRootOfUnity(degree: 2, modulus: 29))
25-
XCTAssertTrue(UInt64(12).isPrimitiveRootOfUnity(degree: 4, modulus: 29))
26-
XCTAssertFalse(UInt64(12).isPrimitiveRootOfUnity(degree: 2, modulus: 29))
27-
XCTAssertFalse(UInt64(12).isPrimitiveRootOfUnity(degree: 8, modulus: 29))
28-
29-
XCTAssertTrue(UInt64(1_234_565_440).isPrimitiveRootOfUnity(degree: 2, modulus: 1_234_565_441))
30-
XCTAssertTrue(UInt64(960_907_033).isPrimitiveRootOfUnity(degree: 8, modulus: 1_234_565_441))
31-
XCTAssertTrue(UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 16, modulus: 1_234_565_441))
32-
XCTAssertFalse(UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 32, modulus: 1_234_565_441))
33-
XCTAssertFalse(UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 8, modulus: 1_234_565_441))
34-
XCTAssertFalse(UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 2, modulus: 1_234_565_441))
16+
import Testing
17+
18+
@Suite
19+
struct NttTests {
20+
@Test
21+
func isPrimitiveRootOfUnity() {
22+
#expect(UInt32(12).isPrimitiveRootOfUnity(degree: 2, modulus: 13))
23+
#expect(!UInt32(11).isPrimitiveRootOfUnity(degree: 2, modulus: 13))
24+
#expect(!UInt32(12).isPrimitiveRootOfUnity(degree: 4, modulus: 13))
25+
26+
#expect(UInt64(28).isPrimitiveRootOfUnity(degree: 2, modulus: 29))
27+
#expect(UInt64(12).isPrimitiveRootOfUnity(degree: 4, modulus: 29))
28+
#expect(!UInt64(12).isPrimitiveRootOfUnity(degree: 2, modulus: 29))
29+
#expect(!UInt64(12).isPrimitiveRootOfUnity(degree: 8, modulus: 29))
30+
31+
#expect(UInt64(1_234_565_440).isPrimitiveRootOfUnity(degree: 2, modulus: 1_234_565_441))
32+
#expect(UInt64(960_907_033).isPrimitiveRootOfUnity(degree: 8, modulus: 1_234_565_441))
33+
#expect(UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 16, modulus: 1_234_565_441))
34+
#expect(!UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 32, modulus: 1_234_565_441))
35+
#expect(!UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 8, modulus: 1_234_565_441))
36+
#expect(!UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 2, modulus: 1_234_565_441))
3537
}
3638

37-
func testMinPrimitiveRootOfUnity() {
38-
XCTAssertEqual(UInt32(11).minPrimitiveRootOfUnity(degree: 2), 10)
39-
XCTAssertEqual(UInt32(29).minPrimitiveRootOfUnity(degree: 2), 28)
40-
XCTAssertEqual(UInt32(29).minPrimitiveRootOfUnity(degree: 4), 12)
41-
XCTAssertEqual(UInt64(1_234_565_441).minPrimitiveRootOfUnity(degree: 2), 1_234_565_440)
42-
XCTAssertEqual(UInt64(1_234_565_441).minPrimitiveRootOfUnity(degree: 8), 249_725_733)
39+
@Test
40+
func minPrimitiveRootOfUnity() {
41+
#expect(UInt32(11).minPrimitiveRootOfUnity(degree: 2) == 10)
42+
#expect(UInt32(29).minPrimitiveRootOfUnity(degree: 2) == 28)
43+
#expect(UInt32(29).minPrimitiveRootOfUnity(degree: 4) == 12)
44+
#expect(UInt64(1_234_565_441).minPrimitiveRootOfUnity(degree: 2) == 1_234_565_440)
45+
#expect(UInt64(1_234_565_441).minPrimitiveRootOfUnity(degree: 8) == 249_725_733)
4346
}
4447

4548
private func runNttTest<T: ScalarType>(
@@ -59,15 +62,16 @@ final class NttTests: XCTestCase {
5962
let polyCoeff = PolyRq<T, Coeff>(context: context, data: coeffData)
6063
let polyEval = PolyRq<T, Eval>(context: context, data: evalData)
6164

62-
XCTAssertEqual(try polyCoeff.forwardNtt(), polyEval)
63-
XCTAssertEqual(try polyEval.inverseNtt(), polyCoeff)
64-
XCTAssertEqual(try polyEval.convertToCoeffFormat(), polyCoeff)
65-
XCTAssertEqual(try polyCoeff.convertToCoeffFormat(), polyCoeff)
66-
XCTAssertEqual(try polyEval.convertToEvalFormat(), polyEval)
67-
XCTAssertEqual(try polyCoeff.convertToEvalFormat(), polyEval)
65+
#expect(try polyCoeff.forwardNtt() == polyEval)
66+
#expect(try polyEval.inverseNtt() == polyCoeff)
67+
#expect(try polyEval.convertToCoeffFormat() == polyCoeff)
68+
#expect(try polyCoeff.convertToCoeffFormat() == polyCoeff)
69+
#expect(try polyEval.convertToEvalFormat() == polyEval)
70+
#expect(try polyCoeff.convertToEvalFormat() == polyEval)
6871
}
6972

70-
func testNtt2() throws {
73+
@Test
74+
func ntt2() throws {
7175
try runNttTest(moduli: [UInt32(97)], coeffData: [[0, 0]], evalData: [[0, 0]])
7276
try runNttTest(moduli: [UInt32(97)], coeffData: [[1, 0]], evalData: [[1, 1]])
7377

@@ -76,15 +80,17 @@ final class NttTests: XCTestCase {
7680
try runNttTest(moduli: [UInt32(97), UInt32(113)], coeffData: [[1, 2], [3, 4]], evalData: [[45, 54], [63, 56]])
7781
}
7882

79-
func testNtt4() throws {
83+
@Test
84+
func ntt4() throws {
8085
try runNttTest(moduli: [UInt32(97)], coeffData: [[0, 0, 0, 0]], evalData: [[0, 0, 0, 0]])
8186
try runNttTest(moduli: [UInt32(97)], coeffData: [[1, 0, 0, 0]], evalData: [[1, 1, 1, 1]])
8287
try runNttTest(moduli: [UInt32(97)], coeffData: [[1, 2, 3, 4]], evalData: [[30, 7, 64, 0]])
8388
try runNttTest(moduli: [UInt32(97), UInt32(113)],
8489
coeffData: [[1, 2, 3, 4], [5, 6, 7, 8]], evalData: [[30, 7, 64, 0], [108, 31, 103, 4]])
8590
}
8691

87-
func testNtt8() throws {
92+
@Test
93+
func ntt8() throws {
8894
try runNttTest(
8995
moduli: [UInt32(4_194_353)],
9096
coeffData: [[0, 0, 0, 0, 0, 0, 0, 0]],
@@ -106,7 +112,8 @@ final class NttTests: XCTestCase {
106112
])
107113
}
108114

109-
func testNtt16() throws {
115+
@Test
116+
func ntt16() throws {
110117
// modulus near top of range
111118
try runNttTest(
112119
moduli: [UInt32(536_870_849)],
@@ -155,7 +162,8 @@ final class NttTests: XCTestCase {
155162
]])
156163
}
157164

158-
func testNtt32() throws {
165+
@Test
166+
func ntt32() throws {
159167
let modulus = UInt32(769)
160168

161169
let zeros = [UInt32](repeating: 0, count: 32)
@@ -172,7 +180,8 @@ final class NttTests: XCTestCase {
172180
try runNttTest(moduli: [modulus], coeffData: [coeffData], evalData: [evalData])
173181
}
174182

175-
func testNtt4096() throws {
183+
@Test
184+
func ntt4096() throws {
176185
let modulus = UInt64(557_057)
177186
let degree = 4096
178187
let zeros = [UInt64](repeating: 0, count: degree)
@@ -182,7 +191,8 @@ final class NttTests: XCTestCase {
182191
try runNttTest(moduli: [modulus], coeffData: [oneHot], evalData: [Array(repeating: 1, count: degree)])
183192
}
184193

185-
func testNttRoundtrip() throws {
194+
@Test
195+
func nttRoundtrip() throws {
186196
let degree = 256
187197
// Test large modulus
188198
let moduli = try UInt64.generatePrimes(
@@ -193,10 +203,11 @@ final class NttTests: XCTestCase {
193203
let polyCoeff = PolyRq<_, Coeff>.random(context: context)
194204
let polyEval = try polyCoeff.forwardNtt()
195205
let polyRoundtrip = try polyEval.inverseNtt()
196-
XCTAssertEqual(polyRoundtrip, polyCoeff)
206+
#expect(polyRoundtrip == polyCoeff)
197207
}
198208

199-
func testNttMatchesNaive() throws {
209+
@Test
210+
func nttMatchesNaive() throws {
200211
func naiveMultiplication<T: ScalarType>(_ x: [T], _ y: [T], modulus: T) -> [T] {
201212
precondition(x.count == y.count)
202213
let n = x.count
@@ -236,6 +247,6 @@ final class NttTests: XCTestCase {
236247
let prodNtt = try nttMultiplication(x, y)
237248
let prodNaive = naiveMultiplication(x.data.data, y.data.data, modulus: moduli[0])
238249

239-
XCTAssertEqual(prodNtt.data.data, prodNaive)
250+
#expect(prodNtt.data.data == prodNaive)
240251
}
241252
}

0 commit comments

Comments
 (0)