Skip to content

Commit 1e6bfbf

Browse files
committed
Added convenience overloads for Connection.createAggregation.
1 parent 2b82d15 commit 1e6bfbf

File tree

3 files changed

+141
-11
lines changed

3 files changed

+141
-11
lines changed

Sources/SQLite/Core/Connection.swift

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,15 @@ public final class Connection {
626626
/// - state: A block of code to run to produce a fresh state variable for
627627
/// each aggregation group. The block should return an
628628
/// UnsafeMutablePointer to the fresh state variable.
629-
public func createAggregation<T>(_ aggregate: String, argumentCount: UInt? = nil, deterministic: Bool = false, step: @escaping ([Binding?], UnsafeMutablePointer<T>) -> (), final: @escaping (UnsafeMutablePointer<T>) -> Binding?, state: @escaping () -> UnsafeMutablePointer<T>) {
629+
public func createAggregation<T>(
630+
_ aggregate: String,
631+
argumentCount: UInt? = nil,
632+
deterministic: Bool = false,
633+
step: @escaping ([Binding?], UnsafeMutablePointer<T>) -> (),
634+
final: @escaping (UnsafeMutablePointer<T>) -> Binding?,
635+
state: @escaping () -> UnsafeMutablePointer<T>) {
636+
637+
630638
let argc = argumentCount.map { Int($0) } ?? -1
631639
let box : Aggregate = { (stepFlag: Int, context: OpaquePointer?, argc: Int32, argv: UnsafeMutablePointer<OpaquePointer?>?) in
632640
let ptr = sqlite3_aggregate_context(context, 64)! // needs to be at least as large as uintptr_t; better way to do this?
@@ -690,17 +698,17 @@ public final class Connection {
690698
{ context, argc, value in
691699
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
692700
function(1, context, argc, value)
693-
},
701+
},
694702
{ context in
695703
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
696704
function(0, context, 0, nil)
697-
},
705+
},
698706
nil
699707
)
700708
if aggregations[aggregate] == nil { self.aggregations[aggregate] = [:] }
701709
aggregations[aggregate]?[argc] = box
702710
}
703-
711+
704712
fileprivate typealias Aggregate = @convention(block) (Int, OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
705713
fileprivate typealias Function = @convention(block) (OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
706714
fileprivate var functions = [String: [Int: Function]]()

Sources/SQLite/Typed/CustomFunctions.swift

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,69 @@ public extension Connection {
133133
}
134134
}
135135

136+
// MARK: -
137+
138+
public func createAggregation<T: AnyObject>(
139+
_ aggregate: String,
140+
argumentCount: UInt? = nil,
141+
deterministic: Bool = false,
142+
initialValue: T,
143+
reduce: @escaping (T, [Binding?]) -> T,
144+
result: @escaping (T) -> Binding?
145+
) {
146+
147+
let step: ([Binding?], UnsafeMutablePointer<UnsafeMutableRawPointer>) -> () = { (bindings, ptr) in
148+
let p = ptr.pointee.assumingMemoryBound(to: T.self)
149+
let current = Unmanaged<T>.fromOpaque(p).takeRetainedValue()
150+
let next = reduce(current, bindings)
151+
ptr.pointee = Unmanaged.passRetained(next).toOpaque()
152+
}
153+
154+
let final: (UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Binding? = { (ptr) in
155+
let p = ptr.pointee.assumingMemoryBound(to: T.self)
156+
let obj = Unmanaged<T>.fromOpaque(p).takeRetainedValue()
157+
let value = result(obj)
158+
ptr.deallocate()
159+
return value
160+
}
161+
162+
let state: () -> UnsafeMutablePointer<UnsafeMutableRawPointer> = {
163+
let p = UnsafeMutablePointer<UnsafeMutableRawPointer>.allocate(capacity: 1)
164+
p.pointee = Unmanaged.passRetained(initialValue).toOpaque()
165+
return p
166+
}
167+
168+
createAggregation(aggregate, step: step, final: final, state: state)
169+
}
170+
171+
public func createAggregation<T>(
172+
_ aggregate: String,
173+
argumentCount: UInt? = nil,
174+
deterministic: Bool = false,
175+
initialValue: T,
176+
reduce: @escaping (T, [Binding?]) -> T,
177+
result: @escaping (T) -> Binding?
178+
) {
179+
180+
let step: ([Binding?], UnsafeMutablePointer<T>) -> () = { (bindings, p) in
181+
let current = p.pointee
182+
let next = reduce(current, bindings)
183+
p.pointee = next
184+
}
185+
186+
let final: (UnsafeMutablePointer<T>) -> Binding? = { (p) in
187+
let v = result(p.pointee)
188+
p.deallocate()
189+
return v
190+
}
191+
192+
let state: () -> UnsafeMutablePointer<T> = {
193+
let p = UnsafeMutablePointer<T>.allocate(capacity: 1)
194+
p.pointee = initialValue
195+
return p
196+
}
197+
198+
createAggregation(aggregate, step: step, final: final, state: state)
199+
}
200+
136201
}

Tests/SQLiteTests/CustomAggregationTests.swift

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,33 @@ class CustomAggregationTests : SQLiteTestCase {
2222
try! InsertUser("Eve", age: 28, admin: false)
2323
}
2424

25-
func testCustomSum() {
25+
func testUnsafeCustomSum() {
2626
let step = { (bindings: [Binding?], state: UnsafeMutablePointer<Int64>) in
2727
if let v = bindings[0] as? Int64 {
2828
state.pointee += v
2929
}
3030
}
31-
31+
3232
let final = { (state: UnsafeMutablePointer<Int64>) -> Binding? in
3333
let v = state.pointee
3434
let p = UnsafeMutableBufferPointer(start: state, count: 1)
3535
p.deallocate()
3636
return v
3737
}
38-
let _ = db.createAggregation("mySUM", step: step, final: final) {
38+
let _ = db.createAggregation("mySUM1", step: step, final: final) {
3939
let v = UnsafeMutableBufferPointer<Int64>.allocate(capacity: 1)
4040
v[0] = 0
4141
return v.baseAddress!
4242
}
43-
let result = try! db.prepare("SELECT mySUM(age) AS s FROM users")
43+
let result = try! db.prepare("SELECT mySUM1(age) AS s FROM users")
4444
let i = result.columnNames.index(of: "s")!
4545
for row in result {
4646
let value = row[i] as? Int64
4747
XCTAssertEqual(83, value)
4848
}
4949
}
5050

51-
func testCustomSumGrouping() {
51+
func testUnsafeCustomSumGrouping() {
5252
let step = { (bindings: [Binding?], state: UnsafeMutablePointer<Int64>) in
5353
if let v = bindings[0] as? Int64 {
5454
state.pointee += v
@@ -60,14 +60,71 @@ class CustomAggregationTests : SQLiteTestCase {
6060
p.deallocate()
6161
return v
6262
}
63-
let _ = db.createAggregation("mySUM", step: step, final: final) {
63+
let _ = db.createAggregation("mySUM2", step: step, final: final) {
6464
let v = UnsafeMutableBufferPointer<Int64>.allocate(capacity: 1)
6565
v[0] = 0
6666
return v.baseAddress!
6767
}
68-
let result = try! db.prepare("SELECT mySUM(age) AS s FROM users GROUP BY admin ORDER BY s")
68+
let result = try! db.prepare("SELECT mySUM2(age) AS s FROM users GROUP BY admin ORDER BY s")
6969
let i = result.columnNames.index(of: "s")!
7070
let values = result.compactMap { $0[i] as? Int64 }
7171
XCTAssertTrue(values.elementsEqual([28, 55]))
7272
}
73+
74+
func testCustomSum() {
75+
let reduce : (Int64, [Binding?]) -> Int64 = { (last, bindings) in
76+
let v = (bindings[0] as? Int64) ?? 0
77+
return last + v
78+
}
79+
let _ = db.createAggregation("myReduceSUM1", initialValue: Int64(2000), reduce: reduce, result: { $0 })
80+
let result = try! db.prepare("SELECT myReduceSUM1(age) AS s FROM users")
81+
let i = result.columnNames.index(of: "s")!
82+
for row in result {
83+
let value = row[i] as? Int64
84+
XCTAssertEqual(2083, value)
85+
}
86+
}
87+
88+
func testCustomSumGrouping() {
89+
let reduce : (Int64, [Binding?]) -> Int64 = { (last, bindings) in
90+
let v = (bindings[0] as? Int64) ?? 0
91+
return last + v
92+
}
93+
let _ = db.createAggregation("myReduceSUM2", initialValue: Int64(3000), reduce: reduce, result: { $0 })
94+
let result = try! db.prepare("SELECT myReduceSUM2(age) AS s FROM users GROUP BY admin ORDER BY s")
95+
let i = result.columnNames.index(of: "s")!
96+
let values = result.compactMap { $0[i] as? Int64 }
97+
XCTAssertTrue(values.elementsEqual([3028, 3055]))
98+
}
99+
100+
func testCustomObjectSum() {
101+
{
102+
let initial = TestObject(value: 1000)
103+
let reduce : (TestObject, [Binding?]) -> TestObject = { (last, bindings) in
104+
let v = (bindings[0] as? Int64) ?? 0
105+
return TestObject(value: last.value + v)
106+
}
107+
let _ = db.createAggregation("myReduceSUMX", initialValue: initial, reduce: reduce, result: { $0.value })
108+
// end this scope to ensure that the initial value is retained
109+
// by the createAggregation call.
110+
}()
111+
let result = try! db.prepare("SELECT myReduceSUMX(age) AS s FROM users")
112+
let i = result.columnNames.index(of: "s")!
113+
for row in result {
114+
let value = row[i] as? Int64
115+
XCTAssertEqual(1083, value)
116+
}
117+
}
118+
}
119+
120+
/// This class is used to test that aggregation state variables
121+
/// can be reference types and are properly memory managed when
122+
/// crossing the Swift<->C boundary multiple times.
123+
class TestObject {
124+
var value: Int64
125+
init(value: Int64) {
126+
self.value = value
127+
}
128+
deinit {
129+
}
73130
}

0 commit comments

Comments
 (0)