Skip to content

Commit abf9bc3

Browse files
committed
Introduce support for user defined functions
1 parent f9cf0fd commit abf9bc3

File tree

1 file changed

+200
-1
lines changed

1 file changed

+200
-1
lines changed

hdr/sqlite_modern_cpp.h

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,32 @@ namespace sqlite {
316316
}
317317
};
318318

319+
namespace sql_function_binder {
320+
template<
321+
std::size_t Count,
322+
typename Function,
323+
typename... Values
324+
>
325+
inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar(
326+
sqlite3_context* db,
327+
int count,
328+
sqlite3_value** vals,
329+
Values&&... values
330+
);
331+
332+
template<
333+
std::size_t Count,
334+
typename Function,
335+
typename... Values
336+
>
337+
inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar(
338+
sqlite3_context* db,
339+
int,
340+
sqlite3_value**,
341+
Values&&... values
342+
);
343+
}
344+
319345
class database {
320346
private:
321347
std::shared_ptr<sqlite3> _db;
@@ -362,6 +388,19 @@ namespace sqlite {
362388
return sqlite3_last_insert_rowid(_db.get());
363389
}
364390

391+
template <typename Function>
392+
void define(const std::string &name, Function&& func) {
393+
typedef utility::function_traits<Function> traits;
394+
395+
auto funcPtr = new auto(std::forward<Function>(func));
396+
sqlite3_create_function_v2(
397+
_db.get(), name.c_str(), traits::arity, SQLITE_UTF8, funcPtr,
398+
sql_function_binder::scalar<traits::arity, typename std::remove_reference<Function>::type>,
399+
nullptr, nullptr, [](void* ptr){
400+
delete static_cast<decltype(funcPtr)>(ptr);
401+
});
402+
}
403+
365404
};
366405

367406
template<std::size_t Count>
@@ -420,6 +459,9 @@ namespace sqlite {
420459
}
421460
++db._inx;
422461
return db;
462+
}
463+
inline void store_result_in_db(sqlite3_context* db, const int& val) {
464+
sqlite3_result_int(db, val);
423465
}
424466
inline void get_col_from_db(database_binder& db, int inx, int& val) {
425467
if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) {
@@ -428,6 +470,13 @@ namespace sqlite {
428470
val = sqlite3_column_int(db._stmt.get(), inx);
429471
}
430472
}
473+
inline void get_val_from_db(sqlite3_value *value, int& val) {
474+
if(sqlite3_value_type(value) == SQLITE_NULL) {
475+
val = 0;
476+
} else {
477+
val = sqlite3_value_int(value);
478+
}
479+
}
431480

