Skip to content

Commit 2b82d15

Browse files
committed
Add prototype implementation for Connection.createAggregation.
1 parent ed8f603 commit 2b82d15

File tree

3 files changed

+187
-0
lines changed

3 files changed

+187
-0
lines changed

SQLite.xcodeproj/project.pbxproj

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@
8181
19A17FB80B94E882050AA908 /* FoundationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1794CC4D7827E997E32A7 /* FoundationTests.swift */; };
8282
19A17FDA323BAFDEC627E76F /* fixtures in Resources */ = {isa = PBXBuildFile; fileRef = 19A17E2695737FAB5D6086E3 /* fixtures */; };
8383
19A17FF4A10B44D3937C8CAC /* Errors.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1710E73A46D5AC721CDA9 /* Errors.swift */; };
84+
3717F908221F5D8800B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; };
85+
3717F909221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; };
86+
3717F90A221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; };
8487
3D67B3E61DB2469200A4F4C6 /* libsqlite3.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 3D67B3E51DB2469200A4F4C6 /* libsqlite3.tbd */; };
8588
3D67B3E71DB246BA00A4F4C6 /* Blob.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE247AEE1C3F06E900AE3E12 /* Blob.swift */; };
8689
3D67B3E81DB246BA00A4F4C6 /* Connection.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE247AEF1C3F06E900AE3E12 /* Connection.swift */; };
@@ -225,6 +228,7 @@
225228
19A17B93B48B5560E6E51791 /* Fixtures.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Fixtures.swift; sourceTree = "<group>"; };
226229
19A17BA55DABB480F9020C8A /* DateAndTimeFunctions.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DateAndTimeFunctions.swift; sourceTree = "<group>"; };
227230
19A17E2695737FAB5D6086E3 /* fixtures */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = folder; path = fixtures; sourceTree = "<group>"; };
231+
3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CustomAggregationTests.swift; sourceTree = "<group>"; };
228232
3D67B3E51DB2469200A4F4C6 /* libsqlite3.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.tbd; path = Platforms/WatchOS.platform/Developer/SDKs/WatchOS3.0.sdk/usr/lib/libsqlite3.tbd; sourceTree = DEVELOPER_DIR; };
229233
49EB68C31F7B3CB400D89D40 /* Coding.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Coding.swift; sourceTree = "<group>"; };
230234
A121AC451CA35C79005A31D1 /* SQLite.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = SQLite.framework; sourceTree = BUILT_PRODUCTS_DIR; };
@@ -401,6 +405,7 @@
401405
EE247B1D1C3F137700AE3E12 /* ConnectionTests.swift */,
402406
EE247B1E1C3F137700AE3E12 /* CoreFunctionsTests.swift */,
403407
EE247B1F1C3F137700AE3E12 /* CustomFunctionsTests.swift */,
408+
3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */,
404409
EE247B201C3F137700AE3E12 /* ExpressionTests.swift */,
405410
EE247B211C3F137700AE3E12 /* FTS4Tests.swift */,
406411
EE247B2A1C3F141E00AE3E12 /* OperatorsTests.swift */,
@@ -834,6 +839,7 @@
834839
03A65E921C6BB3030062603F /* SetterTests.swift in Sources */,
835840
03A65E891C6BB3030062603F /* ConnectionTests.swift in Sources */,
836841
03A65E8A1C6BB3030062603F /* CoreFunctionsTests.swift in Sources */,
842+
3717F90A221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */,
837843
03A65E931C6BB3030062603F /* StatementTests.swift in Sources */,
838844
03A65E911C6BB3030062603F /* SchemaTests.swift in Sources */,
839845
03A65E8D1C6BB3030062603F /* FTS4Tests.swift in Sources */,
@@ -922,6 +928,7 @@
922928
EE247B271C3F137700AE3E12 /* CustomFunctionsTests.swift in Sources */,
923929
EE247B341C3F142E00AE3E12 /* StatementTests.swift in Sources */,
924930
EE247B301C3F141E00AE3E12 /* RTreeTests.swift in Sources */,
931+
3717F908221F5D8800B9BD3D /* CustomAggregationTests.swift in Sources */,
925932
EE247B231C3F137700AE3E12 /* BlobTests.swift in Sources */,
926933
EE247B351C3F142E00AE3E12 /* ValueTests.swift in Sources */,
927934
EE247B2F1C3F141E00AE3E12 /* QueryTests.swift in Sources */,
@@ -980,6 +987,7 @@
980987
EE247B5F1C3F3FC700AE3E12 /* StatementTests.swift in Sources */,
981988
EE247B5C1C3F3FC700AE3E12 /* RTreeTests.swift in Sources */,
982989
EE247B571C3F3FC700AE3E12 /* CustomFunctionsTests.swift in Sources */,
990+
3717F909221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */,
983991
EE247B601C3F3FC700AE3E12 /* ValueTests.swift in Sources */,
984992
EE247B551C3F3FC700AE3E12 /* ConnectionTests.swift in Sources */,
985993
EE247B611C3F3FC700AE3E12 /* TestHelpers.swift in Sources */,

