@@ -597,8 +597,114 @@ public final class Connection {
597597 if functions [ function] == nil { self . functions [ function] = [ : ] }
598598 functions [ function] ? [ argc] = box
599599 }
600+
601+ /// Creates or redefines a custom SQL aggregate.
602+ ///
603+ /// - Parameters:
604+ ///
605+ /// - aggregate: The name of the aggregate to create or redefine.
606+ ///
607+ /// - argumentCount: The number of arguments that the aggregate takes. If
608+ /// `nil`, the aggregate may take any number of arguments.
609+ ///
610+ /// Default: `nil`
611+ ///
612+ /// - deterministic: Whether or not the aggregate is deterministic (_i.e._
613+ /// the aggregate always returns the same result for a given input).
614+ ///
615+ /// Default: `false`
616+ ///
617+ /// - step: A block of code to run for each row of an aggregation group.
618+ /// The block is called with an array of raw SQL values mapped to the
619+ /// aggregate’s parameters, and an UnsafeMutablePointer to a state
620+ /// variable.
621+ ///
622+ /// - final: A block of code to run after each row of an aggregation group
623+ /// is processed. The block is called with an UnsafeMutablePointer to a
624+ /// state variable, and should return a raw SQL value (or nil).
625+ ///
626+ /// - state: A block of code to run to produce a fresh state variable for
627+ /// each aggregation group. The block should return an
628+ /// UnsafeMutablePointer to the fresh state variable.
629+ public func createAggregation< T> ( _ aggregate: String , argumentCount: UInt ? = nil , deterministic: Bool = false , step: @escaping ( [ Binding ? ] , UnsafeMutablePointer < T > ) -> ( ) , final: @escaping ( UnsafeMutablePointer < T > ) -> Binding ? , state: @escaping ( ) -> UnsafeMutablePointer < T > ) {
630+ let argc = argumentCount. map { Int ( $0) } ?? - 1
631+ let box : Aggregate = { ( stepFlag: Int , context: OpaquePointer ? , argc: Int32 , argv: UnsafeMutablePointer < OpaquePointer ? > ? ) in
632+ let ptr = sqlite3_aggregate_context ( context, 64 ) ! // needs to be at least as large as uintptr_t; better way to do this?
633+ let p = ptr. assumingMemoryBound ( to: UnsafeMutableRawPointer . self)
634+ if stepFlag > 0 {
635+ let arguments : [ Binding ? ] = ( 0 ..< Int ( argc) ) . map { idx in
636+ let value = argv![ idx]
637+ switch sqlite3_value_type ( value) {
638+ case SQLITE_BLOB:
639+ return Blob ( bytes: sqlite3_value_blob ( value) , length: Int ( sqlite3_value_bytes ( value) ) )
640+ case SQLITE_FLOAT:
641+ return sqlite3_value_double ( value)
642+ case SQLITE_INTEGER:
643+ return sqlite3_value_int64 ( value)
644+ case SQLITE_NULL:
645+ return nil
646+ case SQLITE_TEXT:
647+ return String ( cString: UnsafePointer ( sqlite3_value_text ( value) ) )
648+ case let type:
649+ fatalError ( " unsupported value type: \( type) " )
650+ }
651+ }
652+
653+ if ptr. assumingMemoryBound ( to: Int64 . self) . pointee == 0 {
654+ let v = state ( )
655+ p. pointee = UnsafeMutableRawPointer ( mutating: v)
656+ }
657+ step ( arguments, p. pointee. assumingMemoryBound ( to: T . self) )
658+ } else {
659+ let result = final ( p. pointee. assumingMemoryBound ( to: T . self) )
660+ if let result = result as? Blob {
661+ sqlite3_result_blob ( context, result. bytes, Int32 ( result. bytes. count) , nil )
662+ } else if let result = result as? Double {
663+ sqlite3_result_double ( context, result)
664+ } else if let result = result as? Int64 {
665+ sqlite3_result_int64 ( context, result)
666+ } else if let result = result as? String {
667+ sqlite3_result_text ( context, result, Int32 ( result. count) , SQLITE_TRANSIENT)
668+ } else if result == nil {
669+ sqlite3_result_null ( context)
670+ } else {
671+ fatalError ( " unsupported result type: \( String ( describing: result) ) " )
672+ }
673+ }
674+ }
675+
676+ var flags = SQLITE_UTF8
677+ #if !os(Linux)
678+ if deterministic {
679+ flags |= SQLITE_DETERMINISTIC
680+ }
681+ #endif
682+
683+ sqlite3_create_function_v2 (
684+ handle,
685+ aggregate,
686+ Int32 ( argc) ,
687+ flags,
688+ unsafeBitCast ( box, to: UnsafeMutableRawPointer . self) ,
689+ nil ,
690+ { context, argc, value in
691+ let function = unsafeBitCast ( sqlite3_user_data ( context) , to: Aggregate . self)
692+ function ( 1 , context, argc, value)
693+ } ,
694+ { context in
695+ let function = unsafeBitCast ( sqlite3_user_data ( context) , to: Aggregate . self)
696+ function ( 0 , context, 0 , nil )
697+ } ,
698+ nil
699+ )
700+ if aggregations [ aggregate] == nil { self . aggregations [ aggregate] = [ : ] }
701+ aggregations [ aggregate] ? [ argc] = box
702+ }
703+
704+ fileprivate typealias Aggregate = @convention ( block) ( Int , OpaquePointer ? , Int32 , UnsafeMutablePointer < OpaquePointer ? > ? ) -> Void
600705 fileprivate typealias Function = @convention ( block) ( OpaquePointer ? , Int32 , UnsafeMutablePointer < OpaquePointer ? > ? ) -> Void
601706 fileprivate var functions = [ String: [ Int: Function] ] ( )
707+ fileprivate var aggregations = [ String: [ Int: Aggregate] ] ( )
602708
603709 /// Defines a new collating sequence.
604710 ///
0 commit comments