432481
// sqlite_int64
433482
inline database_binder& operator <<(database_binder& db, const sqlite_int64& val) {
@@ -438,6 +487,9 @@ namespace sqlite {
438487

439488
++db._inx;
440489
return db;
490+
}
491+
inline void store_result_in_db(sqlite3_context* db, const sqlite_int64& val) {
492+
sqlite3_result_int64(db, val);
441493
}
442494
inline void get_col_from_db(database_binder& db, int inx, sqlite3_int64& i) {
443495
if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) {
@@ -446,6 +498,13 @@ namespace sqlite {
446498
i = sqlite3_column_int64(db._stmt.get(), inx);
447499
}
448500
}
501+
inline void get_val_from_db(sqlite3_value *value, sqlite3_int64& i) {
502+
if(sqlite3_value_type(value) == SQLITE_NULL) {
503+
i = 0;
504+
} else {
505+
i = sqlite3_value_int64(value);
506+
}
507+
}
449508

450509
// float
451510
inline database_binder& operator <<(database_binder& db, const float& val) {
@@ -456,6 +515,9 @@ namespace sqlite {
456515

457516
++db._inx;
458517
return db;
518+
}
519+
inline void store_result_in_db(sqlite3_context* db, const float& val) {
520+
sqlite3_result_double(db, val);
459521
}
460522
inline void get_col_from_db(database_binder& db, int inx, float& f) {
461523
if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) {
@@ -464,6 +526,13 @@ namespace sqlite {
464526
f = float(sqlite3_column_double(db._stmt.get(), inx));
465527
}
466528
}
529+
inline void get_val_from_db(sqlite3_value *value, float& f) {
530+
if(sqlite3_value_type(value) == SQLITE_NULL) {
531+
f = 0;
532+
} else {
533+
f = float(sqlite3_value_double(value));
534+
}
535+
}
467536

468537
// double
469538
inline database_binder& operator <<(database_binder& db, const double& val) {
@@ -474,6 +543,9 @@ namespace sqlite {
474543

475544
++db._inx;
476545
return db;
546+
}
547+
inline void store_result_in_db(sqlite3_context* db, const double& val) {
548+
sqlite3_result_double(db, val);
477549
}
478550
inline void get_col_from_db(database_binder& db, int inx, double& d) {
479551
if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) {
@@ -482,6 +554,13 @@ namespace sqlite {
482554
d = sqlite3_column_double(db._stmt.get(), inx);
483555
}
484556
}
557+
inline void get_val_from_db(sqlite3_value *value, double& d) {
558+
if(sqlite3_value_type(value) == SQLITE_NULL) {
559+
d = 0;
560+
} else {
561+
d = sqlite3_value_double(value);
562+
}
563+
}
485564

486565
// vector<T, A>
487566
template<typename T, typename A> inline database_binder& operator<<(database_binder& db, const std::vector<T, A>& vec) {
@@ -494,6 +573,11 @@ namespace sqlite {
494573
++db._inx;
495574
return db;
496575
}
576+
template<typename T, typename A> inline void store_result_in_db(sqlite3_context* db, const std::vector<T, A>& vec) {
577+
void const* buf = reinterpret_cast<void const *>(vec.data());
578+
int bytes = vec.size() * sizeof(T);
579+
sqlite3_result_blob(db, buf, bytes, SQLITE_TRANSIENT);
580+
}
497581
template<typename T, typename A> inline void get_col_from_db(database_binder& db, int inx, std::vector<T, A>& vec) {
498582
if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) {
499583
vec.clear();
@@ -503,6 +587,15 @@ namespace sqlite {
503587
vec = std::vector<T, A>(buf, buf + bytes/sizeof(T));
504588
}
505589
}
590+
template<typename T, typename A> inline void get_val_from_db(sqlite3_value *value, std::vector<T, A>& vec) {
591+
if(sqlite3_value_type(value) == SQLITE_NULL) {
592+
vec.clear();
593+
} else {
594+
int bytes = sqlite3_value_bytes(value);
595+
T const* buf = reinterpret_cast<T const *>(sqlite3_value_blob(value));
596+
vec = std::vector<T, A>(buf, buf + bytes/sizeof(T));
597+
}
598+
}
506599

507600
/* for nullptr support */
508601
inline database_binder& operator <<(database_binder& db, std::nullptr_t) {
@@ -512,6 +605,9 @@ namespace sqlite {
512605
}
513606
++db._inx;
514607
return db;
608+
}
609+
inline void store_result_in_db(sqlite3_context* db, std::nullptr_t) {
610+
sqlite3_result_null(db);
515611
}
516612
/* for nullptr support */
517613
template<typename T> inline database_binder& operator <<(database_binder& db, const std::unique_ptr<T>& val) {
@@ -532,6 +628,15 @@ namespace sqlite {
532628
_ptr_.reset(underling_ptr);
533629
}
534630
}
631+
template<typename T> inline void get_val_from_db(sqlite3_value *value, std::unique_ptr<T>& _ptr_) {
632+
if(sqlite3_value_type(value) == SQLITE_NULL) {
633+
_ptr_ = nullptr;
634+
} else {
635+
auto underling_ptr = new T();
636+
get_val_from_db(value, *underling_ptr);
637+
_ptr_.reset(underling_ptr);
638+
}
639+
}
535640

536641
// std::string
537642
inline void get_col_from_db(database_binder& db, int inx, std::string & s) {
@@ -542,6 +647,14 @@ namespace sqlite {
542647
s = std::string(reinterpret_cast<char const *>(sqlite3_column_text(db._stmt.get(), inx)));
543648
}
544649
}
650+
inline void get_val_from_db(sqlite3_value *value, std::string & s) {
651+
if(sqlite3_value_type(value) == SQLITE_NULL) {
652+
s = std::string();
653+
} else {
654+
sqlite3_value_bytes(value);
655+
s = std::string(reinterpret_cast<char const *>(sqlite3_value_text(value)));
656+
}
657+
}
545658

546659
// Convert char* to string to trigger op<<(..., const std::string )
547660
template<std::size_t N> inline database_binder& operator <<(database_binder& db, const char(&STR)[N]) { return db << std::string(STR); }
@@ -555,6 +668,9 @@ namespace sqlite {
555668

556669
++db._inx;
557670
return db;
671+
}
672+
inline void store_result_in_db(sqlite3_context* db, const std::string& val) {
673+
sqlite3_result_text(db, val.data(), -1, SQLITE_TRANSIENT);
558674
}
559675
// std::u16string
560676
inline void get_col_from_db(database_binder& db, int inx, std::u16string & w) {
@@ -565,6 +681,14 @@ namespace sqlite {
565681
w = std::u16string(reinterpret_cast<char16_t const *>(sqlite3_column_text16(db._stmt.get(), inx)));
566682
}
567683
}
684+
inline void get_val_from_db(sqlite3_value *value, std::u16string & w) {
685+
if(sqlite3_value_type(value) == SQLITE_NULL) {
686+
w = std::u16string();
687+
} else {
688+
sqlite3_value_bytes16(value);
689+
w = std::u16string(reinterpret_cast<char16_t const *>(sqlite3_value_text16(value)));
690+
}
691+
}
568692

569693

570694
inline database_binder& operator <<(database_binder& db, const std::u16string& txt) {
@@ -575,6 +699,9 @@ namespace sqlite {
575699

576700
++db._inx;
577701
return db;
702+
}
703+
inline void store_result_in_db(sqlite3_context* db, const std::u16string& val) {
704+
sqlite3_result_text16(db, val.data(), -1, SQLITE_TRANSIENT);
578705
}
579706
// std::optional support for NULL values
580707
#ifdef _MODERN_SQLITE_STD_OPTIONAL_SUPPORT
@@ -590,13 +717,28 @@ namespace sqlite {
590717
++db._inx;
591718
return db;
592719
}
720+
template <typename OptionalT> inline void store_result_in_db(sqlite3_context* db, const std::optional<OptionalT>& val) {
721+
if(val) {
722+
store_result_in_db(db, *val);
723+
}
724+
sqlite3_result_null(db);
725+
}
593726

594727
template <typename OptionalT> inline void get_col_from_db(database_binder& db, int inx, std::optional<OptionalT>& o) {
595728
if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) {
596729
o.reset();
597730
} else {
598731
OptionalT v;
599-
get_col_from_db(db, inx, v);
732+
get_col_from_db(value, v);
733+
o = std::move(v);
734+
}
735+
}
736+
template <typename OptionalT> inline void get_val_from_db(sqlite3_value *value, std::optional<OptionalT>& o) {
737+
if(sqlite3_value_type(value) == SQLITE_NULL) {
738+
o.reset();
739+
} else {
740+
OptionalT v;
741+
get_val_from_db(value, v);
600742
o = std::move(v);
601743
}
602744
}
@@ -616,6 +758,12 @@ namespace sqlite {
616758
++db._inx;
617759
return db;
618760
}
761+
template <typename BoostOptionalT> inline void store_result_in_db(sqlite3_context* db, const boost::optional<BoostOptionalT>& val) {
762+
if(val) {
763+
store_result_in_db(db, *val);
764+
}
765+
sqlite3_result_null(db);
766+
}
619767

620768
template <typename BoostOptionalT> inline void get_col_from_db(database_binder& db, int inx, boost::optional<BoostOptionalT>& o) {
621769
if(sqlite3_column_type(db._stmt.get(), inx) == SQLITE_NULL) {
@@ -626,6 +774,15 @@ namespace sqlite {
626774
o = std::move(v);
627775
}
628776
}
777+
template <typename BoostOptionalT> inline void get_val_from_db(sqlite3_value *value, boost::optional<BoostOptionalT>& o) {
778+
if(sqlite3_value_type(value) == SQLITE_NULL) {
779+
o.reset();
780+
} else {
781+
BoostOptionalT v;
782+
get_val_from_db(value, v);
783+
o = std::move(v);
784+
}
785+
}
629786
#endif
630787

631788
// Some ppl are lazy so we have a operator for proper prep. statemant handling.
@@ -634,4 +791,46 @@ namespace sqlite {
634791
// Convert the rValue binder to a reference and call first op<<, its needed for the call that creates the binder (be carefull of recursion here!)
635792
template<typename T> database_binder& operator << (database_binder&& db, const T& val) { return db << val; }
636793

794+
namespace sql_function_binder {
795+
template<
796+
std::size_t Count,
797+
typename Function,
798+
typename... Values
799+
>
800+
inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar(
801+
sqlite3_context* db,
802+
int count,
803+
sqlite3_value** vals,
804+
Values&&... values
805+
) {
806+
typename utility::function_traits<Function>::template argument<sizeof...(Values)> value{};
807+
get_val_from_db(vals[sizeof...(Values)], value);
808+
809+
scalar<Count, Function>(db, count, vals, std::forward<Values>(values)..., std::move(value));
810+
}
811+
812+
template<
813+
std::size_t Count,
814+
typename Function,
815+
typename... Values
816+
>
817+
inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar(
818+
sqlite3_context* db,
819+
int,
820+
sqlite3_value**,
821+
Values&&... values
822+
) {
823+
try {
824+
store_result_in_db(db,
825+
(*static_cast<Function*>(sqlite3_user_data(db)))(std::move(values)...));
826+
} catch(sqlite_exception &e) {
827+
sqlite3_result_error_code(db, e.get_code());
828+
sqlite3_result_error(db, e.what(), -1);
829+
} catch(std::exception &e) {
830+
sqlite3_result_error(db, e.what(), -1);
831+
} catch(...) {
832+
sqlite3_result_error(db, "Unknown error", -1);
833+
}
834+
}
835+
}
637836
}

0 commit comments

Comments
 (0)