Skip to content

Commit 3a14459

Browse files
committed
Add support for custom aggregate functions
1 parent abf9bc3 commit 3a14459

File tree

1 file changed

+132
-1
lines changed

1 file changed

+132
-1
lines changed

hdr/sqlite_modern_cpp.h

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,47 @@ namespace sqlite {
317317
};
318318

319319
namespace sql_function_binder {
320+
template<
321+
typename ContextType,
322+
std::size_t Count,
323+
typename Functions
324+
>
325+
inline void step(
326+
sqlite3_context* db,
327+
int count,
328+
sqlite3_value** vals
329+
);
330+
331+
template<
332+
std::size_t Count,
333+
typename Functions,
334+
typename... Values
335+
>
336+
inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step(
337+
sqlite3_context* db,
338+
int count,
339+
sqlite3_value** vals,
340+
Values&&... values
341+
);
342+
343+
template<
344+
std::size_t Count,
345+
typename Functions,
346+
typename... Values
347+
>
348+
inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step(
349+
sqlite3_context* db,
350+
int,
351+
sqlite3_value**,
352+
Values&&... values
353+
);
354+
355+
template<
356+
typename ContextType,
357+
typename Functions
358+
>
359+
inline void final(sqlite3_context* db);
360+
320361
template<
321362
std::size_t Count,
322363
typename Function,
@@ -401,6 +442,22 @@ namespace sqlite {
401442
});
402443
}
403444

445+
template <typename StepFunction, typename FinalFunction>
446+
void define(const std::string &name, StepFunction&& step, FinalFunction&& final) {
447+
typedef utility::function_traits<StepFunction> traits;
448+
using ContextType = typename std::remove_reference<typename traits::template argument<0>>::type;
449+
450+
auto funcPtr = new auto(std::make_pair(std::forward<StepFunction>(step), std::forward<FinalFunction>(final)));
451+
if(int result = sqlite3_create_function_v2(
452+
_db.get(), name.c_str(), traits::arity - 1, SQLITE_UTF8, funcPtr, nullptr,
453+
sql_function_binder::step<ContextType, traits::arity, typename std::remove_reference<decltype(*funcPtr)>::type>,
454+
sql_function_binder::final<ContextType, typename std::remove_reference<decltype(*funcPtr)>::type>,
455+
[](void* ptr){
456+
delete static_cast<decltype(funcPtr)>(ptr);
457+
}))
458+
exceptions::throw_sqlite_error(result);
459+
}
460+
404461
};
405462

406463
template<std::size_t Count>
@@ -792,6 +849,80 @@ namespace sqlite {
792849
template<typename T> database_binder& operator << (database_binder&& db, const T& val) { return db << val; }
793850

794851
namespace sql_function_binder {
852+
template<class T>
853+
struct AggregateCtxt {
854+
T obj;
855+
bool constructed = true;
856+
};
857+
858+
template<
859+
typename ContextType,
860+
std::size_t Count,
861+
typename Functions
862+
>
863+
inline void step(
864+
sqlite3_context* db,
865+
int count,
866+
sqlite3_value** vals
867+
) {
868+
auto ctxt = static_cast<AggregateCtxt<ContextType>*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt<ContextType>)));
869+
if(!ctxt) return;
870+
if(!ctxt->constructed) new(ctxt) AggregateCtxt<ContextType>();
871+
step<Count, Functions>(db, count, vals, ctxt->obj);
872+
}
873+
874+
template<
875+
std::size_t Count,
876+
typename Functions,
877+
typename... Values
878+
>
879+
inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step(
880+
sqlite3_context* db,
881+
int count,
882+
sqlite3_value** vals,
883+
Values&&... values
884+
) {
885+
typename utility::function_traits<typename Functions::first_type>::template argument<sizeof...(Values)> value{};
886+
get_val_from_db(vals[sizeof...(Values) - 1], value);
887+
888+
step<Count, Functions>(db, count, vals, std::forward<Values>(values)..., std::move(value));
889+
}
890+
891+
template<
892+
std::size_t Count,
893+
typename Functions,
894+
typename... Values
895+
>
896+
inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step(
897+
sqlite3_context* db,
898+
int,
899+
sqlite3_value**,
900+
Values&&... values
901+
) {
902+
static_cast<Functions*>(sqlite3_user_data(db))->first(std::forward<Values>(values)...);
903+
};
904+
905+
template<
906+
typename ContextType,
907+
typename Functions
908+
>
909+
inline void final(sqlite3_context* db) {
910+
try {
911+
auto ctxt = static_cast<AggregateCtxt<ContextType>*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt<ContextType>)));
912+
if(!ctxt) return;
913+
if(!ctxt->constructed) new(ctxt) AggregateCtxt<ContextType>();
914+
store_result_in_db(db,
915+
static_cast<Functions*>(sqlite3_user_data(db))->second(ctxt->obj));
916+
} catch(sqlite_exception &e) {
917+
sqlite3_result_error_code(db, e.get_code());
918+
sqlite3_result_error(db, e.what(), -1);
919+
} catch(std::exception &e) {
920+
sqlite3_result_error(db, e.what(), -1);
921+
} catch(...) {
922+
sqlite3_result_error(db, "Unknown error", -1);
923+
}
924+
}
925+
795926
template<
796927
std::size_t Count,
797928
typename Function,
@@ -822,7 +953,7 @@ namespace sqlite {
822953
) {
823954
try {
824955
store_result_in_db(db,
825-
(*static_cast<Function*>(sqlite3_user_data(db)))(std::move(values)...));
956+
(*static_cast<Function*>(sqlite3_user_data(db)))(std::forward<Values>(values)...));
826957
} catch(sqlite_exception &e) {
827958
sqlite3_result_error_code(db, e.get_code());
828959
sqlite3_result_error(db, e.what(), -1);

0 commit comments

Comments
 (0)