Skip to content

Commit abc77f1

Browse files
authored
Port PNNS tests to _TestUtilities (#184)
1 parent ebf8869 commit abc77f1

18 files changed

+1615
-1461
lines changed

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ let package = Package(
127127
name: "_TestUtilities",
128128
dependencies: [
129129
"HomomorphicEncryption",
130+
"PrivateNearestNeighborSearch",
130131
.product(name: "Numerics", package: "swift-numerics"),
131132
],
132133
path: "Sources/TestUtilities",

Sources/PrivateNearestNeighborSearch/CiphertextMatrix.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
1+
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@ import ModularArithmetic
1818
/// Stores a matrix of scalars as ciphertexts.
1919
public struct CiphertextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendable {
2020
/// Dimensions of the matrix.
21-
@usableFromInline let dimensions: MatrixDimensions
21+
@usableFromInline package let dimensions: MatrixDimensions
2222

2323
/// Dimensions of the scalar matrix in a SIMD-encoded plaintext.
2424
@usableFromInline let simdDimensions: SimdEncodingDimensions
@@ -87,7 +87,7 @@ public struct CiphertextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
8787
}
8888

8989
@inlinable
90-
func decrypt(using secretKey: SecretKey<Scheme>) throws -> PlaintextMatrix<Scheme, Coeff> {
90+
package func decrypt(using secretKey: SecretKey<Scheme>) throws -> PlaintextMatrix<Scheme, Coeff> {
9191
let plaintexts = try ciphertexts.map { ciphertext in try ciphertext.decrypt(using: secretKey) }
9292
return try PlaintextMatrix(dimensions: dimensions, packing: packing, plaintexts: plaintexts)
9393
}
@@ -164,8 +164,8 @@ extension CiphertextMatrix {
164164
/// - Returns: The evaluation key configuration.
165165
/// - Throws: Error upon failure to generate the evaluation key configuration.
166166
@inlinable
167-
static func extractDenseRowConfig(for encryptionParameters: EncryptionParameters<Scheme>,
168-
dimensions: MatrixDimensions) throws -> EvaluationKeyConfig
167+
package static func extractDenseRowConfig(for encryptionParameters: EncryptionParameters<Scheme>,
168+
dimensions: MatrixDimensions) throws -> EvaluationKeyConfig
169169
{
170170
if dimensions.rowCount == 1 {
171171
// extractDenseRow is a No-op, so no evaluation key required
@@ -190,7 +190,7 @@ extension CiphertextMatrix {
190190
/// - Returns: A ciphertext matrix in `.denseRow` format with 1 row
191191
/// - Throws: Error upon failure to extract the row.
192192
@inlinable
193-
func extractDenseRow(rowIndex: Int, evaluationKey: EvaluationKey<Scheme>) throws -> Self
193+
package func extractDenseRow(rowIndex: Int, evaluationKey: EvaluationKey<Scheme>) throws -> Self
194194
where Format == Scheme.CanonicalCiphertextFormat
195195
{
196196
precondition((0..<dimensions.rowCount).contains(rowIndex))

Sources/PrivateNearestNeighborSearch/MatrixMultiplication.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
1+
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -117,7 +117,7 @@ extension PlaintextMatrix {
117117
/// - Returns: Encrypted dense-column packed vector containing dot products.
118118
/// - Throws: Error upon failure to compute the inner product.
119119
@inlinable
120-
func mulTranspose(
120+
package func mulTranspose(
121121
vector ciphertextVector: CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat>,
122122
using evaluationKey: EvaluationKey<Scheme>) throws -> [Scheme.CanonicalCiphertext]
123123
{
@@ -215,7 +215,7 @@ extension PlaintextMatrix {
215215
/// - Returns: Encrypted dense-column packed matrix.
216216
/// - Throws: Error upon failure to compute the product.
217217
@inlinable
218-
func mulTranspose(
218+
package func mulTranspose(
219219
matrix ciphertextMatrix: CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat>,
220220
using evaluationKey: EvaluationKey<Scheme>) throws
221221
-> CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat>

Sources/PrivateNearestNeighborSearch/PlaintextMatrix.swift

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
1+
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -80,13 +80,13 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
8080
@usableFromInline let simdDimensions: SimdEncodingDimensions
8181

8282
/// Plaintext packing with which the data is stored.
83-
@usableFromInline let packing: MatrixPacking
83+
@usableFromInline package let packing: MatrixPacking
8484

8585
/// Plaintexts encoding the scalars.
86-
@usableFromInline let plaintexts: [Plaintext<Scheme, Format>]
86+
@usableFromInline package let plaintexts: [Plaintext<Scheme, Format>]
8787

8888
/// The parameter context.
89-
@usableFromInline var context: Context<Scheme> {
89+
@usableFromInline package var context: Context<Scheme> {
9090
precondition(!plaintexts.isEmpty, "Plaintext array cannot be empty")
9191
return plaintexts[0].context
9292
}
@@ -101,10 +101,10 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
101101
@usableFromInline var count: Int { dimensions.count }
102102

103103
/// Number of rows in the stored data.
104-
@usableFromInline var rowCount: Int { dimensions.rowCount }
104+
@usableFromInline package var rowCount: Int { dimensions.rowCount }
105105

106106
/// Number of columns in the stored data.
107-
@usableFromInline var columnCount: Int { dimensions.columnCount }
107+
@usableFromInline package var columnCount: Int { dimensions.columnCount }
108108

109109
/// Creates a new plaintext matrix.
110110
/// - Parameters:
@@ -113,7 +113,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
113113
/// - plaintexts: Plaintexts encoding the data; must not be empty.
114114
/// - Throws: Error upon failure to initialize the plaintext matrix.
115115
@inlinable
116-
init(dimensions: MatrixDimensions, packing: MatrixPacking, plaintexts: [Plaintext<Scheme, Format>]) throws {
116+
package init(dimensions: MatrixDimensions, packing: MatrixPacking, plaintexts: [Plaintext<Scheme, Format>]) throws {
117117
guard !plaintexts.isEmpty else {
118118
throw PnnsError.emptyPlaintextArray
119119
}
@@ -184,13 +184,12 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
184184
/// - reduce: If true, values are reduced into the correct range before encoding.
185185
/// - Throws: Error upon failure to create the plaitnext matrix.
186186
@inlinable
187-
init(
187+
package init(
188188
context: Context<Scheme>,
189189
dimensions: MatrixDimensions,
190190
packing: MatrixPacking,
191191
values: [Scheme.Scalar],
192-
reduce: Bool = false) throws
193-
where Format == Coeff
192+
reduce: Bool = false) throws where Format == Coeff
194193
{
195194
guard values.count == dimensions.count, !values.isEmpty else {
196195
throw PnnsError.wrongEncodingValuesCount(got: values.count, expected: values.count)
@@ -475,7 +474,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
475474
/// - Returns: The stored data values in row-major format.
476475
/// - Throws: Error upon failure to unpack the matrix.
477476
@inlinable
478-
func unpack() throws -> [Scheme.Scalar] where Format == Coeff {
477+
package func unpack() throws -> [Scheme.Scalar] where Format == Coeff {
479478
switch packing {
480479
case .denseColumn:
481480
return try unpackDenseColumn()
@@ -490,7 +489,7 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
490489
/// - Returns: The stored data values in row-major format.
491490
/// - Throws: Error upon failure to unpack the matrix.
492491
@inlinable
493-
func unpack() throws -> [Scheme.SignedScalar] where Format == Coeff {
492+
package func unpack() throws -> [Scheme.SignedScalar] where Format == Coeff {
494493
let unsigned: [Scheme.Scalar] = try unpack()
495494
return unsigned.map { unsigned in
496495
unsigned.remainderToCentered(modulus: context.plaintextModulus)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
// Copyright 2025 Apple Inc. and the Swift Homomorphic Encryption project authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import HomomorphicEncryption
16+
import PrivateNearestNeighborSearch
17+
import Testing
18+
19+
@inlinable
20+
func increasingData<T: ScalarType>(dimensions: MatrixDimensions, modulus: T) -> [[T]] {
21+
(0..<dimensions.rowCount).map { rowIndex in
22+
(0..<dimensions.columnCount).map { columnIndex in
23+
let value = 1 + T(rowIndex * dimensions.columnCount + columnIndex)
24+
return value % modulus
25+
}
26+
}
27+
}
28+
29+
@inlinable
30+
func randomData<T: ScalarType>(dimensions: MatrixDimensions, modulus: T) -> [[T]] {
31+
(0..<dimensions.rowCount).map { _ in
32+
(0..<dimensions.columnCount).map { _ in T.random(in: 0..<modulus) }
33+
}
34+
}
35+
36+
extension PrivateNearestNeighborSearchUtil {
37+
/// Tests for `CiphertextMatrix`.
38+
public enum CiphertextMatrixTests {
39+
/// Testing encryption/decryption round-trip.
40+
@inlinable
41+
public static func encryptDecryptRoundTrip<Scheme: HeScheme>(for _: Scheme.Type) throws {
42+
let rlweParams = PredefinedRlweParameters.insecure_n_8_logq_5x18_logt_5
43+
let encryptionParameters = try EncryptionParameters<Scheme>(from: rlweParams)
44+
#expect(encryptionParameters.supportsSimdEncoding)
45+
let context = try Context<Scheme>(encryptionParameters: encryptionParameters)
46+
let dimensions = try MatrixDimensions(rowCount: 10, columnCount: 4)
47+
let encodeValues: [[Scheme.Scalar]] = increasingData(
48+
dimensions: dimensions,
49+
modulus: context.plaintextModulus)
50+
let plaintextMatrix = try PlaintextMatrix<Scheme, Coeff>(
51+
context: context,
52+
dimensions: dimensions,
53+
packing: .denseRow,
54+
values: encodeValues.flatMap { $0 })
55+
let secretKey = try context.generateSecretKey()
56+
var ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey)
57+
let plaintextMatrixRoundTrip = try ciphertextMatrix.decrypt(using: secretKey)
58+
#expect(plaintextMatrixRoundTrip == plaintextMatrix)
59+
60+
// modSwitchDownToSingle
61+
do {
62+
try ciphertextMatrix.modSwitchDownToSingle()
63+
let plaintextMatrixRoundTrip = try ciphertextMatrix.decrypt(using: secretKey)
64+
#expect(plaintextMatrixRoundTrip == plaintextMatrix)
65+
}
66+
}
67+
68+
/// Testing convert to Coeff/Eval format roundtrip.
69+
@inlinable
70+
public static func convertFormatRoundTrip<Scheme: HeScheme>(for _: Scheme.Type) throws {
71+
let rlweParams = PredefinedRlweParameters.insecure_n_8_logq_5x18_logt_5
72+
let encryptionParameters = try EncryptionParameters<Scheme>(from: rlweParams)
73+
#expect(encryptionParameters.supportsSimdEncoding)
74+
let context = try Context<Scheme>(encryptionParameters: encryptionParameters)
75+
let dimensions = try MatrixDimensions(rowCount: 10, columnCount: 4)
76+
let encodeValues: [[Scheme.Scalar]] = increasingData(
77+
dimensions: dimensions,
78+
modulus: context.plaintextModulus)
79+
let plaintextMatrix = try PlaintextMatrix<Scheme, Coeff>(
80+
context: context,
81+
dimensions: dimensions,
82+
packing: .denseRow,
83+
values: encodeValues.flatMap { $0 })
84+
let secretKey = try context.generateSecretKey()
85+
let ciphertextCoeffMatrix: CiphertextMatrix = try plaintextMatrix.encrypt(using: secretKey)
86+
let ciphertextEvalMatrix = try ciphertextCoeffMatrix.convertToEvalFormat()
87+
let ciphertextMatrixRoundTrip = try ciphertextEvalMatrix.convertToCoeffFormat()
88+
let decoded = try ciphertextMatrixRoundTrip.decrypt(using: secretKey)
89+
#expect(plaintextMatrix == decoded)
90+
}
91+
92+
/// Testing `extractDenseRow`.
93+
@inlinable
94+
public static func extractDenseRow<Scheme: HeScheme>(for _: Scheme.Type) throws {
95+
let degree = 16
96+
let plaintextModulus = try Scheme.Scalar.generatePrimes(
97+
significantBitCounts: [9],
98+
preferringSmall: true,
99+
nttDegree: degree)[0]
100+
let coefficientModuli = try Scheme.Scalar.generatePrimes(
101+
significantBitCounts: Array(repeating: Scheme.Scalar.bitWidth - 4, count: 2),
102+
preferringSmall: false,
103+
nttDegree: degree)
104+
let encryptionParameters = try EncryptionParameters<Scheme>(
105+
polyDegree: degree,
106+
plaintextModulus: plaintextModulus,
107+
coefficientModuli: coefficientModuli,
108+
errorStdDev: .stdDev32,
109+
securityLevel: .unchecked)
110+
#expect(encryptionParameters.supportsSimdEncoding)
111+
let context = try Context<Scheme>(encryptionParameters: encryptionParameters)
112+
113+
for rowCount in 1..<(2 * degree) {
114+
for columnCount in 1..<degree / 2 {
115+
let dimensions = try MatrixDimensions((rowCount, columnCount))
116+
let encodeValues: [[Scheme.Scalar]] = increasingData(
117+
dimensions: dimensions,
118+
modulus: plaintextModulus)
119+
120+
let plaintextMatrix = try PlaintextMatrix<Scheme, Coeff>(
121+
context: context,
122+
dimensions: dimensions,
123+
packing: .denseRow,
124+
values: encodeValues.flatMap { $0 })
125+
let secretKey = try context.generateSecretKey()
126+
let ciphertextMatrix: CiphertextMatrix = try plaintextMatrix.encrypt(using: secretKey)
127+
128+
let evaluationKeyConfig = try CiphertextMatrix<Scheme, Coeff>.extractDenseRowConfig(
129+
for: encryptionParameters,
130+
dimensions: dimensions)
131+
let evaluationKey = try context.generateEvaluationKey(
132+
config: evaluationKeyConfig,
133+
using: secretKey)
134+
135+
for rowIndex in 0..<rowCount {
136+
let extractedRow = try ciphertextMatrix.extractDenseRow(
137+
rowIndex: rowIndex,
138+
evaluationKey: evaluationKey)
139+
140+
let expectedDimensions = try MatrixDimensions(
141+
rowCount: 1,
142+
columnCount: columnCount)
143+
#expect(extractedRow.dimensions == expectedDimensions)
144+
145+
// Check unpacking
146+
let decrypted = try extractedRow.decrypt(using: secretKey)
147+
let unpacked: [Scheme.Scalar] = try decrypted.unpack()
148+
#expect(unpacked == encodeValues[rowIndex])
149+
150+
// Check encoded values
151+
var row = encodeValues[rowIndex]
152+
row += Array(repeating: 0, count: row.count.nextPowerOfTwo - row.count)
153+
let expectedRow = Array(repeating: row, count: degree / row.count).flatMap { $0 }
154+
let decoded: [Scheme.Scalar] = try decrypted.plaintexts[0].decode(format: .simd)
155+
#expect(decoded == expectedRow)
156+
}
157+
}
158+
}
159+
}
160+
}
161+
}

0 commit comments

Comments
 (0)