Sources/SQLite/Core/Connection.swift

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,114 @@ public final class Connection {
597597
if functions[function] == nil { self.functions[function] = [:] }
598598
functions[function]?[argc] = box
599599
}
600+
601+
/// Creates or redefines a custom SQL aggregate.
602+
///
603+
/// - Parameters:
604+
///
605+
/// - aggregate: The name of the aggregate to create or redefine.
606+
///
607+
/// - argumentCount: The number of arguments that the aggregate takes. If
608+
/// `nil`, the aggregate may take any number of arguments.
609+
///
610+
/// Default: `nil`
611+
///
612+
/// - deterministic: Whether or not the aggregate is deterministic (_i.e._
613+
/// the aggregate always returns the same result for a given input).
614+
///
615+
/// Default: `false`
616+
///
617+
/// - step: A block of code to run for each row of an aggregation group.
618+
/// The block is called with an array of raw SQL values mapped to the
619+
/// aggregate’s parameters, and an UnsafeMutablePointer to a state
620+
/// variable.
621+
///
622+
/// - final: A block of code to run after each row of an aggregation group
623+
/// is processed. The block is called with an UnsafeMutablePointer to a
624+
/// state variable, and should return a raw SQL value (or nil).
625+
///
626+
/// - state: A block of code to run to produce a fresh state variable for
627+
/// each aggregation group. The block should return an
628+
/// UnsafeMutablePointer to the fresh state variable.
629+
public func createAggregation<T>(_ aggregate: String, argumentCount: UInt? = nil, deterministic: Bool = false, step: @escaping ([Binding?], UnsafeMutablePointer<T>) -> (), final: @escaping (UnsafeMutablePointer<T>) -> Binding?, state: @escaping () -> UnsafeMutablePointer<T>) {
630+
let argc = argumentCount.map { Int($0) } ?? -1
631+
let box : Aggregate = { (stepFlag: Int, context: OpaquePointer?, argc: Int32, argv: UnsafeMutablePointer<OpaquePointer?>?) in
632+
let ptr = sqlite3_aggregate_context(context, 64)! // needs to be at least as large as uintptr_t; better way to do this?
633+
let p = ptr.assumingMemoryBound(to: UnsafeMutableRawPointer.self)
634+
if stepFlag > 0 {
635+
let arguments: [Binding?] = (0..<Int(argc)).map { idx in
636+
let value = argv![idx]
637+
switch sqlite3_value_type(value) {
638+
case SQLITE_BLOB:
639+
return Blob(bytes: sqlite3_value_blob(value), length: Int(sqlite3_value_bytes(value)))
640+
case SQLITE_FLOAT:
641+
return sqlite3_value_double(value)
642+
case SQLITE_INTEGER:
643+
return sqlite3_value_int64(value)
644+
case SQLITE_NULL:
645+
return nil
646+
case SQLITE_TEXT:
647+
return String(cString: UnsafePointer(sqlite3_value_text(value)))
648+
case let type:
649+
fatalError("unsupported value type: \(type)")
650+
}
651+
}
652+
653+
if ptr.assumingMemoryBound(to: Int64.self).pointee == 0 {
654+
let v = state()
655+
p.pointee = UnsafeMutableRawPointer(mutating: v)
656+
}
657+
step(arguments, p.pointee.assumingMemoryBound(to: T.self))
658+
} else {
659+
let result = final(p.pointee.assumingMemoryBound(to: T.self))
660+
if let result = result as? Blob {
661+
sqlite3_result_blob(context, result.bytes, Int32(result.bytes.count), nil)
662+
} else if let result = result as? Double {
663+
sqlite3_result_double(context, result)
664+
} else if let result = result as? Int64 {
665+
sqlite3_result_int64(context, result)
666+
} else if let result = result as? String {
667+
sqlite3_result_text(context, result, Int32(result.count), SQLITE_TRANSIENT)
668+
} else if result == nil {
669+
sqlite3_result_null(context)
670+
} else {
671+
fatalError("unsupported result type: \(String(describing: result))")
672+
}
673+
}
674+
}
675+
676+
var flags = SQLITE_UTF8
677+
#if !os(Linux)
678+
if deterministic {
679+
flags |= SQLITE_DETERMINISTIC
680+
}
681+
#endif
682+
683+
sqlite3_create_function_v2(
684+
handle,
685+
aggregate,
686+
Int32(argc),
687+
flags,
688+
unsafeBitCast(box, to: UnsafeMutableRawPointer.self),
689+
nil,
690+
{ context, argc, value in
691+
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
692+
function(1, context, argc, value)
693+
},
694+
{ context in
695+
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
696+
function(0, context, 0, nil)
697+
},
698+
nil
699+
)
700+
if aggregations[aggregate] == nil { self.aggregations[aggregate] = [:] }
701+
aggregations[aggregate]?[argc] = box
702+
}
703+
704+
fileprivate typealias Aggregate = @convention(block) (Int, OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
600705
fileprivate typealias Function = @convention(block) (OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
601706
fileprivate var functions = [String: [Int: Function]]()
707+
fileprivate var aggregations = [String: [Int: Aggregate]]()
602708

603709
/// Defines a new collating sequence.
604710
///
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import XCTest
2+
import Foundation
3+
import Dispatch
4+
@testable import SQLite
5+
6+
#if SQLITE_SWIFT_STANDALONE
7+
import sqlite3
8+
#elseif SQLITE_SWIFT_SQLCIPHER
9+
import SQLCipher
10+
#elseif os(Linux)
11+
import CSQLite
12+
#else
13+
import SQLite3
14+
#endif
15+
16+
class CustomAggregationTests : SQLiteTestCase {
17+
override func setUp() {
18+
super.setUp()
19+
CreateUsersTable()
20+
try! InsertUser("Alice", age: 30, admin: true)
21+
try! InsertUser("Bob", age: 25, admin: true)
22+
try! InsertUser("Eve", age: 28, admin: false)
23+
}
24+
25+
func testCustomSum() {
26+
let step = { (bindings: [Binding?], state: UnsafeMutablePointer<Int64>) in
27+
if let v = bindings[0] as? Int64 {
28+
state.pointee += v
29+
}
30+
}
31+
32+
let final = { (state: UnsafeMutablePointer<Int64>) -> Binding? in
33+
let v = state.pointee
34+
let p = UnsafeMutableBufferPointer(start: state, count: 1)
35+
p.deallocate()
36+
return v
37+
}
38+
let _ = db.createAggregation("mySUM", step: step, final: final) {
39+
let v = UnsafeMutableBufferPointer<Int64>.allocate(capacity: 1)
40+
v[0] = 0
41+
return v.baseAddress!
42+
}
43+
let result = try! db.prepare("SELECT mySUM(age) AS s FROM users")
44+
let i = result.columnNames.index(of: "s")!
45+
for row in result {
46+
let value = row[i] as? Int64
47+
XCTAssertEqual(83, value)
48+
}
49+
}
50+
51+
func testCustomSumGrouping() {
52+
let step = { (bindings: [Binding?], state: UnsafeMutablePointer<Int64>) in
53+
if let v = bindings[0] as? Int64 {
54+
state.pointee += v
55+
}
56+
}
57+
let final = { (state: UnsafeMutablePointer<Int64>) -> Binding? in
58+
let v = state.pointee
59+
let p = UnsafeMutableBufferPointer(start: state, count: 1)
60+
p.deallocate()
61+
return v
62+
}
63+
let _ = db.createAggregation("mySUM", step: step, final: final) {
64+
let v = UnsafeMutableBufferPointer<Int64>.allocate(capacity: 1)
65+
v[0] = 0
66+
return v.baseAddress!
67+
}
68+
let result = try! db.prepare("SELECT mySUM(age) AS s FROM users GROUP BY admin ORDER BY s")
69+
let i = result.columnNames.index(of: "s")!
70+
let values = result.compactMap { $0[i] as? Int64 }
71+
XCTAssertTrue(values.elementsEqual([28, 55]))
72+
}
73+
}

0 commit comments

Comments
 (0)