diff --git a/.flake8 b/.flake8 index d2f09077..4b5a4c0b 100644 --- a/.flake8 +++ b/.flake8 @@ -3,7 +3,7 @@ exclude = docs/* resources/* licenses/* -max-complexity = 45 +max-complexity = 50 max-line-length = 90 min-python-version = 3.9.0 per-file-ignores = diff --git a/.gitignore b/.gitignore index 5fc32a4a..02a1b941 100644 --- a/.gitignore +++ b/.gitignore @@ -99,3 +99,4 @@ dask-worker-space singlestoredb/mysql/tests/databases.json trees +CODIFI_** diff --git a/accel.c b/accel.c index a785b975..d9540841 100644 --- a/accel.c +++ b/accel.c @@ -357,6 +357,58 @@ inline int IMIN(int a, int b) { return((a) < (b) ? a : b); } static PyObject *create_numpy_array(PyObject *py_memview, char *data_format, int data_type, PyObject *py_objs); +static PyObject *apply_transformer(PyObject *transformer, PyObject *value) { + if (value == NULL) { + return value; + } + + if (transformer == NULL || transformer == Py_None) { + // No transformation needed, return value as-is + // We steal the reference from the caller + return value; + } + + PyObject *out = PyObject_CallFunction(transformer, "O", value); + // Always decref value since we're stealing the reference from the caller + // This includes Py_None when it's passed as a new reference (e.g., from PyIter_Next) + Py_DECREF(value); + + return out; +} + +/* +static PyObject *apply_transformer_numpy(PyObject *numpy_vectorize, PyObject *transformer, PyObject *array) { + if (array == NULL) { + return NULL; + } + + if (transformer == NULL || transformer == Py_None) { + // Increment refcount fo Py_None since we are returning it + if (array == Py_None) { + Py_INCREF(array); + } + return array; + } + + // Convert function to vectorized function + PyObject *py_vec_func = PyObject_CallFunction(numpy_vectorize, "O", transformer); + if (!py_vec_func) { + PyErr_SetString(PyExc_ValueError, "unable to vectorize transformer function"); + return NULL; + } + + PyObject *out = PyObject_CallFunction(py_vec_func, "O", array); + // Don't decref Py_None since we assume it was passed in literally and not from a new reference + if (array != NULL && array != Py_None) { + Py_DECREF(array); + } + + Py_DECREF(py_vec_func); + + return out; +} +*/ + char *_PyUnicode_AsUTF8(PyObject *unicode) { PyObject *bytes = PyUnicode_AsEncodedString(unicode, "utf-8", "strict"); if (!bytes) return NULL; @@ -364,11 +416,13 @@ char *_PyUnicode_AsUTF8(PyObject *unicode) { char *str = NULL; Py_ssize_t str_l = 0; if (PyBytes_AsStringAndSize(bytes, &str, &str_l) < 0) { + Py_DECREF(bytes); return NULL; } char *out = calloc(str_l + 1, 1); memcpy(out, str, str_l); + Py_DECREF(bytes); return out; } @@ -907,7 +961,7 @@ static int State_init(StateObject *self, PyObject *args, PyObject *kwds) { NULL : _PyUnicode_AsUTF8(py_encoding); self->py_invalid_values[i] = (!py_invalid_value || py_invalid_value == Py_None) ? - NULL : py_converter; + NULL : py_invalid_value; Py_XINCREF(self->py_invalid_values[i]); self->py_converters[i] = ((!py_converter || py_converter == py_default_converter) @@ -953,13 +1007,13 @@ static int State_init(StateObject *self, PyObject *args, PyObject *kwds) { py_args = PyTuple_New(2); if (!py_args) goto error; + Py_INCREF(PyStr.Row); rc = PyTuple_SetItem(py_args, 0, PyStr.Row); if (rc) goto error; - Py_INCREF(PyStr.Row); + Py_INCREF(self->py_names_list); rc = PyTuple_SetItem(py_args, 1, self->py_names_list); if (rc) goto error; - Py_INCREF(self->py_names_list); self->py_namedtuple = PyObject_Call( PyFunc.collections_namedtuple, @@ -1271,9 +1325,9 @@ static PyObject *read_packet(StateObject *py_state) { if (!py_recv_data) goto error; py_new_buff = PyByteArray_Concat(py_buff, py_recv_data); + if (!py_new_buff) goto error; Py_CLEAR(py_recv_data); Py_CLEAR(py_buff); - if (!py_new_buff) goto error; py_buff = py_new_buff; py_new_buff = NULL; @@ -1949,7 +2003,10 @@ static PyObject *read_row_from_packet( break; case ACCEL_OUT_DICTS: case ACCEL_OUT_ARROW: - PyDict_SetItem(py_result, py_state->py_names[i], py_item); + if (PyDict_SetItem(py_result, py_state->py_names[i], py_item) < 0) { + Py_DECREF(py_item); + goto error; + } Py_DECREF(py_item); break; default: @@ -2014,10 +2071,10 @@ static PyObject *read_rowdata_packet(PyObject *self, PyObject *args, PyObject *k PyObject *py_args = PyTuple_New(2); if (!py_args) goto error; - PyTuple_SetItem(py_args, 0, py_res); - PyTuple_SetItem(py_args, 1, py_requested_n_rows); Py_INCREF(py_res); Py_INCREF(py_requested_n_rows); + PyTuple_SetItem(py_args, 0, py_res); + PyTuple_SetItem(py_args, 1, py_requested_n_rows); py_state = (StateObject*)PyObject_CallObject((PyObject*)StateType, py_args); if (!py_state) { Py_DECREF(py_args); goto error; } @@ -2222,6 +2279,7 @@ static PyObject *load_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k PyObject *py_objs = NULL; PyObject *py_mask = NULL; PyObject *py_pair = NULL; + PyObject **py_transformers = NULL; Py_ssize_t length = 0; uint8_t is_null = 0; int8_t i8 = 0; @@ -2266,6 +2324,8 @@ static PyObject *load_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k // Determine column types ctypes = calloc(sizeof(int), n_cols); if (!ctypes) goto error; + py_transformers = calloc(sizeof(PyObject*), n_cols); + if (!py_transformers) goto error; for (i = 0; i < n_cols; i++) { PyObject *py_cspec = PySequence_GetItem(py_colspec, i); if (!py_cspec) goto error; @@ -2273,6 +2333,8 @@ static PyObject *load_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k if (!py_ctype) { Py_DECREF(py_cspec); goto error; } ctypes[i] = (int)PyLong_AsLong(py_ctype); Py_DECREF(py_ctype); + py_transformers[i] = PySequence_GetItem(py_cspec, 2); + if (!py_transformers[i]) { Py_DECREF(py_cspec); goto error; } Py_DECREF(py_cspec); if (PyErr_Occurred()) { goto error; } } @@ -2689,6 +2751,12 @@ static PyObject *load_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k } free(mask_cols); } + if (py_transformers) { + for (i = 0; i < n_cols; i++) { + Py_XDECREF(py_transformers[i]); + } + free(py_transformers); + } if (out_row_ids) free(out_row_ids); if (data_formats) free(data_formats); if (item_sizes) free(item_sizes); @@ -2875,6 +2943,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k PyObject *py_row_ids = NULL; PyObject *py_cols = NULL; PyObject *py_out = NULL; + PyObject **py_transformers = NULL; unsigned long long n_cols = 0; unsigned long long n_rows = 0; uint8_t is_null = 0; @@ -2965,11 +3034,23 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k goto error; } + // Get transformers + py_transformers = malloc(sizeof(PyObject*) * n_cols); + if (!py_transformers) { + PyErr_SetString(PyExc_MemoryError, "failed to allocate transformers array"); + goto error; + } + for (i = 0; i < n_cols; i++) { - PyObject *py_item = PySequence_GetItem(py_returns, i); - if (!py_item) goto error; - returns[i] = (int)PyLong_AsLong(py_item); - Py_DECREF(py_item); + PyObject *py_retspec = PySequence_GetItem(py_returns, i); + if (!py_retspec) goto error; + PyObject *py_rtype = PySequence_GetItem(py_retspec, 1); + if (!py_rtype) { Py_DECREF(py_retspec); goto error; } + returns[i] = (int)PyLong_AsLong(py_rtype); + Py_DECREF(py_rtype); + py_transformers[i] = PySequence_GetItem(py_retspec, 2); + if (!py_transformers[i]) { Py_DECREF(py_retspec); goto error; } + Py_DECREF(py_retspec); if (PyErr_Occurred()) { goto error; } } @@ -4010,7 +4091,10 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k out_idx += 8; } else { PyObject *py_bytes = PyUnicode_AsEncodedString(py_str, "utf-8", "strict"); - if (!py_bytes) goto error; + if (!py_bytes) { + PyErr_SetString(PyExc_ValueError, "unsupported numpy data type for character output types"); + goto error; + } char *str = NULL; Py_ssize_t str_l = 0; @@ -4088,6 +4172,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k char *str = NULL; Py_ssize_t str_l = 0; if (PyBytes_AsStringAndSize(py_bytes, &str, &str_l) < 0) { + PyErr_SetString(PyExc_ValueError, "unsupported numpy data type for binary output types"); goto error; } @@ -4134,8 +4219,6 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) PyObject *py_out_rows = NULL; PyObject *py_row = NULL; PyObject *py_colspec = NULL; - PyObject *py_str = NULL; - PyObject *py_blob = NULL; Py_ssize_t length = 0; uint64_t row_id = 0; uint8_t is_null = 0; @@ -4150,11 +4233,14 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) float flt = 0; double dbl = 0; int *ctypes = NULL; + PyObject **py_transformers = NULL; char *data = NULL; char *end = NULL; unsigned long long colspec_l = 0; unsigned long long i = 0; char *keywords[] = {"colspec", "data", NULL}; + PyObject *t = NULL; + PyObject *py_value = NULL; // Parse function args. if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO", keywords, &py_colspec, &py_data)) { @@ -4165,15 +4251,25 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) end = data + (unsigned long long)length; colspec_l = PyObject_Length(py_colspec); - ctypes = malloc(sizeof(int) * colspec_l); + ctypes = calloc(sizeof(int), colspec_l); + if (!ctypes) goto error; + py_transformers = calloc(sizeof(PyObject*), colspec_l); + if (!py_transformers) goto error; for (i = 0; i < colspec_l; i++) { PyObject *py_cspec = PySequence_GetItem(py_colspec, i); if (!py_cspec) goto error; + + // Extract type (second element) PyObject *py_ctype = PySequence_GetItem(py_cspec, 1); if (!py_ctype) { Py_DECREF(py_cspec); goto error; } ctypes[i] = (int)PyLong_AsLong(py_ctype); Py_DECREF(py_ctype); + + // Extract transformer (third element) + py_transformers[i] = PySequence_GetItem(py_cspec, 2); + if (!py_transformers[i]) { Py_DECREF(py_cspec); goto error; } + Py_DECREF(py_cspec); if (PyErr_Occurred()) { goto error; } } @@ -4188,32 +4284,42 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) if (!py_out) { Py_DECREF(py_out_row_ids); Py_DECREF(py_out_rows); goto error; } if (PyTuple_SetItem(py_out, 0, py_out_row_ids) < 0) { + // PyTuple_SetItem steals reference on success, so on failure we still own it Py_DECREF(py_out_row_ids); Py_DECREF(py_out_rows); goto error; } + // py_out_row_ids reference now stolen by tuple if (PyTuple_SetItem(py_out, 1, py_out_rows) < 0) { + // py_out_row_ids already stolen by first SetItem, only clean up py_out_rows Py_DECREF(py_out_rows); goto error; } + // py_out_rows reference now stolen by tuple while (end > data) { py_row = PyTuple_New(colspec_l); if (!py_row) goto error; row_id = *(int64_t*)data; data += 8; - CHECKRC(PyList_Append(py_out_row_ids, PyLong_FromLongLong(row_id))); + PyObject *py_row_id = PyLong_FromLongLong(row_id); + if (!py_row_id) goto error; + CHECKRC(PyList_Append(py_out_row_ids, py_row_id)); + Py_DECREF(py_row_id); for (unsigned long long i = 0; i < colspec_l; i++) { is_null = data[0] == '\x01'; data += 1; if (is_null) Py_INCREF(Py_None); + t = py_transformers[i]; + switch (ctypes[i]) { case MYSQL_TYPE_NULL: data += 1; - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); + py_value = apply_transformer(t, Py_None); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; case MYSQL_TYPE_BIT: @@ -4222,108 +4328,78 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) case MYSQL_TYPE_TINY: i8 = *(int8_t*)data; data += 1; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i8))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyLong_FromLong((long)i8)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; // Use negative to indicate unsigned case -MYSQL_TYPE_TINY: u8 = *(uint8_t*)data; data += 1; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u8))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyLong_FromUnsignedLong((unsigned long)u8)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; case MYSQL_TYPE_SHORT: i16 = *(int16_t*)data; data += 2; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i16))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyLong_FromLong((long)i16)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; // Use negative to indicate unsigned case -MYSQL_TYPE_SHORT: u16 = *(uint16_t*)data; data += 2; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u16))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyLong_FromUnsignedLong((unsigned long)u16)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; case MYSQL_TYPE_LONG: case MYSQL_TYPE_INT24: i32 = *(int32_t*)data; data += 4; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i32))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyLong_FromLong((long)i32)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; // Use negative to indicate unsigned case -MYSQL_TYPE_LONG: case -MYSQL_TYPE_INT24: u32 = *(uint32_t*)data; data += 4; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u32))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyLong_FromUnsignedLong((unsigned long)u32)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; case MYSQL_TYPE_LONGLONG: i64 = *(int64_t*)data; data += 8; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLongLong((long long)i64))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyLong_FromLongLong((long long)i64)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; // Use negative to indicate unsigned case -MYSQL_TYPE_LONGLONG: u64 = *(uint64_t*)data; data += 8; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLongLong((unsigned long long)u64))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyLong_FromUnsignedLongLong((unsigned long long)u64)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; case MYSQL_TYPE_FLOAT: flt = *(float*)data; data += 4; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyFloat_FromDouble((double)flt))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyFloat_FromDouble((double)flt)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; case MYSQL_TYPE_DOUBLE: dbl = *(double*)data; data += 8; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyFloat_FromDouble((double)dbl))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyFloat_FromDouble((double)dbl)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; case MYSQL_TYPE_DECIMAL: @@ -4350,12 +4426,9 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) case MYSQL_TYPE_YEAR: u16 = *(uint16_t*)data; data += 2; - if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); - } else { - CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u16))); - } + py_value = apply_transformer(t, (is_null) ? Py_None : PyLong_FromUnsignedLong((unsigned long)u16)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); break; case MYSQL_TYPE_VARCHAR: @@ -4371,13 +4444,14 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) case MYSQL_TYPE_BLOB: i64 = *(int64_t*)data; data += 8; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); + py_value = apply_transformer(t, Py_None); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); } else { - py_str = PyUnicode_FromStringAndSize(data, (Py_ssize_t)i64); + py_value = apply_transformer(t, PyUnicode_FromStringAndSize(data, (Py_ssize_t)i64)); data += i64; - if (!py_str) goto error; - CHECKRC(PyTuple_SetItem(py_row, i, py_str)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); } break; @@ -4395,13 +4469,14 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) case -MYSQL_TYPE_BLOB: i64 = *(int64_t*)data; data += 8; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); - Py_INCREF(Py_None); + py_value = apply_transformer(t, Py_None); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); } else { - py_blob = PyBytes_FromStringAndSize(data, (Py_ssize_t)i64); + py_value = apply_transformer(t, PyBytes_FromStringAndSize(data, (Py_ssize_t)i64)); data += i64; - if (!py_blob) goto error; - CHECKRC(PyTuple_SetItem(py_row, i, py_blob)); + if (!py_value) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_value)); } break; @@ -4417,6 +4492,12 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) exit: if (ctypes) free(ctypes); + if (py_transformers) { + for (unsigned long long j = 0; j < colspec_l; j++) { + Py_XDECREF(py_transformers[j]); + } + free(py_transformers); + } Py_XDECREF(py_row); @@ -4456,6 +4537,7 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) unsigned long long out_l = 0; unsigned long long out_idx = 0; int *returns = NULL; + PyObject **py_transformers = NULL; char *keywords[] = {"returns", "row_ids", "data", NULL}; unsigned long long i = 0; unsigned long long n_cols = 0; @@ -4485,14 +4567,26 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) goto error; } - returns = malloc(sizeof(int) * n_cols); + returns = calloc(sizeof(int), n_cols); if (!returns) goto error; + py_transformers = calloc(sizeof(PyObject*), n_cols); + if (!py_transformers) goto error; for (i = 0; i < n_cols; i++) { - PyObject *py_item = PySequence_GetItem(py_returns, i); - if (!py_item) goto error; - returns[i] = (int)PyLong_AsLong(py_item); - Py_DECREF(py_item); + PyObject *py_retspec = PySequence_GetItem(py_returns, i); + if (!py_retspec) goto error; + + // Extract return type (second element) + PyObject *py_rtype = PySequence_GetItem(py_retspec, 1); + if (!py_rtype) { Py_DECREF(py_retspec); goto error; } + returns[i] = (int)PyLong_AsLong(py_rtype); + Py_DECREF(py_rtype); + + // Extract transformer (third element) + py_transformers[i] = PySequence_GetItem(py_retspec, 2); + if (!py_transformers[i]) { Py_DECREF(py_retspec); goto error; } + + Py_DECREF(py_retspec); if (PyErr_Occurred()) { goto error; } } @@ -4538,6 +4632,7 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) while ((py_item = PyIter_Next(py_row_iter))) { + py_item = apply_transformer(py_transformers[i], py_item); is_null = (uint8_t)(py_item == Py_None); CHECKMEM(1); @@ -4751,6 +4846,12 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) exit: if (out) free(out); if (returns) free(returns); + if (py_transformers) { + for (unsigned long long j = 0; j < n_cols; j++) { + Py_XDECREF(py_transformers[j]); + } + free(py_transformers); + } Py_XDECREF(py_item); Py_XDECREF(py_row_iter); diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 3da98ff4..e71cfc32 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -9,7 +9,7 @@ from typing import Union from . import utils -from .dtypes import SQLString +from .sql_types import SQLString ParameterType = Union[ diff --git a/singlestoredb/functions/dtypes.py b/singlestoredb/functions/dtypes.py index 0fe26a45..fce31037 100644 --- a/singlestoredb/functions/dtypes.py +++ b/singlestoredb/functions/dtypes.py @@ -1,1793 +1,9 @@ -#!/usr/bin/env python3 -import base64 -import datetime -import decimal -import re -from typing import Any -from typing import Callable -from typing import Optional -from typing import Union +from warnings import warn -from ..converters import converters -from ..mysql.converters import escape_item # type: ignore -from ..utils.dtypes import DEFAULT_VALUES # noqa -from ..utils.dtypes import NUMPY_TYPE_MAP # noqa -from ..utils.dtypes import PANDAS_TYPE_MAP # noqa -from ..utils.dtypes import POLARS_TYPE_MAP # noqa -from ..utils.dtypes import PYARROW_TYPE_MAP # noqa +from .sql_types import * # noqa: F403, F401 - -DataType = Union[str, Callable[..., Any]] - - -class SQLString(str): - """SQL string type.""" - name: Optional[str] = None - - -class NULL: - """NULL (for use in default values).""" - pass - - -def escape_name(name: str) -> str: - """Escape a function parameter name.""" - if '`' in name: - name = name.replace('`', '``') - return f'`{name}`' - - -# charsets -utf8mb4 = 'utf8mb4' -utf8 = 'utf8' -binary = 'binary' - -# collations -utf8_general_ci = 'utf8_general_ci' -utf8_bin = 'utf8_bin' -utf8_unicode_ci = 'utf8_unicode_ci' -utf8_icelandic_ci = 'utf8_icelandic_ci' -utf8_latvian_ci = 'utf8_latvian_ci' -utf8_romanian_ci = 'utf8_romanian_ci' -utf8_slovenian_ci = 'utf8_slovenian_ci' -utf8_polish_ci = 'utf8_polish_ci' -utf8_estonian_ci = 'utf8_estonian_ci' -utf8_spanish_ci = 'utf8_spanish_ci' -utf8_swedish_ci = 'utf8_swedish_ci' -utf8_turkish_ci = 'utf8_turkish_ci' -utf8_czech_ci = 'utf8_czech_ci' -utf8_danish_ci = 'utf8_danish_ci' -utf8_lithuanian_ci = 'utf8_lithuanian_ci' -utf8_slovak_ci = 'utf8_slovak_ci' -utf8_spanish2_ci = 'utf8_spanish2_ci' -utf8_roman_ci = 'utf8_roman_ci' -utf8_persian_ci = 'utf8_persian_ci' -utf8_esperanto_ci = 'utf8_esperanto_ci' -utf8_hungarian_ci = 'utf8_hungarian_ci' -utf8_sinhala_ci = 'utf8_sinhala_ci' -utf8mb4_general_ci = 'utf8mb4_general_ci' -utf8mb4_bin = 'utf8mb4_bin' -utf8mb4_unicode_ci = 'utf8mb4_unicode_ci' -utf8mb4_icelandic_ci = 'utf8mb4_icelandic_ci' -utf8mb4_latvian_ci = 'utf8mb4_latvian_ci' -utf8mb4_romanian_ci = 'utf8mb4_romanian_ci' -utf8mb4_slovenian_ci = 'utf8mb4_slovenian_ci' -utf8mb4_polish_ci = 'utf8mb4_polish_ci' -utf8mb4_estonian_ci = 'utf8mb4_estonian_ci' -utf8mb4_spanish_ci = 'utf8mb4_spanish_ci' -utf8mb4_swedish_ci = 'utf8mb4_swedish_ci' -utf8mb4_turkish_ci = 'utf8mb4_turkish_ci' -utf8mb4_czech_ci = 'utf8mb4_czech_ci' -utf8mb4_danish_ci = 'utf8mb4_danish_ci' -utf8mb4_lithuanian_ci = 'utf8mb4_lithuanian_ci' -utf8mb4_slovak_ci = 'utf8mb4_slovak_ci' -utf8mb4_spanish2_ci = 'utf8mb4_spanish2_ci' -utf8mb4_roman_ci = 'utf8mb4_roman_ci' -utf8mb4_persian_ci = 'utf8mb4_persian_ci' -utf8mb4_esperanto_ci = 'utf8mb4_esperanto_ci' -utf8mb4_hungarian_ci = 'utf8mb4_hungarian_ci' -utf8mb4_sinhala_ci = 'utf8mb4_sinhala_ci' - - -def identity(x: Any) -> Any: - return x - - -def utf8str(x: Any) -> Optional[str]: - if x is None: - return x - if isinstance(x, str): - return x - return str(x, 'utf-8') - - -def bytestr(x: Any) -> Optional[bytes]: - if x is None: - return x - if isinstance(x, bytes): - return x - return base64.b64decode(x) - - -PYTHON_CONVERTERS = { - -1: converters[1], - -2: converters[2], - -3: converters[3], - -8: converters[8], - -9: converters[9], - 15: utf8str, - -15: bytestr, - 249: utf8str, - -249: bytestr, - 250: utf8str, - -250: bytestr, - 251: utf8str, - -251: bytestr, - 252: utf8str, - -252: bytestr, - 254: utf8str, - -254: bytestr, - 255: utf8str, -} - -PYTHON_CONVERTERS = dict(list(converters.items()) + list(PYTHON_CONVERTERS.items())) - - -def _modifiers( - *, - nullable: Optional[bool] = None, - charset: Optional[str] = None, - collate: Optional[str] = None, - default: Optional[Any] = None, - unsigned: Optional[bool] = None, -) -> str: - """ - Format type modifiers. - - Parameters - ---------- - nullable : bool, optional - Can the value be NULL? - charset : str, optional - Character set - collate : str, optional - Collation - default ; Any, optional - Default value - unsigned : bool, optional - Is the value unsigned? (ints only) - - Returns - ------- - str - - """ - out = [] - - if unsigned is not None: - if unsigned: - out.append('UNSIGNED') - - if charset is not None: - if not re.match(r'^[A-Za-z0-9_]+$', charset): - raise ValueError(f'charset value is invalid: {charset}') - out.append(f'CHARACTER SET {charset}') - - if collate is not None: - if not re.match(r'^[A-Za-z0-9_]+$', collate): - raise ValueError(f'collate value is invalid: {collate}') - out.append(f'COLLATE {collate}') - - if nullable is not None: - if nullable: - out.append('NULL') - else: - out.append('NOT NULL') - - if default is NULL: - out.append('DEFAULT NULL') - elif default is not None: - out.append(f'DEFAULT {escape_item(default, "utf-8")}') - - return ' ' + ' '.join(out) - - -def _bool(x: Optional[bool] = None) -> Optional[bool]: - """Cast bool.""" - if x is None: - return None - return bool(x) - - -def BOOL( - *, - nullable: bool = True, - default: Optional[bool] = None, - name: Optional[str] = None, -) -> SQLString: - """ - BOOL type specification. - - Parameters - ---------- - nullable : bool, optional - Can the value be NULL? - default : bool, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString('BOOL' + _modifiers(nullable=nullable, default=_bool(default))) - out.name = name - return out - - -def BOOLEAN( - *, - nullable: bool = True, - default: Optional[bool] = None, - name: Optional[str] = None, -) -> SQLString: - """ - BOOLEAN type specification. - - Parameters - ---------- - nullable : bool, optional - Can the value be NULL? - default : bool, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString('BOOLEAN' + _modifiers(nullable=nullable, default=_bool(default))) - out.name = name - return out - - -def BIT( - *, - nullable: bool = True, - default: Optional[int] = None, - name: Optional[str] = None, -) -> SQLString: - """ - BIT type specification. - - Parameters - ---------- - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString('BIT' + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def TINYINT( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - unsigned: bool = False, - name: Optional[str] = None, -) -> SQLString: - """ - TINYINT type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - unsigned : bool, optional - Is the int unsigned? - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'TINYINT({display_width})' if display_width else 'TINYINT' - out = SQLString( - out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), - ) - out.name = name - return out - - -def TINYINT_UNSIGNED( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - name: Optional[str] = None, -) -> SQLString: - """ - TINYINT UNSIGNED type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'TINYINT({display_width})' if display_width else 'TINYINT' - out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) - out.name = name - return out - - -def SMALLINT( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - unsigned: bool = False, - name: Optional[str] = None, -) -> SQLString: - """ - SMALLINT type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - unsigned : bool, optional - Is the int unsigned? - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'SMALLINT({display_width})' if display_width else 'SMALLINT' - out = SQLString( - out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), - ) - out.name = name - return out - - -def SMALLINT_UNSIGNED( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - name: Optional[str] = None, -) -> SQLString: - """ - SMALLINT UNSIGNED type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'SMALLINT({display_width})' if display_width else 'SMALLINT' - out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) - out.name = name - return out - - -def MEDIUMINT( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - unsigned: bool = False, - name: Optional[str] = None, -) -> SQLString: - """ - MEDIUMINT type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - unsigned : bool, optional - Is the int unsigned? - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'MEDIUMINT({display_width})' if display_width else 'MEDIUMINT' - out = SQLString( - out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), - ) - out.name = name - return out - - -def MEDIUMINT_UNSIGNED( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - name: Optional[str] = None, -) -> SQLString: - """ - MEDIUMINT UNSIGNED type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'MEDIUMINT({display_width})' if display_width else 'MEDIUMINT' - out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) - out.name = name - return out - - -def INT( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - unsigned: bool = False, - name: Optional[str] = None, -) -> SQLString: - """ - INT type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - unsigned : bool, optional - Is the int unsigned? - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'INT({display_width})' if display_width else 'INT' - out = SQLString( - out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), - ) - out.name = name - return out - - -def INT_UNSIGNED( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - name: Optional[str] = None, -) -> SQLString: - """ - INT UNSIGNED type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'INT({display_width})' if display_width else 'INT' - out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) - out.name = name - return out - - -def INTEGER( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - unsigned: bool = False, - name: Optional[str] = None, -) -> SQLString: - """ - INTEGER type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - unsigned : bool, optional - Is the int unsigned? - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'INTEGER({display_width})' if display_width else 'INTEGER' - out = SQLString( - out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), - ) - out.name = name - return out - - -def INTEGER_UNSIGNED( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - name: Optional[str] = None, -) -> SQLString: - """ - INTEGER UNSIGNED type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'INTEGER({display_width})' if display_width else 'INTEGER' - out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) - out.name = name - return out - - -def BIGINT( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - unsigned: bool = False, - name: Optional[str] = None, -) -> SQLString: - """ - BIGINT type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - unsigned : bool, optional - Is the int unsigned? - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'BIGINT({display_width})' if display_width else 'BIGINT' - out = SQLString( - out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), - ) - out.name = name - return out - - -def BIGINT_UNSIGNED( - display_width: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[int] = None, - name: Optional[str] = None, -) -> SQLString: - """ - BIGINT UNSIGNED type specification. - - Parameters - ---------- - display_width : int, optional - Display width used by some clients - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'BIGINT({int(display_width)})' if display_width else 'BIGINT' - out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) - out.name = name - return out - - -def FLOAT( - display_decimals: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[float] = None, - name: Optional[str] = None, -) -> SQLString: - """ - FLOAT type specification. - - Parameters - ---------- - display_decimals : int, optional - Number of decimal places to display - nullable : bool, optional - Can the value be NULL? - default : float, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'FLOAT({int(display_decimals)})' if display_decimals else 'FLOAT' - out = SQLString(out + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def DOUBLE( - display_decimals: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[float] = None, - name: Optional[str] = None, -) -> SQLString: - """ - DOUBLE type specification. - - Parameters - ---------- - display_decimals : int, optional - Number of decimal places to display - nullable : bool, optional - Can the value be NULL? - default : float, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'DOUBLE({int(display_decimals)})' if display_decimals else 'DOUBLE' - out = SQLString(out + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def REAL( - display_decimals: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[float] = None, - name: Optional[str] = None, -) -> SQLString: - """ - REAL type specification. - - Parameters - ---------- - display_decimals : int, optional - Number of decimal places to display - nullable : bool, optional - Can the value be NULL? - default : float, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'REAL({int(display_decimals)})' if display_decimals else 'REAL' - out = SQLString(out + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def DECIMAL( - precision: int, - scale: int, - *, - nullable: bool = True, - default: Optional[Union[str, decimal.Decimal]] = None, - name: Optional[str] = None, -) -> SQLString: - """ - DECIMAL type specification. - - Parameters - ---------- - precision : int - Decimal precision - scale : int - Decimal scale - nullable : bool, optional - Can the value be NULL? - default : str or decimal.Decimal, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString( - f'DECIMAL({int(precision)}, {int(scale)})' + - _modifiers(nullable=nullable, default=default), - ) - out.name = name - return out - - -def DEC( - precision: int, - scale: int, - *, - nullable: bool = True, - default: Optional[Union[str, decimal.Decimal]] = None, - name: Optional[str] = None, -) -> SQLString: - """ - DEC type specification. - - Parameters - ---------- - precision : int - Decimal precision - scale : int - Decimal scale - nullable : bool, optional - Can the value be NULL? - default : str or decimal.Decimal, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString( - f'DEC({int(precision)}, {int(scale)})' + - _modifiers(nullable=nullable, default=default), - ) - out.name = name - return out - - -def FIXED( - precision: int, - scale: int, - *, - nullable: bool = True, - default: Optional[Union[str, decimal.Decimal]] = None, - name: Optional[str] = None, -) -> SQLString: - """ - FIXED type specification. - - Parameters - ---------- - precision : int - Decimal precision - scale : int - Decimal scale - nullable : bool, optional - Can the value be NULL? - default : str or decimal.Decimal, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString( - f'FIXED({int(precision)}, {int(scale)})' + - _modifiers(nullable=nullable, default=default), - ) - out.name = name - return out - - -def NUMERIC( - precision: int, - scale: int, - *, - nullable: bool = True, - default: Optional[Union[str, decimal.Decimal]] = None, - name: Optional[str] = None, -) -> SQLString: - """ - NUMERIC type specification. - - Parameters - ---------- - precision : int - Decimal precision - scale : int - Decimal scale - nullable : bool, optional - Can the value be NULL? - default : str or decimal.Decimal, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString( - f'NUMERIC({int(precision)}, {int(scale)})' + - _modifiers(nullable=nullable, default=default), - ) - out.name = name - return out - - -def DATE( - *, - nullable: bool = True, - default: Optional[Union[str, datetime.date]] = None, - name: Optional[str] = None, -) -> SQLString: - """ - DATE type specification. - - Parameters - ---------- - nullable : bool, optional - Can the value be NULL? - default : str or datetime.date, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString('DATE' + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def TIME( - precision: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[Union[str, datetime.timedelta]] = None, - name: Optional[str] = None, -) -> SQLString: - """ - TIME type specification. - - Parameters - ---------- - precision : int, optional - Sub-second precision - nullable : bool, optional - Can the value be NULL? - default : str or datetime.timedelta, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'TIME({int(precision)})' if precision else 'TIME' - out = SQLString(out + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def DATETIME( - precision: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[Union[str, datetime.datetime]] = None, - name: Optional[str] = None, -) -> SQLString: - """ - DATETIME type specification. - - Parameters - ---------- - precision : int, optional - Sub-second precision - nullable : bool, optional - Can the value be NULL? - default : str or datetime.datetime, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'DATETIME({int(precision)})' if precision else 'DATETIME' - out = SQLString(out + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def TIMESTAMP( - precision: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[Union[str, datetime.datetime]] = None, - name: Optional[str] = None, -) -> SQLString: - """ - TIMESTAMP type specification. - - Parameters - ---------- - precision : int, optional - Sub-second precision - nullable : bool, optional - Can the value be NULL? - default : str or datetime.datetime, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'TIMESTAMP({int(precision)})' if precision else 'TIMESTAMP' - out = SQLString(out + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def YEAR( - *, - nullable: bool = True, - default: Optional[int] = None, - name: Optional[str] = None, -) -> SQLString: - """ - YEAR type specification. - - Parameters - ---------- - nullable : bool, optional - Can the value be NULL? - default : int, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString('YEAR' + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def CHAR( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[str] = None, - collate: Optional[str] = None, - charset: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - CHAR type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - charset : str, optional - Character set - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'CHAR({int(length)})' if length else 'CHAR' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, - ), - ) - out.name = name - return out - - -def VARCHAR( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[str] = None, - collate: Optional[str] = None, - charset: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - VARCHAR type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - charset : str, optional - Character set - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'VARCHAR({int(length)})' if length else 'VARCHAR' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, - ), - ) - out.name = name - return out - - -def LONGTEXT( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[str] = None, - collate: Optional[str] = None, - charset: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - LONGTEXT type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - charset : str, optional - Character set - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'LONGTEXT({int(length)})' if length else 'LONGTEXT' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, - ), - ) - out.name = name - return out - - -def MEDIUMTEXT( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[str] = None, - collate: Optional[str] = None, - charset: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - MEDIUMTEXT type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - charset : str, optional - Character set - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'MEDIUMTEXT({int(length)})' if length else 'MEDIUMTEXT' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, - ), - ) - out.name = name - return out - - -def TEXT( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[str] = None, - collate: Optional[str] = None, - charset: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - TEXT type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - charset : str, optional - Character set - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'TEXT({int(length)})' if length else 'TEXT' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, - ), - ) - out.name = name - return out - - -def TINYTEXT( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[str] = None, - collate: Optional[str] = None, - charset: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - TINYTEXT type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - charset : str, optional - Character set - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'TINYTEXT({int(length)})' if length else 'TINYTEXT' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, - ), - ) - out.name = name - return out - - -def BINARY( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[bytes] = None, - collate: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - BINARY type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'BINARY({int(length)})' if length else 'BINARY' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, collate=collate, - ), - ) - out.name = name - return out - - -def VARBINARY( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[bytes] = None, - collate: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - VARBINARY type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'VARBINARY({int(length)})' if length else 'VARBINARY' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, collate=collate, - ), - ) - out.name = name - return out - - -def LONGBLOB( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[bytes] = None, - collate: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - LONGBLOB type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'LONGBLOB({int(length)})' if length else 'LONGBLOB' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, collate=collate, - ), - ) - out.name = name - return out - - -def MEDIUMBLOB( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[bytes] = None, - collate: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - MEDIUMBLOB type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'MEDIUMBLOB({int(length)})' if length else 'MEDIUMBLOB' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, collate=collate, - ), - ) - out.name = name - return out - - -def BLOB( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[bytes] = None, - collate: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - BLOB type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'BLOB({int(length)})' if length else 'BLOB' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, collate=collate, - ), - ) - out.name = name - return out - - -def TINYBLOB( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[bytes] = None, - collate: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - TINYBLOB type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'TINYBLOB({int(length)})' if length else 'TINYBLOB' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, collate=collate, - ), - ) - out.name = name - return out - - -def JSON( - length: Optional[int] = None, - *, - nullable: bool = True, - default: Optional[str] = None, - collate: Optional[str] = None, - charset: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - JSON type specification. - - Parameters - ---------- - length : int, optional - Maximum string length - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - collate : str, optional - Collation - charset : str, optional - Character set - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = f'JSON({int(length)})' if length else 'JSON' - out = SQLString( - out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, - ), - ) - out.name = name - return out - - -def GEOGRAPHYPOINT( - *, - nullable: bool = True, - default: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - GEOGRAPHYPOINT type specification. - - Parameters - ---------- - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - out = SQLString('GEOGRAPHYPOINT' + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -def GEOGRAPHY( - *, - nullable: bool = True, - default: Optional[str] = None, - name: Optional[str] = None, -) -> SQLString: - """ - GEOGRAPHYPOINT type specification. - - Parameters - ---------- - nullable : bool, optional - Can the value be NULL? - default : str, optional - Default value - - Returns - ------- - str - - """ - out = SQLString('GEOGRAPHY' + _modifiers(nullable=nullable, default=default)) - out.name = name - return out - - -# def RECORD( -# *args: Tuple[str, DataType], -# nullable: bool = True, -# name: Optional[str] = None, -# ) -> SQLString: -# """ -# RECORD type specification. -# -# Parameters -# ---------- -# *args : Tuple[str, DataType] -# Field specifications -# nullable : bool, optional -# Can the value be NULL? -# name : str, optional -# Name of the column / parameter -# -# Returns -# ------- -# SQLString -# -# """ -# assert len(args) > 0 -# fields = [] -# for name, value in args: -# if callable(value): -# fields.append(f'{escape_name(name)} {value()}') -# else: -# fields.append(f'{escape_name(name)} {value}') -# out = SQLString(f'RECORD({", ".join(fields)})' + _modifiers(nullable=nullable)) -# out.name = name -# return out - - -# def ARRAY( -# dtype: DataType, -# nullable: bool = True, -# name: Optional[str] = None, -# ) -> SQLString: -# """ -# ARRAY type specification. -# -# Parameters -# ---------- -# dtype : DataType -# The data type of the array elements -# nullable : bool, optional -# Can the value be NULL? -# name : str, optional -# Name of the column / parameter -# -# Returns -# ------- -# SQLString -# -# """ -# if callable(dtype): -# dtype = dtype() -# out = SQLString(f'ARRAY({dtype})' + _modifiers(nullable=nullable)) -# out.name = name -# return out - - -# F32 = 'F32' -# F64 = 'F64' -# I8 = 'I8' -# I16 = 'I16' -# I32 = 'I32' -# I64 = 'I64' - - -# def VECTOR( -# length: int, -# element_type: str = F32, -# *, -# nullable: bool = True, -# default: Optional[bytes] = None, -# name: Optional[str] = None, -# ) -> SQLString: -# """ -# VECTOR type specification. -# -# Parameters -# ---------- -# n : int -# Number of elements in vector -# element_type : str, optional -# Type of the elements in the vector: -# F32, F64, I8, I16, I32, I64 -# nullable : bool, optional -# Can the value be NULL? -# default : str, optional -# Default value -# name : str, optional -# Name of the column / parameter -# -# Returns -# ------- -# SQLString -# -# """ -# out = f'VECTOR({int(length)}, {element_type})' -# out = SQLString( -# out + _modifiers( -# nullable=nullable, default=default, -# ), -# ) -# out.name = name -# return out +warn( + 'The dtypes module has been renamed to sql_types. ' + 'Please update your imports to remove this warning.', + DeprecationWarning, stacklevel=2, +) diff --git a/singlestoredb/functions/ext/arrow.py b/singlestoredb/functions/ext/arrow.py index 34bcf6e1..8f5ab49f 100644 --- a/singlestoredb/functions/ext/arrow.py +++ b/singlestoredb/functions/ext/arrow.py @@ -5,6 +5,9 @@ from typing import Optional from typing import Tuple +from singlestoredb.functions.ext.utils import apply_transformer +from singlestoredb.functions.ext.utils import Transformer + try: import numpy as np has_numpy = True @@ -32,7 +35,7 @@ def load( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[List[int], List[Any]]: ''' @@ -40,7 +43,7 @@ def load( Parameters ---------- - colspec : List[str] + colspec : List[Tuple[str, int, Optional[Transformer]]] An List of column data types data : bytes The data in Apache Feather format @@ -57,12 +60,19 @@ def load( row_ids = table.column(0).to_pylist() rows = [] for row in table.to_pylist(): - rows.append([row[c] for c in table.column_names[1:]]) + converted_row = [] + for i, col_name in enumerate(table.column_names[1:]): + value = apply_transformer( + colspec[i][2], + row[col_name], + ) + converted_row.append(value) + rows.append(converted_row) return row_ids, rows def _load_vectors( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'pa.Array[pa.int64]', @@ -73,7 +83,7 @@ def _load_vectors( Parameters ---------- - colspec : List[str] + colspec : List[Tuple[str, int, Optional[Transformer]]] An List of column data types data : bytes The data in Apache Feather format @@ -86,16 +96,23 @@ def _load_vectors( if not has_pyarrow: raise RuntimeError('pyarrow must be installed for this operation') + import pyarrow as pa + table = pa.feather.read_table(BytesIO(data)) row_ids = table.column(0) out = [] for i, col in enumerate(table.columns[1:]): - out.append((col, col.is_null())) + out.append(( + apply_transformer( + colspec[i][2], col, + ), col.is_null(), + )) + return row_ids, out def load_pandas( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'pd.Series[np.int64]', @@ -106,7 +123,7 @@ def load_pandas( Parameters ---------- - colspec : List[str] + colspec : List[Tuple[str, int, Optional[Transformer]]] An List of column data types data : bytes The data in Apache Feather format @@ -128,12 +145,12 @@ def load_pandas( data.to_pandas().reindex(index), mask.to_pandas().reindex(index), ) - for (data, mask), (name, dtype) in zip(cols, colspec) + for (data, mask), (name, dtype, _) in zip(cols, colspec) ] def load_polars( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'pl.Series[pl.Int64]', @@ -144,7 +161,7 @@ def load_polars( Parameters ---------- - colspec : List[str] + colspec : List[Tuple[str, int, Optional[Transformer]]] An List of column data types data : bytes The data in Apache Feather format @@ -166,13 +183,13 @@ def load_polars( pl.from_arrow(data), # type: ignore pl.from_arrow(mask), # type: ignore ) - for (data, mask), (name, dtype) in zip(cols, colspec) + for (data, mask), (name, dtype, _) in zip(cols, colspec) ], ) def load_numpy( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'np.typing.NDArray[np.int64]', @@ -183,7 +200,7 @@ def load_numpy( Parameters ---------- - colspec : List[str] + colspec : List[Tuple[str, int, Optional[Transformer]]] An List of column data types data : bytes The data in Apache Feather format @@ -204,12 +221,12 @@ def load_numpy( data.to_numpy(), mask.to_numpy(), ) - for (data, mask), (name, dtype) in zip(cols, colspec) + for (data, mask), (name, dtype, _) in zip(cols, colspec) ] def load_arrow( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'pa.Array[pa.int64()]', @@ -220,7 +237,7 @@ def load_arrow( Parameters ---------- - colspec : List[str] + colspec : List[Tuple[str, int, Optional[Transformer]]] An List of column data types data : bytes The data in Apache Feather format @@ -237,7 +254,7 @@ def load_arrow( def dump( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: List[int], rows: List[List[Any]], ) -> bytes: @@ -246,7 +263,7 @@ def dump( Parameters ---------- - returns : List[int] + returns : List[Tuple[str, int, Optional[Transformer]]] The returned data type row_ids : List[int] The row IDs @@ -262,11 +279,23 @@ def dump( raise RuntimeError('pyarrow must be installed for this operation') if len(rows) == 0 or len(row_ids) == 0: - return BytesIO().getbuffer() + return BytesIO().getvalue() colnames = ['col{}'.format(x) for x in range(len(rows[0]))] - tbl = pa.Table.from_pylist([dict(list(zip(colnames, row))) for row in rows]) + # Process rows to handle JSON serialization + processed_rows = [] + for row in rows: + processed_row = [] + for i, value in enumerate(row): + processed_row.append( + apply_transformer( + returns[i][2], value, + ), + ) + processed_rows.append(processed_row) + + tbl = pa.Table.from_pylist([dict(list(zip(colnames, row))) for row in processed_rows]) tbl = tbl.add_column(0, '__index__', pa.array(row_ids)) sink = pa.BufferOutputStream() @@ -278,7 +307,7 @@ def dump( def _dump_vectors( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pa.Array[pa.int64]', cols: List[Tuple['pa.Array[Any]', Optional['pa.Array[pa.bool_]']]], ) -> bytes: @@ -287,7 +316,7 @@ def _dump_vectors( Parameters ---------- - returns : List[int] + returns : List[Tuple[str, int, Optional[Transformer]]] The returned data type row_ids : List[int] The row IDs @@ -303,10 +332,19 @@ def _dump_vectors( raise RuntimeError('pyarrow must be installed for this operation') if len(cols) == 0 or len(row_ids) == 0: - return BytesIO().getbuffer() + return BytesIO().getvalue() + + # Process columns to handle JSON serialization + processed_cols = [] + for i, (data, mask) in enumerate(cols): + processed_cols.append(( + apply_transformer( + returns[i][2], data, + ), mask, + )) tbl = pa.Table.from_arrays( - [pa.array(data, mask=mask) for data, mask in cols], + [pa.array(data, mask=mask) for data, mask in processed_cols], names=['col{}'.format(x) for x in range(len(cols))], ) tbl = tbl.add_column(0, '__index__', row_ids) @@ -320,7 +358,7 @@ def _dump_vectors( def dump_arrow( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pa.Array[int]', cols: List[Tuple['pa.Array[Any]', 'pa.Array[bool]']], ) -> bytes: @@ -331,7 +369,7 @@ def dump_arrow( def dump_numpy( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'np.typing.NDArray[np.int64]', cols: List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']], ) -> bytes: @@ -346,7 +384,7 @@ def dump_numpy( def dump_pandas( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pd.Series[np.int64]', cols: List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']], ) -> bytes: @@ -361,7 +399,7 @@ def dump_pandas( def dump_polars( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pl.Series[pl.Int64]', cols: List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']], ) -> bytes: diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index d9090b38..b9bf6273 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -49,8 +49,10 @@ import zipfile import zipimport from collections.abc import Awaitable +from collections.abc import Coroutine from collections.abc import Iterable from collections.abc import Sequence +from threading import Event from types import ModuleType from typing import Any from typing import Callable @@ -61,6 +63,11 @@ from typing import Tuple from typing import Union +try: + from typing import TypeAlias # type: ignore +except ImportError: + from typing_extensions import TypeAlias # type: ignore + from . import arrow from . import json as jdata from . import rowdat_1 @@ -69,13 +76,13 @@ from ... import manage_workspaces from ...config import get_option from ...mysql.constants import FIELD_TYPE as ft +from ...docstring.parser import parse from ..signature import get_signature from ..signature import signature_to_sql +from ..sql_types import escape_name from ..typing import Masked from ..typing import Table from .timer import Timer -from singlestoredb.docstring.parser import parse -from singlestoredb.functions.dtypes import escape_name try: import cloudpickle @@ -128,6 +135,7 @@ async def to_thread( 'float64': ft.DOUBLE, 'str': ft.STRING, 'bytes': -ft.STRING, + 'json': ft.JSON, } @@ -165,32 +173,37 @@ def get_func_names(funcs: str) -> List[Tuple[str, str]]: return out -def as_tuple(x: Any) -> Any: - """Convert object to tuple.""" - if has_pydantic and isinstance(x, BaseModel): - return tuple(x.model_dump().values()) - if dataclasses.is_dataclass(x): - return dataclasses.astuple(x) # type: ignore - if isinstance(x, dict): - return tuple(x.values()) - return tuple(x) - - -def as_list_of_tuples(x: Any) -> Any: - """Convert object to a list of tuples.""" +def extend_rows(rows: List[Any], x: Any) -> int: + """Extend list of rows with data from object.""" if isinstance(x, Table): x = x[0] + if isinstance(x, (list, tuple)) and len(x) > 0: + if isinstance(x[0], (list, tuple)): - return x - if has_pydantic and isinstance(x[0], BaseModel): - return [tuple(y.model_dump().values()) for y in x] - if dataclasses.is_dataclass(x[0]): - return [dataclasses.astuple(y) for y in x] - if isinstance(x[0], dict): - return [tuple(y.values()) for y in x] - return [(y,) for y in x] - return x + rows.extend(x) + + elif has_pydantic and isinstance(x[0], BaseModel): + for y in x: + rows.append(tuple(y.model_dump().values())) + + elif dataclasses.is_dataclass(x[0]): + for y in x: + rows.append(dataclasses.astuple(y)) + + elif isinstance(x[0], dict): + for y in x: + rows.append(tuple(y.values())) + + else: + for y in x: + rows.append((y,)) + + return len(x) + + rows.append((x,)) + + return 1 def get_dataframe_columns(df: Any) -> List[Any]: @@ -208,16 +221,22 @@ def get_dataframe_columns(df: Any) -> List[Any]: return list(df) rtype = str(type(df)).lower() + + # Pandas or polars type of dataframe if 'dataframe' in rtype: return [df[x] for x in df.columns] + # PyArrow table elif 'table' in rtype: return df.columns + # Pandas or polars series elif 'series' in rtype: return [df] + # Numpy array elif 'array' in rtype: return [df] - elif 'tuple' in rtype: - return list(df) + # List of objects + elif 'list' in rtype: + return [df] raise TypeError( 'Unsupported data type for dataframe columns: ' @@ -225,24 +244,37 @@ def get_dataframe_columns(df: Any) -> List[Any]: ) -def get_array_class(data_format: str) -> Callable[..., Any]: +def get_array_class(array: Any) -> Callable[..., Any]: """ Get the array class for the current data format. """ - if data_format == 'polars': + mod = inspect.getmodule(type(array)) + if mod: + array_type = mod.__name__.split('.')[0] + else: + raise TypeError(f'Unsupported array type: {type(array)}') + + if array_type == 'polars': import polars as pl - array_cls = pl.Series - elif data_format == 'arrow': + return pl.Series + + if array_type == 'pyarrow': import pyarrow as pa - array_cls = pa.array - elif data_format == 'pandas': + return pa.array + + if array_type == 'pandas': import pandas as pd - array_cls = pd.Series - else: + return pd.Series + + if array_type == 'numpy': import numpy as np - array_cls = np.array - return array_cls + return np.array + + if isinstance(array, list): + return list + + raise TypeError(f'Unsupported array type: {type(array)}') def get_masked_params(func: Callable[..., Any]) -> List[bool]: @@ -295,19 +327,31 @@ def cancel_on_event( ) -def build_udf_endpoint( +RowIDs: TypeAlias = Sequence[int] +VectorInput: TypeAlias = Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]] +ScalarInput: TypeAlias = Sequence[Sequence[Any]] +UDFInput = Union[VectorInput, ScalarInput] +VectorOutput: TypeAlias = List[Tuple[Sequence[Any], Optional[Sequence[bool]]]] +ScalarOutput: TypeAlias = List[Tuple[Any, ...]] +UDFOutput = Union[VectorOutput, ScalarOutput] + + +def scalar_in_scalar_out( func: Callable[..., Any], - returns_data_format: str, -) -> Callable[..., Any]: + function_type: str = 'udf', +) -> Callable[ + [Event, Timer, RowIDs, ScalarInput], + Coroutine[Any, Any, Tuple[RowIDs, ScalarOutput]], +]: """ - Build a UDF endpoint for scalar / list types (row-based). + Create a scalar in, scalar out function endpoint. Parameters ---------- func : Callable The function to call as the endpoint - returns_data_format : str - The format of the return values + function_type : str, optional + The type of function: 'udf' or 'tvf' Returns ------- @@ -315,45 +359,57 @@ def build_udf_endpoint( The function endpoint """ - if returns_data_format in ['scalar', 'list']: - - is_async = asyncio.iscoroutinefunction(func) - - async def do_func( - cancel_event: threading.Event, - timer: Timer, - row_ids: Sequence[int], - rows: Sequence[Sequence[Any]], - ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: - '''Call function on given rows of data.''' - out = [] - async with timer('call_function'): - for row in rows: - cancel_on_event(cancel_event) - if is_async: - out.append(await func(*row)) - else: - out.append(func(*row)) - return row_ids, list(zip(out)) + is_async = asyncio.iscoroutinefunction(func) + is_udf = function_type == 'udf' + + async def do_scalar_in_scalar_out_func( + cancel_event: threading.Event, + timer: Timer, + row_ids: RowIDs, + rows: ScalarInput, + ) -> Tuple[RowIDs, ScalarOutput]: + """Call function on given rows of data.""" + cancel_on_event(cancel_event) + + async with (timer('call_function')): + out_ids = [] + out_rows: ScalarOutput = [] - return do_func + for i, row in zip(row_ids, rows): + func_res = await func(*row) if is_async else func(*row) - return build_vector_udf_endpoint(func, returns_data_format) + cancel_on_event(cancel_event) + n_rows = extend_rows(out_rows, func_res) -def build_vector_udf_endpoint( + if is_udf and n_rows != 1: + raise ValueError('UDF must return a single value per input row') + + out_ids.extend([i] * n_rows) + + cancel_on_event(cancel_event) + + return out_ids, out_rows + + return do_scalar_in_scalar_out_func + + +def scalar_in_vector_out( func: Callable[..., Any], - returns_data_format: str, -) -> Callable[..., Any]: + function_type: str = 'udf', +) -> Callable[ + [Event, Timer, RowIDs, ScalarInput], + Coroutine[Any, Any, Tuple[RowIDs, VectorOutput]], +]: """ - Build a UDF endpoint for vector formats (column-based). + Create a scalar in, vector out function endpoint. Parameters ---------- func : Callable The function to call as the endpoint - returns_data_format : str - The format of the return values + function_type : str, optional + The type of function: 'udf' or 'tvf' Returns ------- @@ -361,64 +417,62 @@ def build_vector_udf_endpoint( The function endpoint """ - masks = get_masked_params(func) - array_cls = get_array_class(returns_data_format) is_async = asyncio.iscoroutinefunction(func) + is_udf = function_type == 'udf' - async def do_func( + async def do_scalar_in_vector_out_func( cancel_event: threading.Event, timer: Timer, - row_ids: Sequence[int], - cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], - ) -> Tuple[ - Sequence[int], - List[Tuple[Sequence[Any], Optional[Sequence[bool]]]], - ]: - '''Call function on given columns of data.''' - row_ids = array_cls(row_ids) - - # Call the function with `cols` as the function parameters - async with timer('call_function'): - if cols and cols[0]: - if is_async: - out = await func(*[x if m else x[0] for x, m in zip(cols, masks)]) - else: - out = func(*[x if m else x[0] for x, m in zip(cols, masks)]) - else: - if is_async: - out = await func() - else: - out = func() - + row_ids: RowIDs, + rows: ScalarInput, + ) -> Tuple[RowIDs, VectorOutput]: + """Call function on given rows of data.""" cancel_on_event(cancel_event) - # Single masked value - if isinstance(out, Masked): - return row_ids, [tuple(out)] + async with (timer('call_function')): + out_vectors = [] + out_ids = [] + for i, row in zip(row_ids, rows): + func_res = await func(*row) if is_async else func(*row) - # Multiple return values - if isinstance(out, tuple): - return row_ids, [build_tuple(x) for x in out] + cancel_on_event(cancel_event) + + res = get_dataframe_columns(func_res) + + ref = res[0][0] if isinstance(res[0], Masked) else res[0] + if is_udf and len(ref) != 1: + raise ValueError('UDF must return a single value per input row') + + out_ids.extend([i] * len(ref)) + + out_vectors.append([build_tuple(x) for x in res]) - # Single return value - return row_ids, [(out, None)] + cancel_on_event(cancel_event) - return do_func + # Concatenate vector results from all rows + out = concatenate_vectors(out_vectors) + return get_array_class(out[0][0][0])(out_ids), out -def build_tvf_endpoint( + return do_scalar_in_vector_out_func + + +def vector_in_vector_out( func: Callable[..., Any], - returns_data_format: str, -) -> Callable[..., Any]: + function_type: str = 'udf', +) -> Callable[ + [Event, Timer, RowIDs, VectorInput], + Coroutine[Any, Any, Tuple[RowIDs, VectorOutput]], +]: """ - Build a TVF endpoint for scalar / list types (row-based). + Create a vector in, vector out function endpoint. Parameters ---------- func : Callable The function to call as the endpoint - returns_data_format : str - The format of the return values + function_type : str, optional + The type of function: 'udf' or 'tvf' Returns ------- @@ -426,49 +480,62 @@ def build_tvf_endpoint( The function endpoint """ - if returns_data_format in ['scalar', 'list']: - - is_async = asyncio.iscoroutinefunction(func) - - async def do_func( - cancel_event: threading.Event, - timer: Timer, - row_ids: Sequence[int], - rows: Sequence[Sequence[Any]], - ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: - '''Call function on given rows of data.''' - out_ids: List[int] = [] - out = [] - # Call function on each row of data - async with timer('call_function'): - for i, row in zip(row_ids, rows): - cancel_on_event(cancel_event) - if is_async: - res = await func(*row) - else: - res = func(*row) - out.extend(as_list_of_tuples(res)) - out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) - return out_ids, out + masks = get_masked_params(func) + is_async = asyncio.iscoroutinefunction(func) + is_udf = function_type == 'udf' + + async def do_vector_in_vector_out_func( + cancel_event: threading.Event, + timer: Timer, + row_ids: RowIDs, + cols: VectorInput, + ) -> Tuple[RowIDs, VectorOutput]: + """Call function on given columns of data.""" + cancel_on_event(cancel_event) + + args = [] + + async with timer('call_function'): + # Remove masks from args if mask is None + if cols and cols[0]: + args = [x if m else x[0] for x, m in zip(cols, masks)] + + func_res = await func(*args) if is_async else func(*args) + + cancel_on_event(cancel_event) + + out = get_dataframe_columns(func_res) + + ref = out[0][0] if isinstance(out[0], Masked) else out[0] + array_cls = get_array_class(ref) + if is_udf: + if len(ref) != len(row_ids): + raise ValueError('UDF must return a single value per input row') + row_ids = array_cls(row_ids) + else: + row_ids = array_cls([row_ids[0]] * len(ref)) - return do_func + return row_ids, [build_tuple(x) for x in out] - return build_vector_tvf_endpoint(func, returns_data_format) + return do_vector_in_vector_out_func -def build_vector_tvf_endpoint( +def vector_in_scalar_out( func: Callable[..., Any], - returns_data_format: str, -) -> Callable[..., Any]: + function_type: str = 'udf', +) -> Callable[ + [Event, Timer, RowIDs, VectorInput], + Coroutine[Any, Any, Tuple[RowIDs, ScalarOutput]], +]: """ - Build a TVF endpoint for vector formats (column-based). + Create a vector in, scalar out function endpoint. Parameters ---------- func : Callable The function to call as the endpoint - returns_data_format : str - The format of the return values + function_type : str, optional + The type of function: 'udf' or 'tvf' Returns ------- @@ -477,54 +544,192 @@ def build_vector_tvf_endpoint( """ masks = get_masked_params(func) - array_cls = get_array_class(returns_data_format) + is_async = asyncio.iscoroutinefunction(func) + is_udf = function_type == 'udf' - async def do_func( + async def do_vector_in_scalar_out_func( cancel_event: threading.Event, timer: Timer, - row_ids: Sequence[int], - cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], - ) -> Tuple[ - Sequence[int], - List[Tuple[Sequence[Any], Optional[Sequence[bool]]]], - ]: - '''Call function on given columns of data.''' - # NOTE: There is no way to determine which row ID belongs to - # each result row, so we just have to use the same - # row ID for all rows in the result. - - is_async = asyncio.iscoroutinefunction(func) - - # Call function on each column of data + row_ids: RowIDs, + cols: VectorInput, + ) -> Tuple[RowIDs, ScalarOutput]: + """Call function on given columns of data.""" + cancel_on_event(cancel_event) + + out_ids = [] + out_rows: ScalarOutput = [] + args = [] + async with timer('call_function'): + # Remove masks from args if mask is None if cols and cols[0]: - if is_async: - func_res = await func( - *[x if m else x[0] for x, m in zip(cols, masks)], - ) - else: - func_res = func( - *[x if m else x[0] for x, m in zip(cols, masks)], - ) + args = [x if m else x[0] for x, m in zip(cols, masks)] + + func_res = await func(*args) if is_async else func(*args) + + cancel_on_event(cancel_event) + + n_rows = extend_rows(out_rows, func_res) + + if is_udf: + if n_rows != len(row_ids): + raise ValueError('UDF must return a single value per input row') + out_ids = list(row_ids) else: - if is_async: - func_res = await func() - else: - func_res = func() + out_ids.extend([row_ids[0]] * n_rows) - res = get_dataframe_columns(func_res) + return out_ids, out_rows - cancel_on_event(cancel_event) + return do_vector_in_scalar_out_func - # Generate row IDs - if isinstance(res[0], Masked): - row_ids = array_cls([row_ids[0]] * len(res[0][0])) - else: - row_ids = array_cls([row_ids[0]] * len(res[0])) - return row_ids, [build_tuple(x) for x in res] +def concatenate_vectors(segments: List[VectorOutput]) -> VectorOutput: + """ + Concatenate lists of vectors with optional masks. + + Parameters + ---------- + segments : List[VectorOutput] + List of vectors to concatenate. Each vector is a list of tuples, + where each tuple contains an array and an optional mask. + + Returns + ------- + VectorOutput + Concatenated vector with optional mask. + + Raises + ------ + ValueError + If masks are used on some but not all elements. + + """ + columns: List[List[Sequence[Any]]] = [] + masks: List[List[Sequence[bool]]] = [] + has_masks: List[bool] = [] + + for s in segments: + columns = [[]] * len(s) + masks = [[]] * len(s) + has_masks = [False] * len(s) + for i, v in enumerate(s): + columns[i].append(v[0]) + if v[1] is not None: + masks[i].append(v[1]) + + for i, mask in enumerate(masks): + if mask and len(mask) != len(columns[i]): + raise ValueError('Vector masks must be used on either all or no elements') + if mask: + has_masks[i] = True + + return [ + (_concatenate_arrays(c), _concatenate_arrays(m) if has_masks[i] else None) + for i, (c, m) in enumerate(zip(columns, masks)) + ] + + +def _concatenate_arrays( + arrays: Sequence[Sequence[Any]], +) -> Sequence[Any]: + """ + Concatenate lists of arrays from various formats. + + Parameters + ---------- + arrays : Sequence[Sequence[Any]] + List of arrays to concatenate. Supported formats: + - PyArrow arrays + - NumPy arrays + - Pandas Series + - Polars Series + - Python lists + + Returns + ------- + Sequence[Any] + Concatenated array in the same format as input arrays, + or None if input is None + + Raises + ------ + ValueError + If arrays list contains a mix of None and non-None values + TypeError + If arrays contain mixed or unsupported types + + """ + if arrays[0] is None: + raise ValueError('Cannot concatenate None arrays') + + mod = inspect.getmodule(type(arrays[0])) + if mod: + array_type = mod.__name__.split('.')[0] + else: + raise TypeError(f'Unsupported array type: {type(arrays[0])}') + + if array_type == 'numpy': + import numpy as np + return np.concatenate(arrays) - return do_func + if array_type == 'pyarrow': + import pyarrow as pa + return pa.concat_arrays(arrays) + + if array_type == 'pandas': + import pandas as pd + return pd.concat(arrays, ignore_index=True) + + if array_type == 'polars': + import polars as pl + return pl.concat(arrays) + + if isinstance(arrays[0], list): + result: List[Any] = [] + for arr in arrays: + result.extend(arr) + return result + + raise TypeError(f'Unsupported array type: {type(arrays[0])}') + + +def build_udf_endpoint( + func: Callable[..., Any], + args_data_format: str, + returns_data_format: str, + function_type: str = 'udf', +) -> Callable[ + [Event, Timer, RowIDs, Any], + Coroutine[Any, Any, Tuple[RowIDs, Any]], +]: + """ + Build a UDF endpoint for scalar / list types (row-based). + + Parameters + ---------- + func : Callable + The function to call as the endpoint + args_data_format : str + The format of the argument values + returns_data_format : str + The format of the return values + function_type : str, optional + The type of function: 'udf' or 'tvf' + + Returns + ------- + Callable + The function endpoint + + """ + if args_data_format in ['scalar'] and returns_data_format in ['scalar']: + return scalar_in_scalar_out(func, function_type=function_type) + elif args_data_format in ['scalar'] and returns_data_format not in ['scalar']: + return scalar_in_vector_out(func, function_type=function_type) + elif args_data_format not in ['scalar'] and returns_data_format in ['scalar']: + return vector_in_scalar_out(func, function_type=function_type) + else: + return vector_in_vector_out(func, function_type=function_type) def make_func( @@ -560,10 +765,12 @@ def make_func( get_option('external_function.timeout') ) - if function_type == 'tvf': - do_func = build_tvf_endpoint(func, returns_data_format) - else: - do_func = build_udf_endpoint(func, returns_data_format) + do_func = build_udf_endpoint( + func, + args_data_format, + returns_data_format, + function_type=function_type, + ) do_func.__name__ = name do_func.__doc__ = func.__doc__ @@ -590,7 +797,7 @@ def make_func( dtype = x['dtype'].replace('?', '') if dtype not in rowdat_1_type_map: raise TypeError(f'no data type mapping for {dtype}') - colspec.append((x['name'], rowdat_1_type_map[dtype])) + colspec.append((x['name'], rowdat_1_type_map[dtype], x['transformer'])) info['colspec'] = colspec # Setup return type @@ -599,7 +806,7 @@ def make_func( dtype = x['dtype'].replace('?', '') if dtype not in rowdat_1_type_map: raise TypeError(f'no data type mapping for {dtype}') - returns.append((x['name'], rowdat_1_type_map[dtype])) + returns.append((x['name'], rowdat_1_type_map[dtype], x['transformer'])) info['returns'] = returns return do_func, info @@ -767,8 +974,8 @@ class Application(object): response=rowdat_1_response_dict, ), (b'application/octet-stream', b'1.0', 'list'): dict( - load=rowdat_1.load, - dump=rowdat_1.dump, + load=rowdat_1.load_list, + dump=rowdat_1.dump_list, response=rowdat_1_response_dict, ), (b'application/octet-stream', b'1.0', 'pandas'): dict( @@ -797,8 +1004,8 @@ class Application(object): response=json_response_dict, ), (b'application/json', b'1.0', 'list'): dict( - load=jdata.load, - dump=jdata.dump, + load=jdata.load_list, + dump=jdata.dump_list, response=json_response_dict, ), (b'application/json', b'1.0', 'pandas'): dict( @@ -1233,7 +1440,7 @@ async def __call__( with timer('format_output'): body = output_handler['dump']( - [x[1] for x in func_info['returns']], *result, # type: ignore + func_info['returns'], *result, # type: ignore ) await send(output_handler['response']) diff --git a/singlestoredb/functions/ext/json.py b/singlestoredb/functions/ext/json.py index 05710247..a5498a61 100644 --- a/singlestoredb/functions/ext/json.py +++ b/singlestoredb/functions/ext/json.py @@ -3,15 +3,18 @@ import json from typing import Any from typing import List +from typing import Optional from typing import Tuple from typing import TYPE_CHECKING -from ..dtypes import DEFAULT_VALUES -from ..dtypes import NUMPY_TYPE_MAP -from ..dtypes import PANDAS_TYPE_MAP -from ..dtypes import POLARS_TYPE_MAP -from ..dtypes import PYARROW_TYPE_MAP -from ..dtypes import PYTHON_CONVERTERS +from ..sql_types import DEFAULT_VALUES +from ..sql_types import NUMPY_TYPE_MAP +from ..sql_types import PANDAS_TYPE_MAP +from ..sql_types import POLARS_TYPE_MAP +from ..sql_types import PYARROW_TYPE_MAP +from ..sql_types import PYTHON_CONVERTERS +from .utils import apply_transformer +from .utils import Transformer if TYPE_CHECKING: try: @@ -40,19 +43,33 @@ def default(self, obj: Any) -> Any: return json.JSONEncoder.default(self, obj) -def decode_row(coltypes: List[int], row: List[Any]) -> List[Any]: +def decode_row( + colspec: List[Tuple[str, int, Optional[Transformer]]], + row: List[Any], +) -> List[Any]: out = [] - for dtype, item in zip(coltypes, row): - out.append(PYTHON_CONVERTERS[dtype](item)) # type: ignore + for (_, dtype, transformer), item in zip(colspec, row): + out.append( + apply_transformer( + transformer, + PYTHON_CONVERTERS[dtype](item), # type: ignore + ), + ) return out -def decode_value(coltype: int, data: Any) -> Any: - return PYTHON_CONVERTERS[coltype](data) # type: ignore +def decode_value( + colspec: Tuple[str, int, Optional[Transformer]], + data: Any, +) -> Any: + return apply_transformer( + colspec[2], + PYTHON_CONVERTERS[colspec[1]](data), # type: ignore + ) def load( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[List[int], List[Any]]: ''' @@ -60,7 +77,7 @@ def load( Parameters ---------- - colspec : Iterable[Tuple[str, int]] + colspec : Iterable[Tuple[str, int, Optional[Transformer]]] An Iterable of column data types data : bytes The data in JSON format @@ -74,12 +91,12 @@ def load( rows = [] for row_id, *row in json.loads(data.decode('utf-8'))['data']: row_ids.append(row_id) - rows.append(decode_row([x[1] for x in colspec], row)) + rows.append(decode_row(colspec, row)) return row_ids, rows def _load_vectors( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[List[int], List[Any]]: ''' @@ -87,7 +104,7 @@ def _load_vectors( Parameters ---------- - colspec : Iterable[Tuple[str, int]] + colspec : Iterable[Tuple[str, int, Optional[Transformer]]] An Iterable of column data types data : bytes The data in JSON format @@ -107,13 +124,13 @@ def _load_vectors( if not cols: cols = [([], []) for _ in row] for i, (spec, x) in enumerate(zip(colspec, row)): - cols[i][0].append(decode_value(spec[1], x) if x is not None else defaults[i]) + cols[i][0].append(decode_value(spec, x) if x is not None else defaults[i]) cols[i][1].append(False if x is not None else True) return row_ids, cols def load_pandas( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[List[int], List[Any]]: ''' @@ -121,7 +138,7 @@ def load_pandas( Parameters ---------- - colspec : Iterable[Tuple[str, int]] + colspec : Iterable[Tuple[str, int, Optional[Transformer]]] An Iterable of column data types data : bytes The data in JSON format @@ -149,7 +166,7 @@ def load_pandas( def load_polars( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[List[int], List[Any]]: ''' @@ -180,7 +197,7 @@ def load_polars( def load_numpy( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[Any, List[Any]]: ''' @@ -188,7 +205,7 @@ def load_numpy( Parameters ---------- - colspec : Iterable[Tuple[str, int]] + colspec : Iterable[Tuple[str, int, Optional[Transformer]]] An Iterable of column data types data : bytes The data in JSON format @@ -211,7 +228,7 @@ def load_numpy( def load_arrow( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[Any, List[Any]]: ''' @@ -219,7 +236,7 @@ def load_arrow( Parameters ---------- - colspec : Iterable[Tuple[str, int]] + colspec : Iterable[Tuple[str, int, Optional[Transformer]]] An Iterable of column data types data : bytes The data in JSON format @@ -240,12 +257,12 @@ def load_arrow( ), pa.array(mask, type=pa.bool_()), ) - for (data, mask), (name, dtype) in zip(cols, colspec) + for (data, mask), (name, dtype, _) in zip(cols, colspec) ] def dump( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: List[int], rows: List[List[Any]], ) -> bytes: @@ -254,7 +271,7 @@ def dump( Parameters ---------- - returns : List[int] + returns : List[Tuple[str, int, Optional[Transformer]]] The returned data type row_ids : List[int] Row IDs @@ -266,12 +283,20 @@ def dump( bytes ''' + rows = list(rows) + transformers = [] + for i, (_, _, transformer) in enumerate(returns): + if transformer is not None: + transformers.append((i, transformer)) + for (i, transformer) in transformers: + for row in rows: + row[i] = apply_transformer(transformer, row[i]) data = list(zip(row_ids, *list(zip(*rows)))) return json.dumps(dict(data=data), cls=JSONEncoder).encode('utf-8') def _dump_vectors( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: List[int], cols: List[Tuple[Any, Any]], ) -> bytes: @@ -295,9 +320,18 @@ def _dump_vectors( masked_cols = [] for i, (data, mask) in enumerate(cols): if mask is not None: - masked_cols.append([d if m is not None else None for d, m in zip(data, mask)]) + masked_cols.append([ + apply_transformer( + returns[i][2], d, + ) if m is not None else None for d, m in zip(data, mask) + ]) else: - masked_cols.append(cols[i][0]) + masked_cols.append( + apply_transformer( + returns[i][2], + cols[i][0], + ), + ) data = list(zip(row_ids, *masked_cols)) return json.dumps(dict(data=data), cls=JSONEncoder).encode('utf-8') @@ -307,7 +341,7 @@ def _dump_vectors( def dump_pandas( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pd.Series[int]', cols: List[Tuple['pd.Series[int]', 'pd.Series[bool]']], ) -> bytes: @@ -316,7 +350,7 @@ def dump_pandas( Parameters ---------- - returns : List[int] + returns : List[Tuple[str, int, Optional[Transformer]]] The returned data type row_ids : pd.Series[int] Row IDs @@ -329,13 +363,34 @@ def dump_pandas( ''' import pandas as pd + row_ids.index = row_ids - df = pd.concat([row_ids] + [x[0] for x in cols], axis=1) + + for i, ((data, mask), (_, dtype, transformer)) in enumerate(zip(cols, returns)): + data.index = row_ids.index + if mask is not None: + mask.index = row_ids.index + cols[i] = pd.Series( + [ + apply_transformer(transformer, d) if not m else None + for d, m in zip(data, mask) + ], index=row_ids.index, name=data.name, dtype=PANDAS_TYPE_MAP[dtype], + ) + else: + cols[i] = pd.Series( + apply_transformer(transformer, data), + index=row_ids.index, + name=data.name, + dtype=PANDAS_TYPE_MAP[dtype], + ) + + df = pd.concat([row_ids] + cols, axis=1) + return ('{"data": %s}' % df.to_json(orient='values')).encode('utf-8') def dump_polars( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pl.Series[int]', cols: List[Tuple['pl.Series[Any]', 'pl.Series[int]']], ) -> bytes: @@ -344,7 +399,7 @@ def dump_polars( Parameters ---------- - returns : List[int] + returns : List[Tuple[str, int, Optional[Transformer]]] The returned data type row_ids : List[int] cols : List[Tuple[polars.Series[Any], polars.Series[bool]] @@ -363,7 +418,7 @@ def dump_polars( def dump_numpy( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'np.typing.NDArray[np.int64]', cols: List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']], ) -> bytes: @@ -372,7 +427,7 @@ def dump_numpy( Parameters ---------- - returns : List[int] + returns : List[Tuple[str, int, Optional[Transformer]]] The returned data type row_ids : List[int] Row IDs @@ -392,7 +447,7 @@ def dump_numpy( def dump_arrow( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pa.Array[int]', cols: List[Tuple['pa.Array[int]', 'pa.Array[bool]']], ) -> bytes: @@ -401,7 +456,7 @@ def dump_arrow( Parameters ---------- - returns : List[int] + returns : List[Tuple[str, int, Optional[Transformer]]] The returned data type row_ids : pyarrow.Array[int] Row IDs diff --git a/singlestoredb/functions/ext/rowdat_1.py b/singlestoredb/functions/ext/rowdat_1.py index 94e966b7..267d9664 100644 --- a/singlestoredb/functions/ext/rowdat_1.py +++ b/singlestoredb/functions/ext/rowdat_1.py @@ -11,11 +11,13 @@ from ...config import get_option from ...mysql.constants import FIELD_TYPE as ft -from ..dtypes import DEFAULT_VALUES -from ..dtypes import NUMPY_TYPE_MAP -from ..dtypes import PANDAS_TYPE_MAP -from ..dtypes import POLARS_TYPE_MAP -from ..dtypes import PYARROW_TYPE_MAP +from ..sql_types import DEFAULT_VALUES +from ..sql_types import NUMPY_TYPE_MAP +from ..sql_types import PANDAS_TYPE_MAP +from ..sql_types import POLARS_TYPE_MAP +from ..sql_types import PYARROW_TYPE_MAP +from .utils import apply_transformer +from .utils import Transformer if TYPE_CHECKING: try: @@ -79,17 +81,17 @@ ft.FLOAT: 4, ft.DOUBLE: 8, } -medium_int_types = set([ft.INT24, -ft.INT24]) -int_types = set([ - ft.TINY, -ft.TINY, ft.SHORT, -ft.SHORT, ft.INT24, -ft.INT24, - ft.LONG, -ft.LONG, ft.LONGLONG, -ft.LONGLONG, -]) -string_types = set([15, 245, 247, 248, 249, 250, 251, 252, 253, 254]) +medium_int_types = {ft.INT24, -ft.INT24} +int_types = { + ft.TINY, -ft.TINY, ft.SHORT, -ft.SHORT, ft.INT24, -ft.INT24, ft.LONG, + -ft.LONG, ft.LONGLONG, -ft.LONGLONG, +} +string_types = {15, 245, 247, 248, 249, 250, 251, 252, 253, 254} binary_types = set([-x for x in string_types]) def _load( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[List[int], List[Any]]: ''' @@ -97,7 +99,7 @@ def _load( Parameters ---------- - colspec : List[str] + colspec : List[Tuple[str, int, Optional[Transformer]]] An List of column data types data : bytes The data in rowdat_1 format @@ -115,19 +117,28 @@ def _load( while data_io.tell() < data_len: row_ids.append(struct.unpack(' Tuple[List[int], List[Tuple[Sequence[Any], Optional[Sequence[Any]]]]]: ''' @@ -144,7 +155,7 @@ def _load_vectors( Parameters ---------- - colspec : List[str] + colspec : List[str, int, Optional[Transformer]] An List of column data types data : bytes The data in rowdat_1 format @@ -162,20 +173,29 @@ def _load_vectors( val = None while data_io.tell() < data_len: row_ids.append(struct.unpack(' Tuple[ 'pd.Series[np.int64]', @@ -195,7 +215,7 @@ def _load_pandas( Parameters ---------- - colspec : List[str] + colspec : List[str, int, Optional[Transformer]] An List of column data types data : bytes The data in rowdat_1 format @@ -215,12 +235,12 @@ def _load_pandas( pd.Series(data, index=index, name=name, dtype=PANDAS_TYPE_MAP[dtype]), pd.Series(mask, index=index, dtype=np.bool_), ) - for (data, mask), (name, dtype) in zip(cols, colspec) + for (data, mask), (name, dtype, _) in zip(cols, colspec) ] def _load_polars( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'pl.Series[pl.Int64]', @@ -231,7 +251,7 @@ def _load_polars( Parameters ---------- - colspec : List[str] + colspec : List[str, int, Optional[Transformer]] An List of column data types data : bytes The data in rowdat_1 format @@ -244,18 +264,19 @@ def _load_polars( import polars as pl row_ids, cols = _load_vectors(colspec, data) + return pl.Series(None, row_ids, dtype=pl.Int64), \ [ ( pl.Series(name=name, values=data, dtype=POLARS_TYPE_MAP[dtype]), pl.Series(values=mask, dtype=pl.Boolean), ) - for (data, mask), (name, dtype) in zip(cols, colspec) + for (data, mask), (name, dtype, _) in zip(cols, colspec) ] def _load_numpy( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'np.typing.NDArray[np.int64]', @@ -266,7 +287,7 @@ def _load_numpy( Parameters ---------- - colspec : List[str] + colspec : List[str, int, Optional[Transformer]] An List of column data types data : bytes The data in rowdat_1 format @@ -279,18 +300,19 @@ def _load_numpy( import numpy as np row_ids, cols = _load_vectors(colspec, data) + return np.asarray(row_ids, dtype=np.int64), \ [ ( np.asarray(data, dtype=NUMPY_TYPE_MAP[dtype]), # type: ignore np.asarray(mask, dtype=np.bool_), # type: ignore ) - for (data, mask), (name, dtype) in zip(cols, colspec) + for (data, mask), (name, dtype, _) in zip(cols, colspec) ] def _load_arrow( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'pa.Array[pa.int64]', @@ -301,7 +323,7 @@ def _load_arrow( Parameters ---------- - colspec : List[str] + colspec : List[str, int, Optional[Transformer]] An List of column data types data : bytes The data in rowdat_1 format @@ -314,6 +336,7 @@ def _load_arrow( import pyarrow as pa row_ids, cols = _load_vectors(colspec, data) + return pa.array(row_ids, type=pa.int64()), \ [ ( @@ -323,12 +346,12 @@ def _load_arrow( ), pa.array(mask, type=pa.bool_()), ) - for (data, mask), (name, dtype) in zip(cols, colspec) + for (data, mask), (name, dtype, _) in zip(cols, colspec) ] def _dump( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: List[int], rows: List[List[Any]], ) -> bytes: @@ -337,7 +360,7 @@ def _dump( Parameters ---------- - returns : List[int] + returns : List[Tuple[str, int, Optional[Transformer]]] The returned data type row_ids : List[int] The row IDs @@ -352,13 +375,14 @@ def _dump( out = BytesIO() if len(rows) == 0 or len(row_ids) == 0: - return out.getbuffer() + return out.getvalue() for row_id, *values in zip(row_ids, *list(zip(*rows))): out.write(struct.pack(' bytes: @@ -406,7 +430,7 @@ def _dump_vectors( Parameters ---------- - returns : List[int] + returns : List[Tuple[str, int, Optional[Transformer]]] The returned data type row_ids : List[int] The row IDs @@ -421,14 +445,14 @@ def _dump_vectors( out = BytesIO() if len(cols) == 0 or len(row_ids) == 0: - return out.getbuffer() + return out.getvalue() for j, row_id in enumerate(row_ids): out.write(struct.pack(' bytes: @@ -490,7 +520,7 @@ def _dump_arrow( def _dump_numpy( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'np.typing.NDArray[np.int64]', cols: List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']], ) -> bytes: @@ -502,7 +532,7 @@ def _dump_numpy( def _dump_pandas( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pd.Series[np.int64]', cols: List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']], ) -> bytes: @@ -514,7 +544,7 @@ def _dump_pandas( def _dump_polars( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pl.Series[pl.Int64]', cols: List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']], ) -> bytes: @@ -526,7 +556,7 @@ def _dump_polars( def _load_numpy_accel( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'np.typing.NDArray[np.int64]', @@ -535,22 +565,50 @@ def _load_numpy_accel( if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') - return _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) + import numpy as np + + numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) + + for i, (_, dtype, transformer) in enumerate(colspec): + if transformer is not None: + # Numpy will try to be "helpful" and create multidimensional arrays + # from nested iterables. We don't usually want that. What we want is + # numpy arrays of Python objects (e.g., lists, dicts, etc). To do that, + # we have to create an empty array of the correct length and dtype=object, + # then fill it in with the transformed values. The transformer may have + # an output_type attribute that we can use to create a more specific type. + if getattr(transformer, 'output_type', None): + new_col = np.empty( + len(numpy_cols[i][0]), + dtype=transformer.output_type, # type: ignore + ) + new_col[:] = list(map(transformer, numpy_cols[i][0])) + else: + new_col = np.array(list(map(transformer, numpy_cols[i][0]))) + numpy_cols[i] = (new_col, numpy_cols[i][1]) + + return numpy_ids, numpy_cols def _dump_numpy_accel( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'np.typing.NDArray[np.int64]', cols: List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']], ) -> bytes: if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') + import numpy as np + + for i, (_, dtype, transformer) in enumerate(returns): + if transformer is not None: + cols[i] = (np.array(list(map(transformer, cols[i][0]))), cols[i][1]) + return _singlestoredb_accel.dump_rowdat_1_numpy(returns, row_ids, cols) def _load_pandas_accel( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'pd.Series[np.int64]', @@ -562,19 +620,21 @@ def _load_pandas_accel( import numpy as np import pandas as pd - numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) + numpy_ids, numpy_cols = _load_numpy_accel(colspec, data) + cols = [ ( pd.Series(data, name=name, dtype=PANDAS_TYPE_MAP[dtype]), pd.Series(mask, dtype=np.bool_), ) - for (name, dtype), (data, mask) in zip(colspec, numpy_cols) + for (name, dtype, _), (data, mask) in zip(colspec, numpy_cols) ] + return pd.Series(numpy_ids, dtype=np.int64), cols def _dump_pandas_accel( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pd.Series[np.int64]', cols: List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']], ) -> bytes: @@ -589,11 +649,12 @@ def _dump_pandas_accel( ) for data, mask in cols ] - return _singlestoredb_accel.dump_rowdat_1_numpy(returns, numpy_ids, numpy_cols) + + return _dump_numpy_accel(returns, numpy_ids, numpy_cols) def _load_polars_accel( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'pl.Series[pl.Int64]', @@ -604,7 +665,8 @@ def _load_polars_accel( import polars as pl - numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) + numpy_ids, numpy_cols = _load_numpy_accel(colspec, data) + cols = [ ( pl.Series( @@ -614,32 +676,42 @@ def _load_polars_accel( ), pl.Series(values=mask, dtype=pl.Boolean), ) - for (name, dtype), (data, mask) in zip(colspec, numpy_cols) + for (name, dtype, _), (data, mask) in zip(colspec, numpy_cols) ] + return pl.Series(values=numpy_ids, dtype=pl.Int64), cols def _dump_polars_accel( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pl.Series[pl.Int64]', cols: List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']], ) -> bytes: if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') + import numpy as np + import polars as pl + numpy_ids = row_ids.to_numpy() numpy_cols = [ ( - data.to_numpy(), + # Polars will try to be "helpful" and convert nested iterables into + # multidimensional arrays. We don't usually want that. What we want is + # numpy arrays of Python objects (e.g., lists, dicts, etc). To + # do that, we have to convert the Series to a list first. + np.array(data.to_list()) + if isinstance(data.dtype, (pl.Struct, pl.Object)) else data.to_numpy(), mask.to_numpy() if mask is not None else None, ) for data, mask in cols ] - return _singlestoredb_accel.dump_rowdat_1_numpy(returns, numpy_ids, numpy_cols) + + return _dump_numpy_accel(returns, numpy_ids, numpy_cols) def _load_arrow_accel( - colspec: List[Tuple[str, int]], + colspec: List[Tuple[str, int, Optional[Transformer]]], data: bytes, ) -> Tuple[ 'pa.Array[pa.int64]', @@ -650,13 +722,13 @@ def _load_arrow_accel( import pyarrow as pa - numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) + numpy_ids, numpy_cols = _load_numpy_accel(colspec, data) cols = [ ( pa.array(data, type=PYARROW_TYPE_MAP[dtype], mask=mask), pa.array(mask, type=pa.bool_()), ) - for (data, mask), (name, dtype) in zip(numpy_cols, colspec) + for (data, mask), (name, dtype, _) in zip(numpy_cols, colspec) ] return pa.array(numpy_ids, type=pa.int64()), cols @@ -670,11 +742,11 @@ def _create_arrow_mask( if mask is None: return data.is_null().to_numpy(zero_copy_only=False) - return pc.or_(data.is_null(), mask.is_null()).to_numpy(zero_copy_only=False) + return pc.or_(data.is_null(), mask).to_numpy(zero_copy_only=False) def _dump_arrow_accel( - returns: List[int], + returns: List[Tuple[str, int, Optional[Transformer]]], row_ids: 'pa.Array[pa.int64]', cols: List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_]']], ) -> bytes: @@ -686,11 +758,27 @@ def _dump_arrow_accel( data.fill_null(DEFAULT_VALUES[dtype]).to_numpy(zero_copy_only=False), _create_arrow_mask(data, mask), ) - for (data, mask), dtype in zip(cols, returns) + for (data, mask), (_, dtype, _) in zip(cols, returns) ] - return _singlestoredb_accel.dump_rowdat_1_numpy( - returns, row_ids.to_numpy(), numpy_cols, - ) + + return _dump_numpy_accel(returns, row_ids.to_numpy(), numpy_cols) + + +def _dump_rowdat_1_accel( + returns: List[Tuple[str, int, Optional[Transformer]]], + row_ids: List[int], + rows: List[List[Any]], +) -> bytes: + # C function now handles transformers internally + return _singlestoredb_accel.dump_rowdat_1(returns, row_ids, rows) + + +def _load_rowdat_1_accel( + colspec: List[Tuple[str, int, Optional[Transformer]]], + data: bytes, +) -> Tuple[List[int], List[Any]]: + # C function now handles transformers internally + return _singlestoredb_accel.load_rowdat_1(colspec, data) if not has_accel: @@ -708,8 +796,8 @@ def _dump_arrow_accel( dump_polars = _dump_polars_accel = _dump_polars # noqa: F811 else: - _load_accel = _singlestoredb_accel.load_rowdat_1 - _dump_accel = _singlestoredb_accel.dump_rowdat_1 + _load_accel = _load_rowdat_1_accel + _dump_accel = _dump_rowdat_1_accel load = _load_accel dump = _dump_accel load_list = _load_vectors diff --git a/singlestoredb/functions/ext/utils.py b/singlestoredb/functions/ext/utils.py index 4eda318b..360e3857 100644 --- a/singlestoredb/functions/ext/utils.py +++ b/singlestoredb/functions/ext/utils.py @@ -9,6 +9,7 @@ from typing import Any from typing import Dict from typing import List +from typing import Optional from typing import Union try: @@ -30,6 +31,14 @@ def formatMessage(self, record: logging.LogRecord) -> str: recordcopy.__dict__['levelprefix'] = levelname + ':' + seperator return super().formatMessage(recordcopy) +from ..typing import Transformer + + +def apply_transformer(func: Optional[Transformer], v: Any) -> Any: + if func is not None: + return func(v) + return v + class JSONFormatter(logging.Formatter): """Custom JSON formatter for structured logging.""" diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 69cbb437..cbb9ed94 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -25,9 +25,15 @@ except ImportError: has_numpy = False +try: + # Python 3.9+ should use typing directly, else fallback to typing_extensions + from typing import Annotated +except ImportError: + from typing_extensions import Annotated # type: ignore -from . import dtypes as dt +from . import sql_types as dt from . import utils +from .typing import UDFAttrs from .typing import Table from .typing import Masked from ..mysql.converters import escape_item # type: ignore @@ -143,6 +149,7 @@ class NoDefaultType: 'date': 'DATE', 'time': 'TIME', 'time6': 'TIME(6)', + 'json': 'JSON', } sql_to_type_map = { @@ -184,6 +191,7 @@ class NoDefaultType: 'TINYBLOB': 'bytes', 'MEDIUMBLOB': 'bytes', 'LONGBLOB': 'bytes', + 'JSON': 'json', } @@ -223,23 +231,6 @@ class ArrayCollection(Collection): pass -def get_data_format(obj: Any) -> str: - """Return the data format of the DataFrame / Table / vector.""" - # Cheating here a bit so we don't have to import pandas / polars / pyarrow - # unless we absolutely need to - if getattr(obj, '__module__', '').startswith('pandas.'): - return 'pandas' - if getattr(obj, '__module__', '').startswith('polars.'): - return 'polars' - if getattr(obj, '__module__', '').startswith('pyarrow.'): - return 'arrow' - if getattr(obj, '__module__', '').startswith('numpy.'): - return 'numpy' - if isinstance(obj, list): - return 'list' - return 'scalar' - - def escape_name(name: str) -> str: """Escape a function parameter name.""" if '`' in name: @@ -377,8 +368,12 @@ def normalize_dtype(dtype: Any) -> str: origin = typing.get_origin(dtype) if origin is not None: + # Dict type (for JSON support) + if origin is dict: + return 'json' + # Tuple type - if origin is Tuple: + elif origin is Tuple: args = typing.get_args(dtype) item_dtypes = ','.join(normalize_dtype(x) for x in args) return f'tuple[{item_dtypes}]' @@ -784,6 +779,41 @@ def unwrap_optional(annotation: Any) -> Tuple[Any, bool]: return annotation, is_optional +def validate_udf_type(spec: Any, origin: Any, args_origins: List[Any], mode: str) -> None: + """ + Validate that a type specification is valid for UDF parameters or returns. + + Parameters + ---------- + spec : Any + The type specification to validate + origin : Any + The origin type from typing.get_origin() + args_origins : List[Any] + List of origins from the type arguments + mode : str + Either 'parameter' or 'return' + + Raises + ------ + TypeError + If the type is not valid for UDF usage + + """ + # Short circuit check for common valid types + if utils.is_vector(spec) or type(spec) is type and spec in {str, float, int, bytes}: + return + + # Try to catch some common mistakes + if origin in [tuple, dict] or tuple in args_origins or is_composite_type(spec): + type_desc = 'return type' if mode == 'return' else 'parameter types' + usage_desc = 'scalar or vector' if mode == 'return' else 'scalar or vector' + raise TypeError( + f'invalid {type_desc} for a UDF; expecting a {usage_desc}, ' + f'got {getattr(spec, "__name__", spec)}', + ) + + def is_composite_type(spec: Any) -> bool: """ Check if the object is a composite type (e.g., dataclass, TypedDict, etc.). @@ -843,10 +873,38 @@ def check_composite_type(colspec: List[ParamSpec], mode: str, type_name: str) -> return False +def unpack_annotated(spec: Any) -> Tuple[Any, UDFAttrs]: + """ + Unpack an Annotated type into its base type and metadata. + + Parameters + ---------- + spec : Any + The type annotation to unpack + + Returns + ------- + Tuple[Any, UDFAttrs] + A tuple containing: + - The base type of the annotation + - The Apply metadata, or an empty Apply if none exists + + """ + if typing.get_origin(spec) is Annotated: + args = typing.get_args(spec) + base_type = args[0] + metadata = [x for x in args[1:] if isinstance(x, UDFAttrs)] + if metadata: + return base_type, metadata[0] + return base_type, UDFAttrs() + return spec, UDFAttrs() + + def get_schema( spec: Any, overrides: Optional[List[ParamSpec]] = None, mode: str = 'parameter', + masks: Optional[List[bool]] = None, ) -> Tuple[List[ParamSpec], str, str]: """ Expand a return type annotation into a list of types and field names. @@ -859,6 +917,8 @@ def get_schema( List of SQL type specifications for the return type mode : str The mode of the function, either 'parameter' or 'return' + is_masked : bool + Whether the type is wrapped in a Masked type Returns ------- @@ -873,10 +933,19 @@ def get_schema( udf_parameter = '`returns=`' if mode == 'return' else '`args=`' spec, is_optional = unwrap_optional(spec) + spec, udf_attrs = unpack_annotated(spec) origin = typing.get_origin(spec) args = typing.get_args(spec) args_origins = [typing.get_origin(x) if x is not None else None for x in args] + if not overrides and udf_attrs.sql_type: + overrides = [ + ParamSpec( + dtype=normalize_dtype(udf_attrs.sql_type), + sql_type=udf_attrs.sql_type, + ), + ] + # Make sure that the result of a TVF is a list or dataframe if mode == 'return': @@ -885,33 +954,41 @@ def get_schema( function_type = 'tvf' - if utils.is_dataframe(args[0]): + unpacked_spec = [x[0] for x in (unpack_annotated(x) for x in args)] + + if utils.is_dataframe(unpacked_spec[0]): if not overrides: raise TypeError( 'column types must be specified by the ' '`returns=` parameter of the @udf decorator', ) - if utils.get_module(args[0]) in ['pandas', 'polars', 'pyarrow']: - data_format = utils.get_module(args[0]) - spec = args[0] + if utils.get_module(unpacked_spec[0]) in ['pandas', 'polars', 'pyarrow']: + data_format = utils.get_module(unpacked_spec[0]) + spec = unpacked_spec[0] else: raise TypeError( 'only pandas.DataFrames, polars.DataFrames, ' 'and pyarrow.Tables are supported as tables.', ) - elif typing.get_origin(args[0]) is list: + elif typing.get_origin(unpacked_spec[0]) is list: if len(args) != 1: raise TypeError( 'only one list is supported within a table; to ' 'return multiple columns, use a tuple, NamedTuple, ' 'dataclass, TypedDict, or pydantic model', ) - spec = typing.get_args(args[0])[0] - data_format = 'list' + spec = typing.get_args(unpacked_spec[0])[0] + # Lists as output from TVFs are considered scalar outputs + # since they correspond to individual Python objects, not + # a true vector type. + if function_type == 'tvf': + data_format = 'scalar' + else: + data_format = 'list' - elif all([utils.is_vector(x, include_masks=True) for x in args]): + elif all([utils.is_vector(x, include_masks=True) for x in unpacked_spec]): pass else: @@ -920,27 +997,19 @@ def get_schema( 'or tuple of vectors', ) - # Short circuit check for common valid types - elif utils.is_vector(spec) or spec in {str, float, int, bytes}: + elif overrides: pass - # Try to catch some common mistakes - elif origin in [tuple, dict] or tuple in args_origins or is_composite_type(spec): - raise TypeError( - 'invalid return type for a UDF; expecting a scalar or vector, ' - f'but got {getattr(spec, "__name__", spec)}', - ) + # Validate the return type + else: + validate_udf_type(spec, origin, args_origins, mode) - # Short circuit check for common valid types - elif utils.is_vector(spec) or spec in {str, float, int, bytes}: + elif overrides: pass - # Error out for incorrect parameter types - elif origin in [tuple, dict] or tuple in args_origins or is_composite_type(spec): - raise TypeError( - 'parameter types must be scalar or vector, ' - f'got {getattr(spec, "__name__", spec)}', - ) + # Validate parameter types + else: + validate_udf_type(spec, origin, args_origins, mode) # # Process each parameter / return type into a colspec @@ -1031,7 +1100,21 @@ def get_schema( # Plain list vector elif typing.get_origin(spec) is list: data_format = 'list' - colspec = [ParamSpec(dtype=typing.get_args(spec)[0], is_optional=is_optional)] + _, inner_apply_meta = unpack_annotated(typing.get_args(spec)[0]) + if inner_apply_meta.sql_type: + udf_attrs = inner_apply_meta + colspec = get_schema( + typing.get_args(spec)[0], + mode=mode, + masks=[masks[0]] if masks else None, + )[0] + else: + colspec = [ + ParamSpec( + dtype=typing.get_args(spec)[0], + is_optional=is_optional, + ), + ] # Multiple return values elif inspect.isclass(typing.get_origin(spec)) \ @@ -1055,6 +1138,7 @@ def get_schema( overrides=[overrides[i]] if overrides else [], # Always pass UDF mode for individual items mode=mode, + masks=[masks[i]] if masks else None, ) # Use the name from the overrides if specified @@ -1084,7 +1168,7 @@ def get_schema( # Use overrides if specified elif overrides: if not data_format: - data_format = get_data_format(spec) + data_format = utils.get_data_format(spec) colspec = overrides # Single value, no override @@ -1096,21 +1180,45 @@ def get_schema( out = [] # Normalize colspec data types - for c in colspec: + for i, c in enumerate(colspec): + # if the dtype is a string, it is resolved already if isinstance(c.dtype, str): dtype = c.dtype + + # As long as we don't have explicit overrides, + # use the sql_type from the annotation + elif not overrides and udf_attrs.sql_type: + dtype = normalize_dtype(udf_attrs.sql_type) + + # Otherwise, normalize the dtype from the signature value else: dtype = collapse_dtypes( [normalize_dtype(x) for x in simplify_dtype(c.dtype)], include_null=c.is_optional, ) + sql_type = c.sql_type if isinstance(c.sql_type, str) else udf_attrs.sql_type + + is_optional = ( + c.is_optional + or bool(dtype and dtype.endswith('?')) + or bool(masks and masks[i]) + ) + + if is_optional: + if dtype and not dtype.endswith('?'): + dtype += '?' + if sql_type and re.search(r' NOT NULL\b', sql_type): + sql_type = re.sub(r' NOT NULL\b', r' NULL', sql_type) + p = ParamSpec( name=c.name, dtype=dtype, - sql_type=c.sql_type if isinstance(c.sql_type, str) else None, - is_optional=c.is_optional, + sql_type=sql_type, + is_optional=is_optional, + transformer=udf_attrs.args_transformer + if mode == 'parameter' else udf_attrs.returns_transformer, ) out.append(p) @@ -1248,6 +1356,7 @@ def get_signature( unpack_masked_type(param.annotation), overrides=[args_colspec[i]] if args_colspec else [], mode='parameter', + masks=[args_masks[i]] if args_masks else [], ) args_data_formats.append(args_data_format) @@ -1285,8 +1394,8 @@ def get_signature( name=pspec.name, dtype=pspec.dtype, sql=sql, - **default_option, transformer=pspec.transformer, + **default_option, ), ) @@ -1307,6 +1416,7 @@ def get_signature( unpack_masked_type(signature.return_annotation), overrides=returns_colspec if returns_colspec else None, mode='return', + masks=ret_masks or [], ) rdf = out['returns_data_format'] = out['returns_data_format'] or 'scalar' diff --git a/singlestoredb/functions/sql_types.py b/singlestoredb/functions/sql_types.py new file mode 100644 index 00000000..5364624b --- /dev/null +++ b/singlestoredb/functions/sql_types.py @@ -0,0 +1,1794 @@ +#!/usr/bin/env python3 +import base64 +import datetime +import decimal +import re +from typing import Any +from typing import Callable +from typing import Optional +from typing import Union + +from ..converters import converters +from ..mysql.converters import escape_item # type: ignore +from ..utils.dtypes import DEFAULT_VALUES # noqa +from ..utils.dtypes import NUMPY_TYPE_MAP # noqa +from ..utils.dtypes import PANDAS_TYPE_MAP # noqa +from ..utils.dtypes import POLARS_TYPE_MAP # noqa +from ..utils.dtypes import PYARROW_TYPE_MAP # noqa + + +DataType = Union[str, Callable[..., Any]] + + +class SQLString(str): + """SQL string type.""" + name: Optional[str] = None + + +class NULL: + """NULL (for use in default values).""" + pass + + +def escape_name(name: str) -> str: + """Escape a function parameter name.""" + if '`' in name: + name = name.replace('`', '``') + return f'`{name}`' + + +# charsets +utf8mb4 = 'utf8mb4' +utf8 = 'utf8' +binary = 'binary' + +# collations +utf8_general_ci = 'utf8_general_ci' +utf8_bin = 'utf8_bin' +utf8_unicode_ci = 'utf8_unicode_ci' +utf8_icelandic_ci = 'utf8_icelandic_ci' +utf8_latvian_ci = 'utf8_latvian_ci' +utf8_romanian_ci = 'utf8_romanian_ci' +utf8_slovenian_ci = 'utf8_slovenian_ci' +utf8_polish_ci = 'utf8_polish_ci' +utf8_estonian_ci = 'utf8_estonian_ci' +utf8_spanish_ci = 'utf8_spanish_ci' +utf8_swedish_ci = 'utf8_swedish_ci' +utf8_turkish_ci = 'utf8_turkish_ci' +utf8_czech_ci = 'utf8_czech_ci' +utf8_danish_ci = 'utf8_danish_ci' +utf8_lithuanian_ci = 'utf8_lithuanian_ci' +utf8_slovak_ci = 'utf8_slovak_ci' +utf8_spanish2_ci = 'utf8_spanish2_ci' +utf8_roman_ci = 'utf8_roman_ci' +utf8_persian_ci = 'utf8_persian_ci' +utf8_esperanto_ci = 'utf8_esperanto_ci' +utf8_hungarian_ci = 'utf8_hungarian_ci' +utf8_sinhala_ci = 'utf8_sinhala_ci' +utf8mb4_general_ci = 'utf8mb4_general_ci' +utf8mb4_bin = 'utf8mb4_bin' +utf8mb4_unicode_ci = 'utf8mb4_unicode_ci' +utf8mb4_icelandic_ci = 'utf8mb4_icelandic_ci' +utf8mb4_latvian_ci = 'utf8mb4_latvian_ci' +utf8mb4_romanian_ci = 'utf8mb4_romanian_ci' +utf8mb4_slovenian_ci = 'utf8mb4_slovenian_ci' +utf8mb4_polish_ci = 'utf8mb4_polish_ci' +utf8mb4_estonian_ci = 'utf8mb4_estonian_ci' +utf8mb4_spanish_ci = 'utf8mb4_spanish_ci' +utf8mb4_swedish_ci = 'utf8mb4_swedish_ci' +utf8mb4_turkish_ci = 'utf8mb4_turkish_ci' +utf8mb4_czech_ci = 'utf8mb4_czech_ci' +utf8mb4_danish_ci = 'utf8mb4_danish_ci' +utf8mb4_lithuanian_ci = 'utf8mb4_lithuanian_ci' +utf8mb4_slovak_ci = 'utf8mb4_slovak_ci' +utf8mb4_spanish2_ci = 'utf8mb4_spanish2_ci' +utf8mb4_roman_ci = 'utf8mb4_roman_ci' +utf8mb4_persian_ci = 'utf8mb4_persian_ci' +utf8mb4_esperanto_ci = 'utf8mb4_esperanto_ci' +utf8mb4_hungarian_ci = 'utf8mb4_hungarian_ci' +utf8mb4_sinhala_ci = 'utf8mb4_sinhala_ci' + + +def identity(x: Any) -> Any: + return x + + +def utf8str(x: Any) -> Optional[str]: + if x is None: + return x + if isinstance(x, str): + return x + return str(x, 'utf-8') + + +def bytestr(x: Any) -> Optional[bytes]: + if x is None: + return x + if isinstance(x, bytes): + return x + return base64.b64decode(x) + + +PYTHON_CONVERTERS = { + -1: converters[1], + -2: converters[2], + -3: converters[3], + -8: converters[8], + -9: converters[9], + 15: utf8str, + -15: bytestr, + 245: utf8str, + 249: utf8str, + -249: bytestr, + 250: utf8str, + -250: bytestr, + 251: utf8str, + -251: bytestr, + 252: utf8str, + -252: bytestr, + 254: utf8str, + -254: bytestr, + 255: utf8str, +} + +PYTHON_CONVERTERS = dict(list(converters.items()) + list(PYTHON_CONVERTERS.items())) + + +def _modifiers( + *, + nullable: Optional[bool] = None, + charset: Optional[str] = None, + collate: Optional[str] = None, + default: Optional[Any] = None, + unsigned: Optional[bool] = None, +) -> str: + """ + Format type modifiers. + + Parameters + ---------- + nullable : bool, optional + Can the value be NULL? + charset : str, optional + Character set + collate : str, optional + Collation + default ; Any, optional + Default value + unsigned : bool, optional + Is the value unsigned? (ints only) + + Returns + ------- + str + + """ + out = [] + + if unsigned is not None: + if unsigned: + out.append('UNSIGNED') + + if charset is not None: + if not re.match(r'^[A-Za-z0-9_]+$', charset): + raise ValueError(f'charset value is invalid: {charset}') + out.append(f'CHARACTER SET {charset}') + + if collate is not None: + if not re.match(r'^[A-Za-z0-9_]+$', collate): + raise ValueError(f'collate value is invalid: {collate}') + out.append(f'COLLATE {collate}') + + if nullable is not None: + if nullable: + out.append('NULL') + else: + out.append('NOT NULL') + + if default is NULL: + out.append('DEFAULT NULL') + elif default is not None: + out.append(f'DEFAULT {escape_item(default, "utf-8")}') + + return ' ' + ' '.join(out) + + +def _bool(x: Optional[bool] = None) -> Optional[bool]: + """Cast bool.""" + if x is None: + return None + return bool(x) + + +def BOOL( + *, + nullable: bool = True, + default: Optional[bool] = None, + name: Optional[str] = None, +) -> SQLString: + """ + BOOL type specification. + + Parameters + ---------- + nullable : bool, optional + Can the value be NULL? + default : bool, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString('BOOL' + _modifiers(nullable=nullable, default=_bool(default))) + out.name = name + return out + + +def BOOLEAN( + *, + nullable: bool = True, + default: Optional[bool] = None, + name: Optional[str] = None, +) -> SQLString: + """ + BOOLEAN type specification. + + Parameters + ---------- + nullable : bool, optional + Can the value be NULL? + default : bool, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString('BOOLEAN' + _modifiers(nullable=nullable, default=_bool(default))) + out.name = name + return out + + +def BIT( + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: + """ + BIT type specification. + + Parameters + ---------- + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString('BIT' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def TINYINT( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + unsigned: bool = False, + name: Optional[str] = None, +) -> SQLString: + """ + TINYINT type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + unsigned : bool, optional + Is the int unsigned? + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'TINYINT({display_width})' if display_width else 'TINYINT' + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out + + +def TINYINT_UNSIGNED( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: + """ + TINYINT UNSIGNED type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'TINYINT({display_width})' if display_width else 'TINYINT' + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out + + +def SMALLINT( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + unsigned: bool = False, + name: Optional[str] = None, +) -> SQLString: + """ + SMALLINT type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + unsigned : bool, optional + Is the int unsigned? + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'SMALLINT({display_width})' if display_width else 'SMALLINT' + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out + + +def SMALLINT_UNSIGNED( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: + """ + SMALLINT UNSIGNED type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'SMALLINT({display_width})' if display_width else 'SMALLINT' + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out + + +def MEDIUMINT( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + unsigned: bool = False, + name: Optional[str] = None, +) -> SQLString: + """ + MEDIUMINT type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + unsigned : bool, optional + Is the int unsigned? + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'MEDIUMINT({display_width})' if display_width else 'MEDIUMINT' + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out + + +def MEDIUMINT_UNSIGNED( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: + """ + MEDIUMINT UNSIGNED type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'MEDIUMINT({display_width})' if display_width else 'MEDIUMINT' + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out + + +def INT( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + unsigned: bool = False, + name: Optional[str] = None, +) -> SQLString: + """ + INT type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + unsigned : bool, optional + Is the int unsigned? + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'INT({display_width})' if display_width else 'INT' + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out + + +def INT_UNSIGNED( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: + """ + INT UNSIGNED type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'INT({display_width})' if display_width else 'INT' + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out + + +def INTEGER( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + unsigned: bool = False, + name: Optional[str] = None, +) -> SQLString: + """ + INTEGER type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + unsigned : bool, optional + Is the int unsigned? + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'INTEGER({display_width})' if display_width else 'INTEGER' + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out + + +def INTEGER_UNSIGNED( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: + """ + INTEGER UNSIGNED type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'INTEGER({display_width})' if display_width else 'INTEGER' + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out + + +def BIGINT( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + unsigned: bool = False, + name: Optional[str] = None, +) -> SQLString: + """ + BIGINT type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + unsigned : bool, optional + Is the int unsigned? + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'BIGINT({display_width})' if display_width else 'BIGINT' + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out + + +def BIGINT_UNSIGNED( + display_width: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: + """ + BIGINT UNSIGNED type specification. + + Parameters + ---------- + display_width : int, optional + Display width used by some clients + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'BIGINT({int(display_width)})' if display_width else 'BIGINT' + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out + + +def FLOAT( + display_decimals: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[float] = None, + name: Optional[str] = None, +) -> SQLString: + """ + FLOAT type specification. + + Parameters + ---------- + display_decimals : int, optional + Number of decimal places to display + nullable : bool, optional + Can the value be NULL? + default : float, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'FLOAT({int(display_decimals)})' if display_decimals else 'FLOAT' + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def DOUBLE( + display_decimals: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[float] = None, + name: Optional[str] = None, +) -> SQLString: + """ + DOUBLE type specification. + + Parameters + ---------- + display_decimals : int, optional + Number of decimal places to display + nullable : bool, optional + Can the value be NULL? + default : float, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'DOUBLE({int(display_decimals)})' if display_decimals else 'DOUBLE' + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def REAL( + display_decimals: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[float] = None, + name: Optional[str] = None, +) -> SQLString: + """ + REAL type specification. + + Parameters + ---------- + display_decimals : int, optional + Number of decimal places to display + nullable : bool, optional + Can the value be NULL? + default : float, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'REAL({int(display_decimals)})' if display_decimals else 'REAL' + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def DECIMAL( + precision: int, + scale: int, + *, + nullable: bool = True, + default: Optional[Union[str, decimal.Decimal]] = None, + name: Optional[str] = None, +) -> SQLString: + """ + DECIMAL type specification. + + Parameters + ---------- + precision : int + Decimal precision + scale : int + Decimal scale + nullable : bool, optional + Can the value be NULL? + default : str or decimal.Decimal, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString( + f'DECIMAL({int(precision)}, {int(scale)})' + + _modifiers(nullable=nullable, default=default), + ) + out.name = name + return out + + +def DEC( + precision: int, + scale: int, + *, + nullable: bool = True, + default: Optional[Union[str, decimal.Decimal]] = None, + name: Optional[str] = None, +) -> SQLString: + """ + DEC type specification. + + Parameters + ---------- + precision : int + Decimal precision + scale : int + Decimal scale + nullable : bool, optional + Can the value be NULL? + default : str or decimal.Decimal, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString( + f'DEC({int(precision)}, {int(scale)})' + + _modifiers(nullable=nullable, default=default), + ) + out.name = name + return out + + +def FIXED( + precision: int, + scale: int, + *, + nullable: bool = True, + default: Optional[Union[str, decimal.Decimal]] = None, + name: Optional[str] = None, +) -> SQLString: + """ + FIXED type specification. + + Parameters + ---------- + precision : int + Decimal precision + scale : int + Decimal scale + nullable : bool, optional + Can the value be NULL? + default : str or decimal.Decimal, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString( + f'FIXED({int(precision)}, {int(scale)})' + + _modifiers(nullable=nullable, default=default), + ) + out.name = name + return out + + +def NUMERIC( + precision: int, + scale: int, + *, + nullable: bool = True, + default: Optional[Union[str, decimal.Decimal]] = None, + name: Optional[str] = None, +) -> SQLString: + """ + NUMERIC type specification. + + Parameters + ---------- + precision : int + Decimal precision + scale : int + Decimal scale + nullable : bool, optional + Can the value be NULL? + default : str or decimal.Decimal, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString( + f'NUMERIC({int(precision)}, {int(scale)})' + + _modifiers(nullable=nullable, default=default), + ) + out.name = name + return out + + +def DATE( + *, + nullable: bool = True, + default: Optional[Union[str, datetime.date]] = None, + name: Optional[str] = None, +) -> SQLString: + """ + DATE type specification. + + Parameters + ---------- + nullable : bool, optional + Can the value be NULL? + default : str or datetime.date, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString('DATE' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def TIME( + precision: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[Union[str, datetime.timedelta]] = None, + name: Optional[str] = None, +) -> SQLString: + """ + TIME type specification. + + Parameters + ---------- + precision : int, optional + Sub-second precision + nullable : bool, optional + Can the value be NULL? + default : str or datetime.timedelta, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'TIME({int(precision)})' if precision else 'TIME' + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def DATETIME( + precision: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[Union[str, datetime.datetime]] = None, + name: Optional[str] = None, +) -> SQLString: + """ + DATETIME type specification. + + Parameters + ---------- + precision : int, optional + Sub-second precision + nullable : bool, optional + Can the value be NULL? + default : str or datetime.datetime, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'DATETIME({int(precision)})' if precision else 'DATETIME' + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def TIMESTAMP( + precision: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[Union[str, datetime.datetime]] = None, + name: Optional[str] = None, +) -> SQLString: + """ + TIMESTAMP type specification. + + Parameters + ---------- + precision : int, optional + Sub-second precision + nullable : bool, optional + Can the value be NULL? + default : str or datetime.datetime, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'TIMESTAMP({int(precision)})' if precision else 'TIMESTAMP' + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def YEAR( + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: + """ + YEAR type specification. + + Parameters + ---------- + nullable : bool, optional + Can the value be NULL? + default : int, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString('YEAR' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def CHAR( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[str] = None, + collate: Optional[str] = None, + charset: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + CHAR type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + charset : str, optional + Character set + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'CHAR({int(length)})' if length else 'CHAR' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), + ) + out.name = name + return out + + +def VARCHAR( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[str] = None, + collate: Optional[str] = None, + charset: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + VARCHAR type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + charset : str, optional + Character set + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'VARCHAR({int(length)})' if length else 'VARCHAR' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), + ) + out.name = name + return out + + +def LONGTEXT( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[str] = None, + collate: Optional[str] = None, + charset: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + LONGTEXT type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + charset : str, optional + Character set + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'LONGTEXT({int(length)})' if length else 'LONGTEXT' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), + ) + out.name = name + return out + + +def MEDIUMTEXT( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[str] = None, + collate: Optional[str] = None, + charset: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + MEDIUMTEXT type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + charset : str, optional + Character set + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'MEDIUMTEXT({int(length)})' if length else 'MEDIUMTEXT' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), + ) + out.name = name + return out + + +def TEXT( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[str] = None, + collate: Optional[str] = None, + charset: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + TEXT type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + charset : str, optional + Character set + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'TEXT({int(length)})' if length else 'TEXT' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), + ) + out.name = name + return out + + +def TINYTEXT( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[str] = None, + collate: Optional[str] = None, + charset: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + TINYTEXT type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + charset : str, optional + Character set + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'TINYTEXT({int(length)})' if length else 'TINYTEXT' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), + ) + out.name = name + return out + + +def BINARY( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[bytes] = None, + collate: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + BINARY type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'BINARY({int(length)})' if length else 'BINARY' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), + ) + out.name = name + return out + + +def VARBINARY( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[bytes] = None, + collate: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + VARBINARY type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'VARBINARY({int(length)})' if length else 'VARBINARY' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), + ) + out.name = name + return out + + +def LONGBLOB( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[bytes] = None, + collate: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + LONGBLOB type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'LONGBLOB({int(length)})' if length else 'LONGBLOB' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), + ) + out.name = name + return out + + +def MEDIUMBLOB( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[bytes] = None, + collate: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + MEDIUMBLOB type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'MEDIUMBLOB({int(length)})' if length else 'MEDIUMBLOB' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), + ) + out.name = name + return out + + +def BLOB( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[bytes] = None, + collate: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + BLOB type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'BLOB({int(length)})' if length else 'BLOB' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), + ) + out.name = name + return out + + +def TINYBLOB( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[bytes] = None, + collate: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + TINYBLOB type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'TINYBLOB({int(length)})' if length else 'TINYBLOB' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), + ) + out.name = name + return out + + +def JSON( + length: Optional[int] = None, + *, + nullable: bool = True, + default: Optional[str] = None, + collate: Optional[str] = None, + charset: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + JSON type specification. + + Parameters + ---------- + length : int, optional + Maximum string length + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + collate : str, optional + Collation + charset : str, optional + Character set + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'JSON({int(length)})' if length else 'JSON' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), + ) + out.name = name + return out + + +def GEOGRAPHYPOINT( + *, + nullable: bool = True, + default: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + GEOGRAPHYPOINT type specification. + + Parameters + ---------- + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = SQLString('GEOGRAPHYPOINT' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +def GEOGRAPHY( + *, + nullable: bool = True, + default: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: + """ + GEOGRAPHYPOINT type specification. + + Parameters + ---------- + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + + Returns + ------- + str + + """ + out = SQLString('GEOGRAPHY' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out + + +# def RECORD( +# *args: Tuple[str, DataType], +# nullable: bool = True, +# name: Optional[str] = None, +# ) -> SQLString: +# """ +# RECORD type specification. +# +# Parameters +# ---------- +# *args : Tuple[str, DataType] +# Field specifications +# nullable : bool, optional +# Can the value be NULL? +# name : str, optional +# Name of the column / parameter +# +# Returns +# ------- +# SQLString +# +# """ +# assert len(args) > 0 +# fields = [] +# for name, value in args: +# if callable(value): +# fields.append(f'{escape_name(name)} {value()}') +# else: +# fields.append(f'{escape_name(name)} {value}') +# out = SQLString(f'RECORD({", ".join(fields)})' + _modifiers(nullable=nullable)) +# out.name = name +# return out + + +# def ARRAY( +# dtype: DataType, +# nullable: bool = True, +# name: Optional[str] = None, +# ) -> SQLString: +# """ +# ARRAY type specification. +# +# Parameters +# ---------- +# dtype : DataType +# The data type of the array elements +# nullable : bool, optional +# Can the value be NULL? +# name : str, optional +# Name of the column / parameter +# +# Returns +# ------- +# SQLString +# +# """ +# if callable(dtype): +# dtype = dtype() +# out = SQLString(f'ARRAY({dtype})' + _modifiers(nullable=nullable)) +# out.name = name +# return out + + +# F32 = 'F32' +# F64 = 'F64' +# I8 = 'I8' +# I16 = 'I16' +# I32 = 'I32' +# I64 = 'I64' + + +# def VECTOR( +# length: int, +# element_type: str = F32, +# *, +# nullable: bool = True, +# default: Optional[bytes] = None, +# name: Optional[str] = None, +# ) -> SQLString: +# """ +# VECTOR type specification. +# +# Parameters +# ---------- +# n : int +# Number of elements in vector +# element_type : str, optional +# Type of the elements in the vector: +# F32, F64, I8, I16, I32, I64 +# nullable : bool, optional +# Can the value be NULL? +# default : str, optional +# Default value +# name : str, optional +# Name of the column / parameter +# +# Returns +# ------- +# SQLString +# +# """ +# out = f'VECTOR({int(length)}, {element_type})' +# out = SQLString( +# out + _modifiers( +# nullable=nullable, default=default, +# ), +# ) +# out.name = name +# return out diff --git a/singlestoredb/functions/typing/__init__.py b/singlestoredb/functions/typing/__init__.py index 66bb9852..56a6e19e 100644 --- a/singlestoredb/functions/typing/__init__.py +++ b/singlestoredb/functions/typing/__init__.py @@ -1,18 +1,30 @@ -from collections.abc import Iterable +import dataclasses +import json +from collections.abc import Sequence +from typing import Annotated from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Protocol from typing import Tuple from typing import TypeVar +from typing import Union try: from typing import TypeVarTuple # type: ignore from typing import Unpack # type: ignore + from typing import TypeAlias # type: ignore except ImportError: # Python 3.8 and earlier do not have TypeVarTuple from typing_extensions import TypeVarTuple # type: ignore from typing_extensions import Unpack # type: ignore + from typing_extensions import TypeAlias # type: ignore +from .. import sql_types -T = TypeVar('T', bound=Iterable[Any]) # Generic type for iterable types +T = TypeVar('T', bound=Sequence[Any]) # Generic type for iterable types # # Masked types are used for pairs of vectors where the first element is the @@ -39,3 +51,112 @@ class Table(Tuple[Unpack[Ts]]): def __new__(cls, *args: Unpack[Ts]) -> 'Table[Tuple[Unpack[Ts]]]': # type: ignore return tuple.__new__(cls, args) # type: ignore + + +class TypedTransformer(Protocol): + + output_type: Optional[Any] = None + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + ... + + +Transformer = Union[Callable[..., Any], TypedTransformer] + + +def output_type(dtype: Any) -> Callable[..., Any]: + """ + Decorator that sets the output_type attribute on a function. + + Parameters + ---------- + dtype : Any + The data type to set as the function's output_type attribute + + Returns + ------- + Callable + The decorated function with output_type attribute set + """ + + def decorator(func: Callable[..., Any]) -> Transformer: + func.output_type = dtype # type: ignore + return func # type: ignore + + return decorator + + +@dataclasses.dataclass +class UDFAttrs: + sql_type: Optional[sql_types.SQLString] = None + args_transformer: Optional[Transformer] = None + returns_transformer: Optional[Transformer] = None + + +def json_or_null_dumps(v: Optional[Any], **kwargs: Any) -> Optional[str]: + """ + Serialize a Python object to a JSON string or None + + Parameters + ---------- + v : Optional[Any] + The Python object to serialize. If None or empty, the function returns None. + **kwargs : Any + Additional keyword arguments to pass to `json.dumps`. + + Returns + ------- + Optional[str] + The JSON string representation of the input object, + or None if the input is None or empty + + """ + if not v: + return None + return json.dumps(v, **kwargs) + + +# Force numpy dtype to 'object' to avoid issues with +# numpy trying to infer the dtype and creating multidimensional arrays +# instead of an array of Python objects. +@output_type('object') +def json_or_null_loads(v: Optional[str], **kwargs: Any) -> Optional[Any]: + """ + Deserialize a JSON string to a Python object or None + + Parameters + ---------- + v : Optional[str] + The JSON string to deserialize. If None or empty, the function returns None. + **kwargs : Any + Additional keyword arguments to pass to `json.loads`. + + Returns + ------- + Optional[Any] + The Python object represented by the JSON string, + or None if the input is None or empty + + """ + if not v: + return None + return json.loads(v, **kwargs) + + +JSON: TypeAlias = Annotated[ + Union[Dict[str, Any], List[Any], int, float, str, bool, None], + UDFAttrs( + sql_type=sql_types.JSON(nullable=False), + args_transformer=json_or_null_loads, + returns_transformer=json_or_null_dumps, + ), +] + + +__all__ = [ + 'Table', + 'Masked', + 'JSON', + 'UDFAttrs', + 'Transformer', +] diff --git a/singlestoredb/functions/typing/numpy.py b/singlestoredb/functions/typing/numpy.py index fb3954d2..d1ceed9c 100644 --- a/singlestoredb/functions/typing/numpy.py +++ b/singlestoredb/functions/typing/numpy.py @@ -1,20 +1,110 @@ +import json +from typing import Annotated +from typing import Any + import numpy as np import numpy.typing as npt +from numpy import array # noqa: F401 + +try: + from typing import TypeAlias # type: ignore +except ImportError: + from typing_extensions import TypeAlias # type: ignore + +from . import UDFAttrs +from . import json_or_null_dumps +from . import json_or_null_loads +from .. import sql_types NDArray = npt.NDArray -StringArray = StrArray = npt.NDArray[np.str_] -BytesArray = npt.NDArray[np.bytes_] -Float32Array = FloatArray = npt.NDArray[np.float32] -Float64Array = DoubleArray = npt.NDArray[np.float64] -IntArray = npt.NDArray[np.int_] -Int8Array = npt.NDArray[np.int8] -Int16Array = npt.NDArray[np.int16] -Int32Array = npt.NDArray[np.int32] -Int64Array = npt.NDArray[np.int64] -UInt8Array = npt.NDArray[np.uint8] -UInt16Array = npt.NDArray[np.uint16] -UInt32Array = npt.NDArray[np.uint32] -UInt64Array = npt.NDArray[np.uint64] -DateTimeArray = npt.NDArray[np.datetime64] -TimeDeltaArray = npt.NDArray[np.timedelta64] + +StringArray: TypeAlias = Annotated[ + npt.NDArray[np.str_], UDFAttrs(sql_type=sql_types.TEXT(nullable=False)), +] +StrArray: TypeAlias = StringArray + +BytesArray: TypeAlias = Annotated[ + npt.NDArray[np.bytes_], UDFAttrs(sql_type=sql_types.BLOB(nullable=False)), +] + +Float32Array: TypeAlias = Annotated[ + npt.NDArray[np.float32], UDFAttrs(sql_type=sql_types.FLOAT(nullable=False)), +] +FloatArray: TypeAlias = Float32Array + +Float64Array: TypeAlias = Annotated[ + npt.NDArray[np.float64], UDFAttrs(sql_type=sql_types.DOUBLE(nullable=False)), +] +DoubleArray: TypeAlias = Float64Array + +IntArray: TypeAlias = Annotated[ + npt.NDArray[np.int_], UDFAttrs(sql_type=sql_types.INT(nullable=False)), +] + +Int8Array: TypeAlias = Annotated[ + npt.NDArray[np.int8], UDFAttrs(sql_type=sql_types.TINYINT(nullable=False)), +] + +Int16Array: TypeAlias = Annotated[ + npt.NDArray[np.int16], UDFAttrs(sql_type=sql_types.SMALLINT(nullable=False)), +] + +Int32Array: TypeAlias = Annotated[ + npt.NDArray[np.int32], UDFAttrs(sql_type=sql_types.INT(nullable=False)), +] + +Int64Array: TypeAlias = Annotated[ + npt.NDArray[np.int64], UDFAttrs(sql_type=sql_types.BIGINT(nullable=False)), +] + +UInt8Array: TypeAlias = Annotated[ + npt.NDArray[np.uint8], UDFAttrs(sql_type=sql_types.TINYINT_UNSIGNED(nullable=False)), +] + +UInt16Array: TypeAlias = Annotated[ + npt.NDArray[np.uint16], + UDFAttrs(sql_type=sql_types.SMALLINT_UNSIGNED(nullable=False)), +] + +UInt32Array: TypeAlias = Annotated[ + npt.NDArray[np.uint32], UDFAttrs(sql_type=sql_types.INT_UNSIGNED(nullable=False)), +] + +UInt64Array: TypeAlias = Annotated[ + npt.NDArray[np.uint64], UDFAttrs(sql_type=sql_types.BIGINT_UNSIGNED(nullable=False)), +] + +DateTimeArray: TypeAlias = Annotated[ + npt.NDArray[np.datetime64], UDFAttrs(sql_type=sql_types.DATETIME(nullable=False)), +] + +TimeDeltaArray: TypeAlias = Annotated[ + npt.NDArray[np.timedelta64], UDFAttrs(sql_type=sql_types.TIME(nullable=False)), +] + + +class NumpyJSONEncoder(json.JSONEncoder): + """Custom JSON encoder that converts numpy scalar types to Python types.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + +JSONArray: TypeAlias = Annotated[ + npt.NDArray[np.object_], + UDFAttrs( + sql_type=sql_types.JSON(nullable=False), + args_transformer=json_or_null_loads, + returns_transformer=lambda x: json_or_null_dumps(x, cls=NumpyJSONEncoder), + ), +] + + +__all__ = ['array'] + [x for x in globals().keys() if x.endswith('Array')] diff --git a/singlestoredb/functions/typing/pandas.py b/singlestoredb/functions/typing/pandas.py index 23a662c5..7e891fcd 100644 --- a/singlestoredb/functions/typing/pandas.py +++ b/singlestoredb/functions/typing/pandas.py @@ -1,2 +1,112 @@ +import json +from typing import Annotated +from typing import Any + +import numpy as np +import pandas as pd from pandas import DataFrame # noqa: F401 from pandas import Series # noqa: F401 + +try: + from typing import TypeAlias # type: ignore +except ImportError: + from typing_extensions import TypeAlias # type: ignore + +from . import UDFAttrs +from . import json_or_null_dumps +from . import json_or_null_loads +from .. import sql_types + + +StringSeries: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.TEXT(nullable=False)), +] +StrSeries: TypeAlias = StringSeries + +BytesSeries: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.BLOB(nullable=False)), +] + +Float32Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.FLOAT(nullable=False)), +] +FloatSeries: TypeAlias = Float32Series + +Float64Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.DOUBLE(nullable=False)), +] +DoubleSeries: TypeAlias = Float64Series + +IntSeries: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.INT(nullable=False)), +] + +Int8Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.TINYINT(nullable=False)), +] + +Int16Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.SMALLINT(nullable=False)), +] + +Int32Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.INT(nullable=False)), +] + +Int64Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.BIGINT(nullable=False)), +] + +UInt8Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.TINYINT_UNSIGNED(nullable=False)), +] + +UInt16Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.SMALLINT_UNSIGNED(nullable=False)), +] + +UInt32Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.INT_UNSIGNED(nullable=False)), +] + +UInt64Series: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.BIGINT_UNSIGNED(nullable=False)), +] + +DateTimeSeries: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.DATETIME(nullable=False)), +] + +TimeSeries: TypeAlias = Annotated[ + pd.Series, UDFAttrs(sql_type=sql_types.TIME(nullable=False)), +] + + +class PandasJSONEncoder(json.JSONEncoder): + """Custom JSON encoder that handles pandas Series and numpy scalar types.""" + + def default(self, obj: Any) -> Any: + if hasattr(obj, 'dtype') and hasattr(obj, 'tolist'): + # Handle pandas Series and numpy arrays + return obj.tolist() + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif hasattr(obj, 'item'): + # Handle pandas scalar types + return obj.item() + return super().default(obj) + + +JSONSeries: TypeAlias = Annotated[ + pd.Series, + UDFAttrs( + sql_type=sql_types.JSON(nullable=False), + args_transformer=json_or_null_loads, + returns_transformer=lambda x: json_or_null_dumps(x, cls=PandasJSONEncoder), + ), +] + + +__all__ = ['DataFrame'] + [x for x in globals().keys() if x.endswith('Series')] diff --git a/singlestoredb/functions/typing/polars.py b/singlestoredb/functions/typing/polars.py index d7556a1e..6a602dba 100644 --- a/singlestoredb/functions/typing/polars.py +++ b/singlestoredb/functions/typing/polars.py @@ -1,2 +1,118 @@ +import json +from typing import Annotated +from typing import Any + +import polars as pl from polars import DataFrame # noqa: F401 from polars import Series # noqa: F401 + +try: + from typing import TypeAlias # type: ignore +except ImportError: + from typing_extensions import TypeAlias # type: ignore + +from . import UDFAttrs +from . import json_or_null_dumps +from . import json_or_null_loads +from .. import sql_types + + +StringSeries: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.TEXT(nullable=False)), +] +StrSeries: TypeAlias = StringSeries + +BytesSeries: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.BLOB(nullable=False)), +] + +Float32Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.FLOAT(nullable=False)), +] +FloatSeries: TypeAlias = Float32Series + +Float64Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.DOUBLE(nullable=False)), +] +DoubleSeries: TypeAlias = Float64Series + +IntSeries: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.INT(nullable=False)), +] + +Int8Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.TINYINT(nullable=False)), +] + +Int16Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.SMALLINT(nullable=False)), +] + +Int32Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.INT(nullable=False)), +] + +Int64Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.BIGINT(nullable=False)), +] + +UInt8Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.TINYINT_UNSIGNED(nullable=False)), +] + +UInt16Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.SMALLINT_UNSIGNED(nullable=False)), +] + +UInt32Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.INT_UNSIGNED(nullable=False)), +] + +UInt64Series: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.BIGINT_UNSIGNED(nullable=False)), +] + +DateTimeSeries: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.DATETIME(nullable=False)), +] + +TimeDeltaSeries: TypeAlias = Annotated[ + pl.Series, UDFAttrs(sql_type=sql_types.TIME(nullable=False)), +] + + +class PolarsJSONEncoder(json.JSONEncoder): + """Custom JSON encoder that converts Polars Series / scalar types to Python types.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, pl.Series): + # Convert Polars Series to Python list + return obj.to_list() + elif hasattr(obj, 'dtype') and \ + str(obj.dtype).startswith(('Int', 'UInt', 'Float')): + # Handle Polars scalar integer and float types + return obj.item() if hasattr(obj, 'item') else obj + elif isinstance( + obj, ( + pl.datatypes.Int8, pl.datatypes.Int16, pl.datatypes.Int32, + pl.datatypes.Int64, pl.datatypes.UInt8, pl.datatypes.UInt16, + pl.datatypes.UInt32, pl.datatypes.UInt64, + ), + ): + return int(obj) + elif isinstance(obj, (pl.datatypes.Float32, pl.datatypes.Float64)): + return float(obj) + return super().default(obj) + + +JSONSeries: TypeAlias = Annotated[ + pl.Series, + UDFAttrs( + sql_type=sql_types.JSON(nullable=False), + args_transformer=json_or_null_loads, + returns_transformer=lambda x: json_or_null_dumps(x, cls=PolarsJSONEncoder), + ), +] + + +__all__ = ['DataFrame'] + [x for x in globals().keys() if x.endswith('Series')] diff --git a/singlestoredb/functions/typing/pyarrow.py b/singlestoredb/functions/typing/pyarrow.py index 7c7fce94..64f57ba9 100644 --- a/singlestoredb/functions/typing/pyarrow.py +++ b/singlestoredb/functions/typing/pyarrow.py @@ -1,2 +1,116 @@ +import json +from typing import Annotated +from typing import Any + +import pyarrow as pa from pyarrow import Array # noqa: F401 +from pyarrow import array # noqa: F401 from pyarrow import Table # noqa: F401 + +try: + from typing import TypeAlias # type: ignore +except ImportError: + from typing_extensions import TypeAlias # type: ignore + +from . import UDFAttrs +from . import json_or_null_dumps +from . import json_or_null_loads # noqa: F401 +from .. import sql_types + + +StringArray: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.TEXT(nullable=True)), +] +StrArray: TypeAlias = StringArray + +BytesArray: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.BLOB(nullable=True)), +] + +Float32Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.FLOAT(nullable=True)), +] +FloatArray: TypeAlias = Float32Array + +Float64Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.DOUBLE(nullable=True)), +] +DoubleArray: TypeAlias = Float64Array + +IntArray: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.INT(nullable=True)), +] + +Int8Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.TINYINT(nullable=True)), +] + +Int16Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.SMALLINT(nullable=True)), +] + +Int32Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.INT(nullable=True)), +] + +Int64Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.BIGINT(nullable=True)), +] + +UInt8Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.TINYINT_UNSIGNED(nullable=True)), +] + +UInt16Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.SMALLINT_UNSIGNED(nullable=True)), +] + +UInt32Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.INT_UNSIGNED(nullable=True)), +] + +UInt64Array: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.BIGINT_UNSIGNED(nullable=True)), +] + +DateTimeArray: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.DATETIME(nullable=True)), +] + +TimeDeltaArray: TypeAlias = Annotated[ + pa.Array, UDFAttrs(sql_type=sql_types.TIME(nullable=True)), +] + + +class PyArrowJSONEncoder(json.JSONEncoder): + """Custom JSON encoder that converts PyArrow scalar types to Python types.""" + + def default(self, obj: Any) -> Any: + if hasattr(obj, 'as_py'): + # Handle PyArrow scalar types (including individual ints and floats) + return obj.as_py() + elif isinstance(obj, pa.Array): + # Convert PyArrow Array to Python list + return obj.to_pylist() + elif isinstance(obj, pa.Table): + # Convert PyArrow Table to list of dictionaries + return obj.to_pydict() + return super().default(obj) + +# +# NOTE: We don't use input_transformer=json.loads because it doesn't handle +# all cases (e.g., when the input is already a dict/list). +# + + +JSONArray: TypeAlias = Annotated[ + pa.Array, + UDFAttrs( + sql_type=sql_types.JSON(nullable=True), + # input_transformer=json_or_null_loads, + returns_transformer=lambda x: json_or_null_dumps(x, cls=PyArrowJSONEncoder), + ), +] + + +__all__ = ['Table', 'array'] + [x for x in globals().keys() if x.endswith('Array')] diff --git a/singlestoredb/pytest.py b/singlestoredb/pytest.py index 22efd68d..b18db9a3 100644 --- a/singlestoredb/pytest.py +++ b/singlestoredb/pytest.py @@ -165,8 +165,8 @@ def start(self) -> None: command = ' '.join(self._start_command()) logger.info( - f'Starting container {self.container_name} on ports {self.mysql_port}, ' - f'{self.http_port}, {self.studio_port}', + f'Starting container {self.container_name} on ports ' + f'{self.mysql_port}, {self.http_port}, {self.studio_port}', ) try: license = os.environ['SINGLESTORE_LICENSE'] diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index 0b346db2..bb376bbb 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # mypy: disable-error-code="type-arg" import asyncio +import json import time import typing from typing import List @@ -9,25 +10,111 @@ from typing import Tuple import numpy as np +import pyarrow.compute as pc -import singlestoredb.functions.dtypes as dt +import singlestoredb.functions.sql_types as dt from singlestoredb.functions import Masked from singlestoredb.functions import Table from singlestoredb.functions import udf -from singlestoredb.functions.dtypes import BIGINT -from singlestoredb.functions.dtypes import BLOB -from singlestoredb.functions.dtypes import DOUBLE -from singlestoredb.functions.dtypes import FLOAT -from singlestoredb.functions.dtypes import MEDIUMINT -from singlestoredb.functions.dtypes import SMALLINT -from singlestoredb.functions.dtypes import TEXT -from singlestoredb.functions.dtypes import TINYINT +from singlestoredb.functions.sql_types import BIGINT +from singlestoredb.functions.sql_types import BLOB +from singlestoredb.functions.sql_types import DOUBLE +from singlestoredb.functions.sql_types import FLOAT +from singlestoredb.functions.sql_types import MEDIUMINT +from singlestoredb.functions.sql_types import SMALLINT +from singlestoredb.functions.sql_types import TEXT +from singlestoredb.functions.sql_types import TINYINT +from singlestoredb.functions.typing import JSON from singlestoredb.functions.typing import numpy as npt from singlestoredb.functions.typing import pandas as pdt from singlestoredb.functions.typing import polars as plt from singlestoredb.functions.typing import pyarrow as pat +@udf +def add(x: int, y: int) -> int: + """ + Add two integers. + + Parameters + ---------- + x : int + First integer. + y : int + Second integer. + + Returns + ------- + int + Sum of x and y. + + """ + return x + y + + +@udf +def add_vec(x: npt.Int64Array, y: npt.Int64Array) -> npt.Int64Array: + """ + Add two numpy arrays of int64. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise sum of x and y. + + """ + return x + y + + +@udf +async def async_add(x: int, y: int) -> int: + """ + Asynchronously add two integers. + + Parameters + ---------- + x : int + First integer. + y : int + Second integer. + + Returns + ------- + int + Sum of x and y. + + """ + return x + y + + +@udf +async def async_add_vec_vec(x: npt.Int64Array, y: npt.Int64Array) -> npt.Int64Array: + """ + Asynchronously add two numpy arrays of int64. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise sum of x and y. + + """ + return x + y + + @udf def doc_test(x: int, y: float) -> int: """ @@ -69,27 +156,107 @@ def doc_test(x: int, y: float) -> int: @udf def int_mult(x: int, y: int) -> int: + """ + Multiply two integers. + + Parameters + ---------- + x : int + First integer. + y : int + Second integer. + + Returns + ------- + int + Product of x and y. + + """ return x * y @udf def double_mult(x: float, y: float) -> float: + """ + Multiply two floats. + + Parameters + ---------- + x : float + First float. + y : float + Second float. + + Returns + ------- + float + Product of x and y. + + """ return x * y @udf(timeout=2) def timeout_double_mult(x: float, y: float) -> float: + """ + Multiply two floats after a delay. + + Parameters + ---------- + x : float + First float. + y : float + Second float. + + Returns + ------- + float + Product of x and y. + + """ time.sleep(5) return x * y @udf async def async_double_mult(x: float, y: float) -> float: + """ + Asynchronously multiply two floats. + + Parameters + ---------- + x : float + First float. + y : float + Second float. + + Returns + ------- + float + Product of x and y. + + """ return x * y @udf(timeout=2) async def async_timeout_double_mult(x: float, y: float) -> float: + """ + Asynchronously multiply two floats after a delay. + + Parameters + ---------- + x : float + First float. + y : float + Second float. + + Returns + ------- + float + Product of x and y. + + """ await asyncio.sleep(5) return x * y @@ -99,6 +266,22 @@ async def async_timeout_double_mult(x: float, y: float) -> float: returns=DOUBLE(nullable=False), ) def pandas_double_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: + """ + Multiply two pandas Series of floats. + + Parameters + ---------- + x : pandas.Series + First series. + y : pandas.Series + Second series. + + Returns + ------- + pandas.Series + Elementwise product of x and y. + + """ return x * y @@ -107,6 +290,22 @@ def numpy_double_mult( x: npt.Float64Array, y: npt.Float64Array, ) -> npt.Float64Array: + """ + Multiply two numpy arrays of float64. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @@ -115,6 +314,22 @@ async def async_numpy_double_mult( x: npt.Float64Array, y: npt.Float64Array, ) -> npt.Float64Array: + """ + Asynchronously multiply two numpy arrays of float64. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @@ -123,6 +338,22 @@ async def async_numpy_double_mult( returns=DOUBLE(nullable=False), ) def arrow_double_mult(x: pat.Array, y: pat.Array) -> pat.Array: + """ + Multiply two pyarrow arrays of doubles. + + Parameters + ---------- + x : pyarrow.Array + First array. + y : pyarrow.Array + Second array. + + Returns + ------- + pyarrow.Array + Elementwise product of x and y. + + """ import pyarrow.compute as pc return pc.multiply(x, y) @@ -132,11 +363,43 @@ def arrow_double_mult(x: pat.Array, y: pat.Array) -> pat.Array: returns=DOUBLE(nullable=False), ) def polars_double_mult(x: plt.Series, y: plt.Series) -> plt.Series: + """ + Multiply two polars Series of doubles. + + Parameters + ---------- + x : polars.Series + First series. + y : polars.Series + Second series. + + Returns + ------- + polars.Series + Elementwise product of x and y. + + """ return x * y @udf def nullable_double_mult(x: Optional[float], y: Optional[float]) -> Optional[float]: + """ + Multiply two optional floats, returning None if either is None. + + Parameters + ---------- + x : float or None + First value. + y : float or None + Second value. + + Returns + ------- + float or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -144,11 +407,43 @@ def nullable_double_mult(x: Optional[float], y: Optional[float]) -> Optional[flo @udf(args=[FLOAT(nullable=False), FLOAT(nullable=False)], returns=FLOAT(nullable=False)) def float_mult(x: float, y: float) -> float: + """ + Multiply two floats. + + Parameters + ---------- + x : float + First float. + y : float + Second float. + + Returns + ------- + float + Product of x and y. + + """ return x * y @udf(args=[FLOAT(nullable=True), FLOAT(nullable=True)], returns=FLOAT(nullable=True)) def nullable_float_mult(x: Optional[float], y: Optional[float]) -> Optional[float]: + """ + Multiply two optional floats, returning None if either is None. + + Parameters + ---------- + x : float or None + First value. + y : float or None + Second value. + + Returns + ------- + float or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -166,6 +461,22 @@ def nullable_float_mult(x: Optional[float], y: Optional[float]) -> Optional[floa @tinyint_udf def tinyint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + """ + Multiply two optional tinyints, returning None if either is None. + + Parameters + ---------- + x : int or None + First value. + y : int or None + Second value. + + Returns + ------- + int or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -173,21 +484,85 @@ def tinyint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @tinyint_udf def pandas_tinyint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: + """ + Multiply two pandas Series of tinyints. + + Parameters + ---------- + x : pandas.Series + First series. + y : pandas.Series + Second series. + + Returns + ------- + pandas.Series + Elementwise product of x and y. + + """ return x * y @tinyint_udf def polars_tinyint_mult(x: plt.Series, y: plt.Series) -> plt.Series: + """ + Multiply two polars Series of tinyints. + + Parameters + ---------- + x : polars.Series + First series. + y : polars.Series + Second series. + + Returns + ------- + polars.Series + Elementwise product of x and y. + + """ return x * y @tinyint_udf def numpy_tinyint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Multiply two numpy arrays of tinyints. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @tinyint_udf def arrow_tinyint_mult(x: pat.Array, y: pat.Array) -> pat.Array: + """ + Multiply two pyarrow arrays of tinyints. + + Parameters + ---------- + x : pyarrow.Array + First array. + y : pyarrow.Array + Second array. + + Returns + ------- + pyarrow.Array + Elementwise product of x and y. + + """ import pyarrow.compute as pc return pc.multiply(x, y) @@ -204,6 +579,22 @@ def arrow_tinyint_mult(x: pat.Array, y: pat.Array) -> pat.Array: @smallint_udf def smallint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + """ + Multiply two optional smallints, returning None if either is None. + + Parameters + ---------- + x : int or None + First value. + y : int or None + Second value. + + Returns + ------- + int or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -211,23 +602,87 @@ def smallint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @smallint_udf def pandas_smallint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: + """ + Multiply two pandas Series of smallints. + + Parameters + ---------- + x : pandas.Series + First series. + y : pandas.Series + Second series. + + Returns + ------- + pandas.Series + Elementwise product of x and y. + + """ return x * y @smallint_udf def polars_smallint_mult(x: plt.Series, y: plt.Series) -> plt.Series: - return x * y + """ + Multiply two polars Series of smallints. + Parameters + ---------- + x : polars.Series + First series. + y : polars.Series + Second series. -@smallint_udf -def numpy_smallint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + Returns + ------- + polars.Series + Elementwise product of x and y. + + """ return x * y @smallint_udf -def arrow_smallint_mult(x: pat.Array, y: pat.Array) -> pat.Array: - import pyarrow.compute as pc - return pc.multiply(x, y) +def numpy_smallint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Multiply two numpy arrays of smallints. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ + return x * y + + +@smallint_udf +def arrow_smallint_mult(x: pat.Array, y: pat.Array) -> pat.Array: + """ + Multiply two pyarrow arrays of smallints. + + Parameters + ---------- + x : pyarrow.Array + First array. + y : pyarrow.Array + Second array. + + Returns + ------- + pyarrow.Array + Elementwise product of x and y. + + """ + import pyarrow.compute as pc + return pc.multiply(x, y) # @@ -243,6 +698,22 @@ def arrow_smallint_mult(x: pat.Array, y: pat.Array) -> pat.Array: @mediumint_udf def mediumint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + """ + Multiply two optional mediumints, returning None if either is None. + + Parameters + ---------- + x : int or None + First value. + y : int or None + Second value. + + Returns + ------- + int or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -250,21 +721,85 @@ def mediumint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @mediumint_udf def pandas_mediumint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: + """ + Multiply two pandas Series of mediumints. + + Parameters + ---------- + x : pandas.Series + First series. + y : pandas.Series + Second series. + + Returns + ------- + pandas.Series + Elementwise product of x and y. + + """ return x * y @mediumint_udf def polars_mediumint_mult(x: plt.Series, y: plt.Series) -> plt.Series: + """ + Multiply two polars Series of mediumints. + + Parameters + ---------- + x : polars.Series + First series. + y : polars.Series + Second series. + + Returns + ------- + polars.Series + Elementwise product of x and y. + + """ return x * y @mediumint_udf def numpy_mediumint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Multiply two numpy arrays of mediumints. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @mediumint_udf def arrow_mediumint_mult(x: pat.Array, y: pat.Array) -> pat.Array: + """ + Multiply two pyarrow arrays of mediumints. + + Parameters + ---------- + x : pyarrow.Array + First array. + y : pyarrow.Array + Second array. + + Returns + ------- + pyarrow.Array + Elementwise product of x and y. + + """ import pyarrow.compute as pc return pc.multiply(x, y) @@ -282,6 +817,22 @@ def arrow_mediumint_mult(x: pat.Array, y: pat.Array) -> pat.Array: @bigint_udf def bigint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + """ + Multiply two optional bigints, returning None if either is None. + + Parameters + ---------- + x : int or None + First value. + y : int or None + Second value. + + Returns + ------- + int or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -289,21 +840,85 @@ def bigint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @bigint_udf def pandas_bigint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: + """ + Multiply two pandas Series of bigints. + + Parameters + ---------- + x : pandas.Series + First series. + y : pandas.Series + Second series. + + Returns + ------- + pandas.Series + Elementwise product of x and y. + + """ return x * y @bigint_udf def polars_bigint_mult(x: plt.Series, y: plt.Series) -> plt.Series: + """ + Multiply two polars Series of bigints. + + Parameters + ---------- + x : polars.Series + First series. + y : polars.Series + Second series. + + Returns + ------- + polars.Series + Elementwise product of x and y. + + """ return x * y @bigint_udf def numpy_bigint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Multiply two numpy arrays of bigints. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @bigint_udf def arrow_bigint_mult(x: pat.Array, y: pat.Array) -> pat.Array: + """ + Multiply two pyarrow arrays of bigints. + + Parameters + ---------- + x : pyarrow.Array + First array. + y : pyarrow.Array + Second array. + + Returns + ------- + pyarrow.Array + Elementwise product of x and y. + + """ import pyarrow.compute as pc return pc.multiply(x, y) @@ -321,6 +936,22 @@ def arrow_bigint_mult(x: pat.Array, y: pat.Array) -> pat.Array: @nullable_tinyint_udf def nullable_tinyint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + """ + Multiply two optional tinyints, returning None if either is None. + + Parameters + ---------- + x : int or None + First value. + y : int or None + Second value. + + Returns + ------- + int or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -328,21 +959,85 @@ def nullable_tinyint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @nullable_tinyint_udf def pandas_nullable_tinyint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: + """ + Multiply two pandas Series of nullable tinyints. + + Parameters + ---------- + x : pandas.Series + First series. + y : pandas.Series + Second series. + + Returns + ------- + pandas.Series + Elementwise product of x and y. + + """ return x * y @nullable_tinyint_udf def polars_nullable_tinyint_mult(x: plt.Series, y: plt.Series) -> plt.Series: + """ + Multiply two polars Series of nullable tinyints. + + Parameters + ---------- + x : polars.Series + First series. + y : polars.Series + Second series. + + Returns + ------- + polars.Series + Elementwise product of x and y. + + """ return x * y @nullable_tinyint_udf def numpy_nullable_tinyint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Multiply two numpy arrays of nullable tinyints. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @nullable_tinyint_udf def arrow_nullable_tinyint_mult(x: pat.Array, y: pat.Array) -> pat.Array: + """ + Multiply two pyarrow arrays of nullable tinyints. + + Parameters + ---------- + x : pyarrow.Array + First array. + y : pyarrow.Array + Second array. + + Returns + ------- + pyarrow.Array + Elementwise product of x and y. + + """ import pyarrow.compute as pc return pc.multiply(x, y) @@ -359,6 +1054,22 @@ def arrow_nullable_tinyint_mult(x: pat.Array, y: pat.Array) -> pat.Array: @nullable_smallint_udf def nullable_smallint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + """ + Multiply two optional smallints, returning None if either is None. + + Parameters + ---------- + x : int or None + First value. + y : int or None + Second value. + + Returns + ------- + int or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -366,21 +1077,85 @@ def nullable_smallint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @nullable_smallint_udf def pandas_nullable_smallint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: + """ + Multiply two pandas Series of nullable smallints. + + Parameters + ---------- + x : pandas.Series + First series. + y : pandas.Series + Second series. + + Returns + ------- + pandas.Series + Elementwise product of x and y. + + """ return x * y @nullable_smallint_udf def polars_nullable_smallint_mult(x: plt.Series, y: plt.Series) -> plt.Series: + """ + Multiply two polars Series of nullable smallints. + + Parameters + ---------- + x : polars.Series + First series. + y : polars.Series + Second series. + + Returns + ------- + polars.Series + Elementwise product of x and y. + + """ return x * y @nullable_smallint_udf def numpy_nullable_smallint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Multiply two numpy arrays of nullable smallints. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @nullable_smallint_udf def arrow_nullable_smallint_mult(x: pat.Array, y: pat.Array) -> pat.Array: + """ + Multiply two pyarrow arrays of nullable smallints. + + Parameters + ---------- + x : pyarrow.Array + First array. + y : pyarrow.Array + Second array. + + Returns + ------- + pyarrow.Array + Elementwise product of x and y. + + """ import pyarrow.compute as pc return pc.multiply(x, y) @@ -398,6 +1173,22 @@ def arrow_nullable_smallint_mult(x: pat.Array, y: pat.Array) -> pat.Array: @nullable_mediumint_udf def nullable_mediumint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + """ + Multiply two optional mediumints, returning None if either is None. + + Parameters + ---------- + x : int or None + First value. + y : int or None + Second value. + + Returns + ------- + int or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -405,21 +1196,85 @@ def nullable_mediumint_mult(x: Optional[int], y: Optional[int]) -> Optional[int] @nullable_mediumint_udf def pandas_nullable_mediumint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: + """ + Multiply two pandas Series of nullable mediumints. + + Parameters + ---------- + x : pandas.Series + First series. + y : pandas.Series + Second series. + + Returns + ------- + pandas.Series + Elementwise product of x and y. + + """ return x * y @nullable_mediumint_udf def polars_nullable_mediumint_mult(x: plt.Series, y: plt.Series) -> plt.Series: + """ + Multiply two polars Series of nullable mediumints. + + Parameters + ---------- + x : polars.Series + First series. + y : polars.Series + Second series. + + Returns + ------- + polars.Series + Elementwise product of x and y. + + """ return x * y @nullable_mediumint_udf def numpy_nullable_mediumint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Multiply two numpy arrays of nullable mediumints. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @nullable_mediumint_udf def arrow_nullable_mediumint_mult(x: pat.Array, y: pat.Array) -> pat.Array: + """ + Multiply two pyarrow arrays of nullable mediumints. + + Parameters + ---------- + x : pyarrow.Array + First array. + y : pyarrow.Array + Second array. + + Returns + ------- + pyarrow.Array + Elementwise product of x and y. + + """ import pyarrow.compute as pc return pc.multiply(x, y) @@ -437,6 +1292,22 @@ def arrow_nullable_mediumint_mult(x: pat.Array, y: pat.Array) -> pat.Array: @nullable_bigint_udf def nullable_bigint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + """ + Multiply two optional bigints, returning None if either is None. + + Parameters + ---------- + x : int or None + First value. + y : int or None + Second value. + + Returns + ------- + int or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -444,27 +1315,107 @@ def nullable_bigint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @nullable_bigint_udf def pandas_nullable_bigint_mult(x: pdt.Series, y: pdt.Series) -> pdt.Series: + """ + Multiply two pandas Series of nullable bigints. + + Parameters + ---------- + x : pandas.Series + First series. + y : pandas.Series + Second series. + + Returns + ------- + pandas.Series + Elementwise product of x and y. + + """ return x * y @nullable_bigint_udf def polars_nullable_bigint_mult(x: plt.Series, y: plt.Series) -> plt.Series: + """ + Multiply two polars Series of nullable bigints. + + Parameters + ---------- + x : polars.Series + First series. + y : polars.Series + Second series. + + Returns + ------- + polars.Series + Elementwise product of x and y. + + """ return x * y @nullable_bigint_udf def numpy_nullable_bigint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Multiply two numpy arrays of nullable bigints. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @nullable_bigint_udf def arrow_nullable_bigint_mult(x: pat.Array, y: pat.Array) -> pat.Array: + """ + Multiply two pyarrow arrays of nullable bigints. + + Parameters + ---------- + x : pyarrow.Array + First array. + y : pyarrow.Array + Second array. + + Returns + ------- + pyarrow.Array + Elementwise product of x and y. + + """ import pyarrow.compute as pc return pc.multiply(x, y) @udf def nullable_int_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + """ + Multiply two optional integers, returning None if either is None. + + Parameters + ---------- + x : int or None + First value. + y : int or None + Second value. + + Returns + ------- + int or None + Product of x and y, or None. + + """ if x is None or y is None: return None return x * y @@ -472,11 +1423,46 @@ def nullable_int_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: @udf def string_mult(x: str, times: int) -> str: + """ + Repeat a string a given number of times. + + Parameters + ---------- + x : str + String to repeat. + times : int + Number of repetitions. + + Returns + ------- + str + Repeated string. + + """ return x * times @udf(args=[TEXT(nullable=False), BIGINT(nullable=False)], returns=TEXT(nullable=False)) def pandas_string_mult(x: pdt.Series, times: pdt.Series) -> pdt.Series: + """ + Repeat each string in a pandas Series a given number of times. + + pandas Series do not support annotated element types, so the `args` + parameter of the `@udf` decorator is used to specify the types. + + Parameters + ---------- + x : pandas.Series + Series of strings. + times : pandas.Series + Series of repetition counts. + + Returns + ------- + pandas.Series + Series of repeated strings. + + """ return x * times @@ -484,6 +1470,24 @@ def pandas_string_mult(x: pdt.Series, times: pdt.Series) -> pdt.Series: def numpy_string_mult( x: npt.NDArray[np.str_], times: npt.NDArray[np.int_], ) -> npt.NDArray[np.str_]: + """ + Repeat each string in a numpy array a given number of times. + + Numpy arrays can be used to specify the element types of vector inputs. + + Parameters + ---------- + x : numpy.ndarray + Array of strings. + times : numpy.ndarray + Array of repetition counts. + + Returns + ------- + numpy.ndarray + Array of repeated strings. + + """ return x * times @@ -503,6 +1507,22 @@ def numpy_string_mult( @udf def nullable_string_mult(x: Optional[str], times: Optional[int]) -> Optional[str]: + """ + Repeat a string a given number of times, returning None if either is None. + + Parameters + ---------- + x : str or None + String to repeat. + times : int or None + Number of repetitions. + + Returns + ------- + str or None + Repeated string or None. + + """ if x is None or times is None: return None return x * times @@ -515,6 +1535,31 @@ def nullable_string_mult(x: Optional[str], times: Optional[int]) -> Optional[str def pandas_nullable_tinyint_mult_with_masks( x: Masked[pdt.Series], y: Masked[pdt.Series], ) -> Masked[pdt.Series]: + """ + Multiply two masked pandas Series of nullable tinyints. + + Masks are used to represent null values in vector inputs and + outputs that do not natively support nulls. Each parameter + wrapped by the `Masked` type is a tuple of (data, nulls), + where `data` is the original type and `nulls` is a boolean + array indicating which elements are null. + + The return value is also wrapped in `Masked`, with the + returned vectors being represented by a tuple of (data, nulls). + + Parameters + ---------- + x : Masked[pandas.Series] + First masked series. + y : Masked[pandas.Series] + Second masked series. + + Returns + ------- + Masked[pandas.Series] + Masked elementwise product of x and y. + + """ x_data, x_nulls = x y_data, y_nulls = y return Masked(x_data * y_data, x_nulls | y_nulls) @@ -524,6 +1569,31 @@ def pandas_nullable_tinyint_mult_with_masks( def numpy_nullable_tinyint_mult_with_masks( x: Masked[npt.NDArray[np.int8]], y: Masked[npt.NDArray[np.int8]], ) -> Masked[npt.NDArray[np.int8]]: + """ + Multiply two masked numpy arrays of nullable tinyints. + + Masks are used to represent null values in vector inputs and + outputs that do not natively support nulls. Each parameter + wrapped by the `Masked` type is a tuple of (data, nulls), + where `data` is the original type and `nulls` is a boolean + array indicating which elements are null. + + The return value is also wrapped in `Masked`, with the + returned vectors being represented by a tuple of (data, nulls). + + Parameters + ---------- + x : Masked[numpy.ndarray] + First masked array. + y : Masked[numpy.ndarray] + Second masked array. + + Returns + ------- + Masked[numpy.ndarray] + Masked elementwise product of x and y. + + """ x_data, x_nulls = x y_data, y_nulls = y return Masked(x_data * y_data, x_nulls | y_nulls) @@ -536,6 +1606,32 @@ def numpy_nullable_tinyint_mult_with_masks( def polars_nullable_tinyint_mult_with_masks( x: Masked[plt.Series], y: Masked[plt.Series], ) -> Masked[plt.Series]: + """ + Multiply two masked polars Series of nullable tinyints. + + This function demonstrates how to handle masks in polars Series, + which do not natively support nulls. Each parameter wrapped by the + `Masked` type is a tuple of (data, nulls), where `data` is the + original polars Series and `nulls` is a boolean Series indicating + which elements are null. + + In addition, the element types of the polars Series are annotated + using the `args` parameter of the `@udf` decorator, since polars + Series do not support annotated element types directly. + + Parameters + ---------- + x : Masked[polars.Series] + First masked series. + y : Masked[polars.Series] + Second masked series. + + Returns + ------- + Masked[polars.Series] + Masked elementwise product of x and y. + + """ x_data, x_nulls = x y_data, y_nulls = y return Masked(x_data * y_data, x_nulls | y_nulls) @@ -548,6 +1644,22 @@ def polars_nullable_tinyint_mult_with_masks( def arrow_nullable_tinyint_mult_with_masks( x: Masked[pat.Array], y: Masked[pat.Array], ) -> Masked[pat.Array]: + """ + Multiply two masked pyarrow arrays of nullable tinyints. + + Parameters + ---------- + x : Masked[pyarrow.Array] + First masked array. + y : Masked[pyarrow.Array] + Second masked array. + + Returns + ------- + Masked[pyarrow.Array] + Masked elementwise product of x and y. + + """ import pyarrow.compute as pc x_data, x_nulls = x y_data, y_nulls = y @@ -556,6 +1668,22 @@ def arrow_nullable_tinyint_mult_with_masks( @udf(returns=[TEXT(nullable=False, name='res')]) def numpy_fixed_strings() -> Table[npt.StrArray]: + """ + Return a table of fixed-length numpy strings. + + Table-valued functions must use a `Table` annotation which + encapsulates the return type. The return type can be one or + more vectors. If the return type is more than one native + Python type, they must be wrapped in a `List[Tuple[...]]`. + + The return value mult also be wrapped in a `Table` instance. + + Returns + ------- + Table[numpy.ndarray] + Table containing fixed-length strings. + + """ out = np.array( [ 'hello', @@ -569,6 +1697,15 @@ def numpy_fixed_strings() -> Table[npt.StrArray]: @udf(returns=[TEXT(nullable=False, name='res'), TINYINT(nullable=False, name='res2')]) def numpy_fixed_strings_2() -> Table[npt.StrArray, npt.Int8Array]: + """ + Return a table of fixed-length numpy strings and int8s. + + Returns + ------- + Table[numpy.ndarray, numpy.ndarray] + Table containing fixed-length strings and int8s. + + """ out = np.array( [ 'hello', @@ -582,6 +1719,15 @@ def numpy_fixed_strings_2() -> Table[npt.StrArray, npt.Int8Array]: @udf(returns=[BLOB(nullable=False, name='res')]) def numpy_fixed_binary() -> Table[npt.BytesArray]: + """ + Return a table of fixed-length numpy binary strings. + + Returns + ------- + Table[numpy.ndarray] + Table containing fixed-length binary strings. + + """ out = np.array( [ 'hello'.encode('utf8'), @@ -595,16 +1741,47 @@ def numpy_fixed_binary() -> Table[npt.BytesArray]: @udf def no_args_no_return_value() -> None: + """Function with no arguments and no return value.""" pass @udf def table_function(n: int) -> Table[List[int]]: + """ + Return a table of n tens. + + When returning native Python types from a table-valued function, + the return type must be wrapped in a `Table[List[...]]` annotation. + + Parameters + ---------- + n : int + Number of tens. + + Returns + ------- + Table[List[int]] + Table containing n tens. + + """ return Table([10] * n) @udf async def async_table_function(n: int) -> Table[List[int]]: + """ + Asynchronously return a table of n tens. + + Parameters + ---------- + n : int + Number of tens. + + Returns + ------- + Table[List[int]] + Table containing n tens. + """ return Table([10] * n) @@ -616,6 +1793,23 @@ async def async_table_function(n: int) -> Table[List[int]]: ], ) def table_function_tuple(n: int) -> Table[List[Tuple[int, float, str]]]: + """ + Return a table of tuples (int, float, str). + + To return multiple native Python types from a table-valued function, + the return type must be wrapped in a `Table[List[Tuple[...]]]` annotation + + Parameters + ---------- + n : int + Number of tuples. + + Returns + ------- + Table[List[Tuple[int, float, str]]] + Table containing n tuples. + + """ return Table([(10, 10.0, 'ten')] * n) @@ -627,6 +1821,25 @@ class MyTable(NamedTuple): @udf def table_function_struct(n: int) -> Table[List[MyTable]]: + """ + Return a table of MyTable namedtuples. + + Multiple return values can also be represented using + a NamedTuple, pydantic model, or dataclass. Each + field of the NamedTuple, pydantic model, or dataclass + will be mapped to a column in the returned table. + + Parameters + ---------- + n : int + Number of tuples. + + Returns + ------- + Table[List[MyTable]] + Table containing n MyTable tuples. + + """ return Table([MyTable(10, 10.0, 'ten')] * n) @@ -634,6 +1847,22 @@ def table_function_struct(n: int) -> Table[List[MyTable]]: def vec_function( x: npt.Float64Array, y: npt.Float64Array, ) -> npt.Float64Array: + """ + Multiply two numpy arrays of float64. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @@ -650,6 +1879,28 @@ class VecOutputs(typing.NamedTuple): def vec_function_ints( x: npt.IntArray, y: npt.IntArray, ) -> npt.IntArray: + """ + Multiply two numpy arrays of int. + + You can specify the types of arguments or return values + using a NamedTuple, pydantic model, or dataclass. This is + especially useful for vector inputs and outputs, where you + may want to specify the element type of a numpy array or + pandas Series. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + numpy.ndarray + Elementwise product of x and y. + + """ return x * y @@ -662,6 +1913,28 @@ class DFOutputs(typing.NamedTuple): def vec_function_df( x: npt.IntArray, y: npt.IntArray, ) -> Table[pdt.DataFrame]: + """ + Return a pandas DataFrame with two columns. + + When using a `DataFrame` return type, the return type of the UDF + must be wrapped in a `Table` annotation. The columns of the DataFrame + are determined by the fields of the return type NamedTuple, pydantic + model, or dataclass specified in the `returns` parameter of the `@udf` + decorator. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + Table[pandas.DataFrame] + Table containing a DataFrame with columns 'res' and 'res2'. + + """ return pdt.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3])) @@ -669,6 +1942,22 @@ def vec_function_df( async def async_vec_function_df( x: npt.IntArray, y: npt.IntArray, ) -> Table[pdt.DataFrame]: + """ + Asynchronously return a pandas DataFrame with two columns. + + Parameters + ---------- + x : numpy.ndarray + First array. + y : numpy.ndarray + Second array. + + Returns + ------- + Table[pandas.DataFrame] + Table containing a DataFrame with columns 'res' and 'res2'. + + """ return pdt.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3])) @@ -680,6 +1969,29 @@ class MaskOutputs(typing.NamedTuple): def vec_function_ints_masked( x: Masked[npt.IntArray], y: Masked[npt.IntArray], ) -> Table[Masked[npt.IntArray]]: + """ + Multiply two masked numpy arrays of int. + + Masked vectors can also be used in table-valued functions. + The return type must be wrapped in a `Table` annotation. + Each masked vector is represented by a `Masked` type, which + encapsulates a tuple of (data, nulls), where `data` is the + original vector type and `nulls` is a boolean array indicating + which elements are null. + + Parameters + ---------- + x : Masked[numpy.ndarray] + First masked array. + y : Masked[numpy.ndarray] + Second masked array. + + Returns + ------- + Table[Masked[numpy.ndarray]] + Table containing masked elementwise product. + + """ x_data, x_nulls = x y_data, y_nulls = y return Table(Masked(x_data * y_data, x_nulls | y_nulls)) @@ -694,9 +2006,976 @@ class MaskOutputs2(typing.NamedTuple): def vec_function_ints_masked2( x: Masked[npt.IntArray], y: Masked[npt.IntArray], ) -> Table[Masked[npt.IntArray], Masked[npt.IntArray]]: + """ + Multiply two masked numpy arrays of int, returning two masked outputs. + + Parameters + ---------- + x : Masked[numpy.ndarray] + First masked array. + y : Masked[numpy.ndarray] + Second masked array. + + Returns + ------- + Table[Masked[numpy.ndarray], Masked[numpy.ndarray]] + Table containing two masked elementwise products. + + """ x_data, x_nulls = x y_data, y_nulls = y return Table( Masked(x_data * y_data, x_nulls | y_nulls), Masked(x_data * y_data, x_nulls | y_nulls), ) + + +# +# Begin JSON UDFs +# + +# numpy + +@udf +def json_object_numpy( + x: npt.IntArray, + y: npt.JSONArray, +) -> npt.JSONArray: + """ + Create a numpy array of JSON objects from int and JSON arrays. + + The JSON type is used to represent JSON objects and lists. + The JSON type has argument and return transformers which will + automatically convert JSON strings from the database into + native Python types (dicts and lists) when passed into a UDF, + and convert native Python types back into JSON strings when + returning from a UDF. + + Parameters + ---------- + x : numpy.ndarray + Array of integers. + y : numpy.ndarray + Array of JSON objects. + + Returns + ------- + numpy.ndarray + Array of JSON objects. + + """ + return npt.array([ + None if a == 0 and b is None else dict( + x=a * 2 if a is not None else None, + y=b['foo'] if b is not None else None, + ) for a, b in zip(x, y) + ]) + + +@udf +def json_object_numpy_masked( + x: Masked[npt.IntArray], + y: Masked[npt.JSONArray], +) -> Masked[npt.JSONArray]: + """ + Create a masked numpy array of JSON objects from masked int and JSON arrays. + + Just as with other types, the JSON type can be used with masked + vectors to represent null values in vector inputs and outputs + that do not natively support nulls. Each parameter wrapped by the + `Masked` type is a tuple of (data, nulls), where `data` is the original + type and `nulls` is a boolean array indicating which elements are null. + + Parameters + ---------- + x : Masked[numpy.ndarray] + Masked array of integers. + y : Masked[numpy.ndarray] + Masked array of JSON objects. + + Returns + ------- + Masked[numpy.ndarray] + Masked array of JSON objects. + + """ + (x_data, x_nulls), (y_data, y_nulls) = x, y + return Masked( + npt.array([ + dict( + x=a * 2 if a is not None else 0, + y=b['foo'] if b is not None else None, + ) for a, b in zip(x_data, y_data) + ]), + x_nulls & y_nulls, + ) + + +@udf +def json_list_numpy(x: npt.IntArray, y: npt.JSONArray) -> npt.JSONArray: + """ + Create a numpy array of JSON objects from int and JSON arrays, using list indexing. + + Parameters + ---------- + x : numpy.ndarray + Array of integers. + y : numpy.ndarray + Array of JSON lists. + + Returns + ------- + numpy.ndarray + Array of JSON objects. + + """ + return npt.array([ + None if a == 0 and b is None else dict( + x=a * 2 if a is not None else None, + y=b[0] if b is not None else None, + ) + for a, b in zip(x, y) + ]) + + +@udf +def json_list_numpy_masked( + x: Masked[npt.IntArray], + y: Masked[npt.JSONArray], +) -> Masked[npt.JSONArray]: + """ + Create a masked numpy array of JSON objects from masked int and JSON arrays. + + Parameters + ---------- + x : Masked[numpy.ndarray] + Masked array of integers. + y : Masked[numpy.ndarray] + Masked array of JSON lists. + + Returns + ------- + Masked[numpy.ndarray] + Masked array of JSON objects. + + """ + (x_data, x_nulls), (y_data, y_nulls) = x, y + return Masked( + npt.array([ + dict( + x=a * 2 if a is not None else 0, + y=b[0] if b is not None else None, + ) for a, b in zip(x_data, y_data) + ]), + x_nulls & y_nulls, + ) + + +@udf +def json_object_numpy_tvf( + x: npt.IntArray, + y: npt.JSONArray, +) -> Table[npt.IntArray, npt.JSONArray]: + """ + Return a table of int and JSON arrays based on input arrays. + + Parameters + ---------- + x : numpy.ndarray + Array of integers. + y : numpy.ndarray + Array of JSON objects. + + Returns + ------- + Table[numpy.ndarray, numpy.ndarray] + Table containing int and JSON arrays. + + """ + return Table( + npt.array([x[0] * i for i in range(5)]), + npt.array([ + dict(x=x[0] * i, y=y[0]['foo'] if y[0] is not None else None) + for i in range(5) + ]), + ) + + +@udf +def json_object_numpy_tvf_masked( + x: Masked[npt.IntArray], + y: Masked[npt.JSONArray], +) -> Table[Masked[npt.IntArray], Masked[npt.JSONArray]]: + """ + Return a table of masked int and JSON arrays based on input arrays. + + Parameters + ---------- + x : Masked[numpy.ndarray] + Masked array of integers. + y : Masked[numpy.ndarray] + Masked array of JSON objects. + + Returns + ------- + Table[Masked[numpy.ndarray], Masked[numpy.ndarray]] + Table containing masked int and JSON arrays. + + """ + (x_data, _), (y_data, _) = x, y + return Table( + Masked( + npt.array([ + 0 if x_data[0] == 20 else x_data[0] * i for i in range(5) + ]), + npt.array([False, False, True, False, False]), + ), + Masked( + npt.array([ + dict( + x=x_data[0] * i, y=y_data[0]['foo'] + if i != 4 and y_data[0] is not None else None, + ) + for i in range(5) + ]), + npt.array([False, False, False, False, True]), + ), + ) + + +# pandas + + +@udf +def json_object_pandas( + x: pdt.IntSeries, + y: pdt.JSONSeries, +) -> pdt.JSONSeries: + """ + Create a pandas Series of JSON objects from int and JSON series. + + Parameters + ---------- + x : pandas.Series + Series of integers. + y : pandas.Series + Series of JSON objects. + + Returns + ------- + pandas.Series + Series of JSON objects. + + """ + return pdt.Series([ + None if a == 0 and b is None else dict( + x=a * 2 if a is not None else None, + y=b['foo'] if b is not None else None, + ) for a, b in zip(x, y) + ]) + + +@udf +def json_object_pandas_masked( + x: Masked[pdt.IntSeries], + y: Masked[pdt.JSONSeries], +) -> Masked[pdt.JSONSeries]: + """ + Create a masked pandas Series of JSON objects from masked int and JSON series. + + Parameters + ---------- + x : Masked[pandas.Series] + Masked series of integers. + y : Masked[pandas.Series] + Masked series of JSON objects. + + Returns + ------- + Masked[pandas.Series] + Masked series of JSON objects. + + """ + (x_data, x_nulls), (y_data, y_nulls) = x, y + return Masked( + pdt.Series([ + dict( + x=a * 2 if a is not None else 0, + y=b['foo'] if b is not None else None, + ) for a, b in zip(x_data, y_data) + ]), + x_nulls & y_nulls, + ) + + +@udf +def json_list_pandas(x: pdt.IntSeries, y: pdt.JSONSeries) -> pdt.JSONSeries: + """ + Create a pandas Series of JSON objects from int and JSON series. + + Parameters + ---------- + x : pandas.Series + Series of integers. + y : pandas.Series + Series of JSON lists. + + Returns + ------- + pandas.Series + Series of JSON objects. + + """ + return pdt.Series([ + None if a == 0 and b is None else dict( + x=a * 2 if a is not None else None, + y=b[0] if b is not None else None, + ) + for a, b in zip(x, y) + ]) + + +@udf +def json_list_pandas_masked( + x: Masked[pdt.IntSeries], + y: Masked[pdt.JSONSeries], +) -> Masked[pdt.JSONSeries]: + """ + Create a masked pandas Series of JSON objects from masked int and JSON series. + + Parameters + ---------- + x : Masked[pandas.Series] + Masked series of integers. + y : Masked[pandas.Series] + Masked series of JSON lists. + + Returns + ------- + Masked[pandas.Series] + Masked series of JSON objects. + + """ + (x_data, x_nulls), (y_data, y_nulls) = x, y + return Masked( + pdt.Series([ + dict( + x=a * 2 if a is not None else 0, + y=b[0] if b is not None else None, + ) for a, b in zip(x_data, y_data) + ]), + x_nulls & y_nulls, + ) + + +@udf +def json_object_pandas_tvf( + x: pdt.IntSeries, + y: pdt.JSONSeries, +) -> Table[pdt.IntSeries, pdt.JSONSeries]: + """ + Return a table of int and JSON arrays based on input series. + + Parameters + ---------- + x : pandas.Series + Series of integers. + y : pandas.Series + Series of JSON objects. + + Returns + ------- + Table[pandas.Series, pandas.Series] + Table containing int and JSON series. + + """ + return Table( + pdt.Series([x[0] * i for i in range(5)]), + pdt.Series([ + dict(x=x[0] * i, y=y[0]['foo'] if y[0] is not None else None) + for i in range(5) + ]), + ) + + +@udf +def json_object_pandas_tvf_masked( + x: Masked[pdt.IntSeries], + y: Masked[pdt.JSONSeries], +) -> Table[Masked[pdt.IntSeries], Masked[pdt.JSONSeries]]: + """ + Return a table of masked int and JSON arrays based on input series. + + Parameters + ---------- + x : Masked[pandas.Series] + Masked series of integers. + y : Masked[pandas.Series] + Masked series of JSON objects. + + Returns + ------- + Table[Masked[pandas.Series], Masked[pandas.Series]] + Table containing masked int and JSON series. + + """ + (x_data, _), (y_data, _) = x, y + return Table( + Masked( + pdt.Series([ + 0 if x_data[0] == 20 else x_data[0] * i for i in range(5) + ]), + pdt.Series([False, False, True, False, False]), + ), + Masked( + pdt.Series([ + dict( + x=x_data[0] * i, y=y_data[0]['foo'] + if i != 4 and y_data[0] is not None else None, + ) + for i in range(5) + ]), + pdt.Series([False, False, False, False, True]), + ), + ) + + +# polars + + +@udf +def json_object_polars( + x: plt.IntSeries, + y: plt.JSONSeries, +) -> plt.JSONSeries: + """ + Create a polars Series of JSON objects from int and JSON series. + + Parameters + ---------- + x : polars.Series + Series of integers. + y : polars.Series + Series of JSON objects. + + Returns + ------- + polars.Series + Series of JSON objects. + + """ + return plt.Series([ + None if a == 0 and b is None else dict( + x=a * 2 if a is not None else None, + y=b['foo'] if b is not None else None, + ) for a, b in zip(x, y) + ]) + + +@udf +def json_object_polars_masked( + x: Masked[plt.IntSeries], + y: Masked[plt.JSONSeries], +) -> Masked[plt.JSONSeries]: + """ + Create a masked polars Series of JSON objects from masked int and JSON series. + + Parameters + ---------- + x : Masked[polars.Series] + Masked series of integers. + y : Masked[polars.Series] + Masked series of JSON objects. + + Returns + ------- + Masked[polars.Series] + Masked series of JSON objects. + + """ + (x_data, x_nulls), (y_data, y_nulls) = x, y + return Masked( + plt.Series([ + dict( + x=a * 2 if a is not None else 0, + y=b['foo'] if b is not None else None, + ) for a, b in zip(x_data, y_data) + ]), + x_nulls & y_nulls, + ) + + +@udf +def json_list_polars(x: plt.IntSeries, y: plt.JSONSeries) -> plt.JSONSeries: + """ + Create a polars Series of JSON objects from int and JSON series, using list indexing. + + Parameters + ---------- + x : polars.Series + Series of integers. + y : polars.Series + Series of JSON lists. + + Returns + ------- + polars.Series + Series of JSON objects. + + """ + return plt.Series([ + None if a == 0 and b is None else dict( + x=a * 2 if a is not None else None, + y=b[0] if b is not None else None, + ) + for a, b in zip(x, y) + ]) + + +@udf +def json_list_polars_masked( + x: Masked[plt.IntSeries], + y: Masked[plt.JSONSeries], +) -> Masked[plt.JSONSeries]: + """ + Create a masked polars Series of JSON objects from masked int and JSON series. + + Parameters + ---------- + x : Masked[polars.Series] + Masked series of integers. + y : Masked[polars.Series] + Masked series of JSON lists. + + Returns + ------- + Masked[polars.Series] + Masked series of JSON objects. + + """ + (x_data, x_nulls), (y_data, y_nulls) = x, y + return Masked( + plt.Series([ + dict( + x=a * 2 if a is not None else 0, + y=b[0] if b is not None else None, + ) for a, b in zip(x_data, y_data) + ]), + x_nulls & y_nulls, + ) + + +@udf +def json_object_polars_tvf( + x: plt.IntSeries, + y: plt.JSONSeries, +) -> Table[plt.IntSeries, plt.JSONSeries]: + """ + Return a table of int and JSON arrays based on input series. + + Parameters + ---------- + x : polars.Series + Series of integers. + y : polars.Series + Series of JSON objects. + + Returns + ------- + Table[polars.Series, polars.Series] + Table containing int and JSON series. + + """ + return Table( + plt.Series([x[0] * i for i in range(5)]), + plt.Series([ + dict(x=x[0] * i, y=y[0]['foo'] if y[0] is not None else None) + for i in range(5) + ]), + ) + + +@udf +def json_object_polars_tvf_masked( + x: Masked[plt.IntSeries], + y: Masked[plt.JSONSeries], +) -> Table[Masked[plt.IntSeries], Masked[plt.JSONSeries]]: + """ + Return a table of masked int and JSON arrays based on input series. + + Parameters + ---------- + x : Masked[polars.Series] + Masked series of integers. + y : Masked[polars.Series] + Masked series of JSON objects. + + Returns + ------- + Table[Masked[polars.Series], Masked[polars.Series]] + Table containing masked int and JSON series. + + """ + (x_data, _), (y_data, _) = x, y + return Table( + Masked( + plt.Series([ + 0 if x_data[0] == 20 else x_data[0] * i for i in range(5) + ]), + plt.Series([False, False, True, False, False]), + ), + Masked( + plt.Series([ + dict( + x=x_data[0] * i, y=y_data[0]['foo'] + if i != 4 and y_data[0] is not None else None, + ) + for i in range(5) + ]), + plt.Series([False, False, False, False, True]), + ), + ) + + +# pyarrow + + +@udf +def json_object_pyarrow( + x: pat.IntArray, + y: pat.JSONArray, +) -> pat.JSONArray: + """ + Create a pyarrow array of JSON objects from int and JSON arrays. + + Parameters + ---------- + x : pyarrow.Array + Array of integers. + y : pyarrow.Array + Array of JSON objects. + + Returns + ------- + pyarrow.Array + Array of JSON objects. + + """ + return pat.array([ + None if a == 0 and b.as_py() is None else dict( + x=pc.multiply(a, 2) if a is not None else None, + y=json.loads(b.as_py())['foo'] if b.as_py() is not None else None, + ) for a, b in zip(x, y) + ]) + + +@udf +def json_object_pyarrow_masked( + x: Masked[pat.IntArray], + y: Masked[pat.JSONArray], +) -> Masked[pat.JSONArray]: + """ + Create a masked pyarrow array of JSON objects from masked int and JSON arrays. + + Parameters + ---------- + x : Masked[pyarrow.Array] + Masked array of integers. + y : Masked[pyarrow.Array] + Masked array of JSON objects. + + Returns + ------- + Masked[pyarrow.Array] + Masked array of JSON objects. + + """ + (x_data, x_nulls), (y_data, y_nulls) = x, y + return Masked( + pat.array([ + dict( + x=pc.multiply(a, 2) if a is not None else 0, + y=json.loads(b.as_py())['foo'] if b.as_py() is not None else None, + ) for a, b in zip(x_data, y_data) + ]), + pc.and_(x_nulls, y_nulls), + ) + + +@udf +def json_list_pyarrow(x: pat.IntArray, y: pat.JSONArray) -> pat.JSONArray: + """ + Create a pyarrow array of JSON objects from int and JSON arrays. + + Parameters + ---------- + x : pyarrow.Array + Array of integers. + y : pyarrow.Array + Array of JSON lists. + + Returns + ------- + pyarrow.Array + Array of JSON objects. + + """ + return pat.array([ + None if a == 0 and b is None else dict( + x=pc.multiply(a, 2) if a is not None else None, + y=json.loads(b.as_py())[0] if b.as_py() is not None else None, + ) + for a, b in zip(x, y) + ]) + + +@udf +def json_list_pyarrow_masked( + x: Masked[pat.IntArray], + y: Masked[pat.JSONArray], +) -> Masked[pat.JSONArray]: + """ + Create a masked pyarrow array of JSON objects from masked int and JSON arrays. + + Parameters + ---------- + x : Masked[pyarrow.Array] + Masked array of integers. + y : Masked[pyarrow.Array] + Masked array of JSON lists. + + Returns + ------- + Masked[pyarrow.Array] + Masked array of JSON objects. + + """ + (x_data, x_nulls), (y_data, y_nulls) = x, y + return Masked( + pat.array([ + dict( + x=pc.multiply(a, 2) if a is not None else 0, + y=json.loads(b.as_py())[0] if b.as_py() is not None else None, + ) for a, b in zip(x_data, y_data) + ]), + pc.and_(x_nulls, y_nulls), + ) + + +@udf +def json_object_pyarrow_tvf( + x: pat.IntArray, + y: pat.JSONArray, +) -> Table[pat.IntArray, pat.JSONArray]: + """ + Return a table of int and JSON arrays based on input arrays. + + Parameters + ---------- + x : pyarrow.Array + Array of integers. + y : pyarrow.Array + Array of JSON objects. + + Returns + ------- + Table[pyarrow.Array, pyarrow.Array] + Table containing int and JSON arrays. + + """ + return Table( + pat.array([pc.multiply(x[0], i) for i in range(5)]), + pat.array([ + dict( + x=pc.multiply(x[0], i), + y=json.loads(y[0].as_py())['foo'] if y[0].as_py() is not None else None, + ) + for i in range(5) + ]), + ) + + +@udf +def json_object_pyarrow_tvf_masked( + x: Masked[pat.IntArray], + y: Masked[pat.JSONArray], +) -> Table[Masked[pat.IntArray], Masked[pat.JSONArray]]: + """ + Return a table of masked int and JSON arrays based on input arrays. + + Parameters + ---------- + x : Masked[pyarrow.Array] + Masked array of integers. + y : Masked[pyarrow.Array] + Masked array of JSON objects. + + Returns + ------- + Table[Masked[pyarrow.Array], Masked[pyarrow.Array]] + Table containing masked int and JSON arrays. + + """ + (x_data, _), (y_data, _) = x, y + return Table( + Masked( + pat.array([ + 0 if x_data[0] == 20 else pc.multiply(x_data[0], i) for i in range(5) + ]), + pat.array([False, False, True, False, False]), + ), + Masked( + pat.array([ + dict( + x=pc.multiply(x_data[0], i), + y=json.loads(str(y_data[0]))['foo'] + if i != 4 and y_data[0].as_py() is not None else None, + ) + for i in range(5) + ]), + pat.array([False, False, False, False, True]), + ), + ) + + +@udf +def json_object_list(x: List[int], y: List[JSON]) -> List[JSON]: + """ + Create a list of JSON objects from int and JSON lists. + + Parameters + ---------- + x : list of int + List of integers. + y : list of JSON + List of JSON objects or arrays. + + Returns + ------- + list of JSON + List of JSON objects. + + """ + return [dict(x=x * 2, y=y['foo']) for x, y in zip(x, y)] # type: ignore + + +@udf +def json_list_list(x: List[int], y: List[JSON]) -> List[JSON]: + """ + Create a list of JSON objects from int and JSON lists. + + Parameters + ---------- + x : list of int + List of integers. + y : list of JSON + List of JSON objects or arrays. + + Returns + ------- + list of JSON + List of JSON objects. + + """ + return [dict(x=x * 2, y=y[0] if isinstance(y, list) else None) for x, y in zip(x, y)] + + +@udf +def json_object_list_tvf( + x: List[int], y: List[JSON], +) -> Table[List[Tuple[int, JSON]]]: + """ + Return a table of transformed values from lists of int and JSON objects. + + Parameters + ---------- + x : list of int + List of integers. + y : list of JSON + List of JSON objects or arrays. + + Returns + ------- + Table[List[Tuple[int, JSON]]] + Table containing transformed values. + + """ + out: List[Tuple[int, JSON]] = [] + for i in range(5): + out.append(( + x[0] * i, + dict(x=x[0] * i, y=y[0]['foo'] if isinstance(y[0], dict) else None), + )) + return Table(out) + + +@udf +def json_object_nonvector(x: int, y: JSON) -> JSON: + """ + Extract and transform values from a JSON object. + + Parameters + ---------- + x : int + An integer value. + y : JSON + A JSON object. + + Returns + ------- + JSON + A JSON object with transformed values. + + """ + if not isinstance(y, dict): + raise ValueError('Expected dict for JSON object') + return dict(x=x * 2, y=y['foo']) + + +@udf +def json_list_nonvector(x: int, y: JSON) -> JSON: + """ + Extract and transform values from a JSON array. + + Parameters + ---------- + x : int + An integer value. + y : JSON + A JSON array. + + Returns + ------- + JSON + A JSON object with transformed values. + + """ + if not isinstance(y, list): + raise ValueError('Expected list for JSON array') + return dict(x=x * 2, y=y[0]) + + +@udf +def json_object_nonvector_tvf( + x: int, y: JSON, +) -> Table[List[Tuple[int, JSON]]]: + """ + Return a table of transformed values from a JSON object. + + Parameters + ---------- + x : int + An integer value. + y : JSON + A JSON object. + + Returns + ------- + Table[List[Tuple[int, JSON]]] + Table containing transformed values. + + """ + out: List[Tuple[int, JSON]] = [] + for i in range(5): + out.append((x * i, dict(x=x * i, y=y['foo'] if isinstance(y, dict) else None))) + return Table(out) diff --git a/singlestoredb/tests/test.sql b/singlestoredb/tests/test.sql index ab3cf955..61073f96 100644 --- a/singlestoredb/tests/test.sql +++ b/singlestoredb/tests/test.sql @@ -677,4 +677,44 @@ INSERT INTO i64_vectors VALUES(2, '[4, 5, 6]'); INSERT INTO i64_vectors VALUES(3, '[-1, -4, 8]'); +COMMIT; + +CREATE TABLE `json_data` ( + id INT NOT NULL PRIMARY KEY, + x INT NOT NULL, + y JSON NOT NULL +); +INSERT INTO json_data VALUES (1, 101, '{"foo": 10, "bar": 2.75, "baz": "hello"}'); +INSERT INTO json_data VALUES (2, 121, '{"foo": 105, "bar": 3.5, "baz": "goodbye"}'); +INSERT INTO json_data VALUES (3, 151, '{"foo": 50, "bar": 7.63, "baz": ""}'); + +CREATE TABLE `json_data_with_nulls` ( + id INT NOT NULL PRIMARY KEY, + x INT NULL, + y JSON NULL +); +INSERT INTO json_data_with_nulls VALUES (1, NULL, '{"foo": 10, "bar": 2.75, "baz": "hello"}'); +INSERT INTO json_data_with_nulls VALUES (2, 121, NULL); +INSERT INTO json_data_with_nulls VALUES (3, 151, '{"foo": 50, "bar": 7.63, "baz": ""}'); +INSERT INTO json_data_with_nulls VALUES (4, NULL, NULL); + +CREATE TABLE `json_list_data` ( + id INT NOT NULL PRIMARY KEY, + x INT NOT NULL, + y JSON NOT NULL +); +INSERT INTO json_list_data VALUES (1, 101, '["foo", "bar", "baz"]'); +INSERT INTO json_list_data VALUES (2, 121, '["foo", "bar", "baz"]'); +INSERT INTO json_list_data VALUES (3, 151, '["foo", "bar", "baz"]'); + +CREATE TABLE `json_list_data_with_nulls` ( + id INT NOT NULL PRIMARY KEY, + x INT NULL, + y JSON NULL +); +INSERT INTO json_list_data_with_nulls VALUES (1, NULL, '["foo", "bar", "baz"]'); +INSERT INTO json_list_data_with_nulls VALUES (2, 121, NULL); +INSERT INTO json_list_data_with_nulls VALUES (3, 151, '["foo", "bar", "baz"]'); +INSERT INTO json_list_data_with_nulls VALUES (4, NULL, NULL); + COMMIT; diff --git a/singlestoredb/tests/test_ext_func.py b/singlestoredb/tests/test_ext_func.py index d3e680e5..bb0d3461 100755 --- a/singlestoredb/tests/test_ext_func.py +++ b/singlestoredb/tests/test_ext_func.py @@ -2,10 +2,13 @@ # type: ignore """Test SingleStoreDB external functions.""" import os +import random import socket +import string import subprocess import time import unittest +from contextlib import contextmanager import requests @@ -101,6 +104,23 @@ class TestExtFunc(unittest.TestCase): http_host = '127.0.0.1' http_port = 0 + @contextmanager + def temp_table(self, table_name=None, schema=None): + """Context manager for creating and cleaning up temporary tables.""" + if table_name is None: + table_name = 'temp_' + ''.join(random.choices(string.ascii_lowercase, k=8)) + + try: + if schema: + self.cur.execute(f'CREATE TABLE {table_name} {schema}') + yield table_name + finally: + try: + self.cur.execute(f'DROP TABLE IF EXISTS {table_name}') + except Exception: + # Ignore cleanup errors + pass + @classmethod def setUpClass(cls): sql_file = os.path.join(os.path.dirname(__file__), 'test.sql') @@ -1470,3 +1490,321 @@ def test_vec_function_ints_masked2(self): assert desc[1].name == 'res2' assert desc[1].type_code == ft.SHORT assert desc[1].null_ok is True + + def _test_json_object_vector(self, vector_type: str): + self.cur.execute( + f'select json_object_{vector_type}(x, y) as res ' + 'from json_data order by id', + ) + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.JSON + assert desc[0].null_ok is False if vector_type != 'pyarrow' else True + + assert [tuple(x) for x in self.cur] == [ + ({'x': 202, 'y': 10},), + ({'x': 242, 'y': 105},), + ({'x': 302, 'y': 50},), + ] + + if vector_type == 'pyarrow': + self.cur.execute( + f'select json_object_{vector_type}(x, y) as res ' + 'from json_data_with_nulls order by id', + ) + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.JSON + assert desc[0].null_ok is True + + assert [tuple(x) for x in self.cur] == [ + ({'x': None, 'y': 10},), + ({'x': 242, 'y': None},), + ({'x': 302, 'y': 50},), + ({'x': None, 'y': None},), + ] + + else: + with self.assertRaises(self.conn.OperationalError): + self.cur.execute( + f'select json_object_{vector_type}(x, y) as res ' + 'from json_data_with_nulls order by id', + ) + + # Masks are not used with Python lists + if vector_type in ['list', 'nonvector']: + return + + self.cur.execute( + f'select json_object_{vector_type}_masked(x, y) as res ' + 'from json_data order by id', + ) + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.JSON + assert desc[0].null_ok is True + + assert [tuple(x) for x in self.cur] == [ + ({'x': 202, 'y': 10},), + ({'x': 242, 'y': 105},), + ({'x': 302, 'y': 50},), + ] + + self.cur.execute( + f'select json_object_{vector_type}_masked(x, y) as res ' + 'from json_data_with_nulls order by id', + ) + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.JSON + assert desc[0].null_ok is True + + if vector_type == 'pyarrow': + assert [tuple(x) for x in self.cur] == [ + ({'x': None, 'y': 10},), + ({'x': 242, 'y': None},), + ({'x': 302, 'y': 50},), + (None,), + ] + else: + assert [tuple(x) for x in self.cur] == [ + ({'x': 0, 'y': 10},), + ({'x': 242, 'y': None},), + ({'x': 302, 'y': 50},), + (None,), + ] + + def test_json_object_numpy(self): + self._test_json_object_vector('numpy') + + def test_json_object_pandas(self): + self._test_json_object_vector('pandas') + + def test_json_object_polars(self): + self._test_json_object_vector('polars') + + def test_json_object_pyarrow(self): + self._test_json_object_vector('pyarrow') + + def test_json_object_list(self): + self._test_json_object_vector('list') + + def test_json_object_nonvector(self): + self._test_json_object_vector('nonvector') + + def _test_json_list_vector(self, vector_type: str): + self.cur.execute( + f'select json_list_{vector_type}(x, y) as res ' + 'from json_list_data order by id', + ) + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.JSON + assert desc[0].null_ok is False if vector_type != 'pyarrow' else True + + assert [tuple(x) for x in self.cur] == [ + ({'x': 202, 'y': 'foo'},), + ({'x': 242, 'y': 'foo'},), + ({'x': 302, 'y': 'foo'},), + ] + + # Pyarrow supports nulls, but the return object should not be null + if vector_type == 'pyarrow': + self.cur.execute( + f'select json_list_{vector_type}(x, y) as res ' + 'from json_list_data_with_nulls order by id', + ) + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.JSON + assert desc[0].null_ok is True + + assert [tuple(x) for x in self.cur] == [ + ({'x': None, 'y': 'foo'},), + ({'x': 242, 'y': None},), + ({'x': 302, 'y': 'foo'},), + ({'x': None, 'y': None},), + ] + + else: + with self.assertRaises(self.conn.OperationalError): + self.cur.execute( + f'select json_list_{vector_type}(x, y) as res ' + 'from json_list_data_with_nulls order by id', + ) + + # Masks are not used with Python lists + if vector_type in ['list', 'nonvector']: + return + + self.cur.execute( + f'select json_list_{vector_type}_masked(x, y) as res ' + 'from json_list_data order by id', + ) + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.JSON + assert desc[0].null_ok is True if vector_type != 'pyarrow' else True + + assert [tuple(x) for x in self.cur] == [ + ({'x': 202, 'y': 'foo'},), + ({'x': 242, 'y': 'foo'},), + ({'x': 302, 'y': 'foo'},), + ] + + self.cur.execute( + f'select json_list_{vector_type}_masked(x, y) as res ' + 'from json_list_data_with_nulls order by id', + ) + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.JSON + assert desc[0].null_ok is True + + if vector_type == 'pyarrow': + assert [tuple(x) for x in self.cur] == [ + ({'x': None, 'y': 'foo'},), + ({'x': 242, 'y': None},), + ({'x': 302, 'y': 'foo'},), + (None,), + ] + else: + assert [tuple(x) for x in self.cur] == [ + ({'x': 0, 'y': 'foo'},), + ({'x': 242, 'y': None},), + ({'x': 302, 'y': 'foo'},), + (None,), + ] + + def test_json_list_numpy(self): + self._test_json_list_vector('numpy') + + def test_json_list_pandas(self): + self._test_json_list_vector('pandas') + + def test_json_list_polars(self): + self._test_json_list_vector('polars') + + def test_json_list_pyarrow(self): + self._test_json_list_vector('pyarrow') + + def test_json_list_list(self): + self._test_json_list_vector('list') + + def test_json_list_nonvector(self): + self._test_json_list_vector('nonvector') + + def _test_json_object_vector_tvf(self, vector_type: str): + + int_data = 10 + json_data = '{"foo": "bar", "baz": 2.75}' + + self.cur.execute( + f'select * from json_object_{vector_type}_tvf(%s, %s) order by a', + (int_data, json_data), + ) + + desc = self.cur.description + assert len(desc) == 2 + assert desc[0].name == 'a' + assert desc[1].name == 'b' + assert desc[0].type_code in [ft.LONGLONG, ft.LONG] + assert desc[1].type_code == ft.JSON + assert desc[0].null_ok is False if vector_type != 'pyarrow' else True + assert desc[1].null_ok is False if vector_type != 'pyarrow' else True + + assert [tuple(x) for x in self.cur] == [ + (0, {'x': 0, 'y': 'bar'}), + (10, {'x': 10, 'y': 'bar'}), + (20, {'x': 20, 'y': 'bar'}), + (30, {'x': 30, 'y': 'bar'}), + (40, {'x': 40, 'y': 'bar'}), + ] + + if vector_type == 'pyarrow': + self.cur.execute( + f'select * from json_object_{vector_type}_tvf(10, NULL)', + ) + + desc = self.cur.description + assert len(desc) == 2 + assert desc[0].name == 'a' + assert desc[1].name == 'b' + assert desc[0].type_code == ft.LONG + assert desc[1].type_code == ft.JSON + assert desc[0].null_ok is True if vector_type != 'pyarrow' else True + assert desc[1].null_ok is True if vector_type != 'pyarrow' else True + + assert [tuple(x) for x in self.cur] == [ + (0, {'x': 0, 'y': None}), + (10, {'x': 10, 'y': None}), + (20, {'x': 20, 'y': None}), + (30, {'x': 30, 'y': None}), + (40, {'x': 40, 'y': None}), + ] + + else: + with self.assertRaises(self.conn.OperationalError): + self.cur.execute( + f'select * from json_object_{vector_type}_tvf(10, NULL)', + ) + + # Masks are not used with Python lists + if vector_type in ['list', 'nonvector']: + return + + self.cur.execute( + f'select * from json_object_{vector_type}_tvf_masked(%s, %s) order by a', + (int_data, json_data), + ) + + desc = self.cur.description + assert len(desc) == 2 + assert desc[0].name == 'a' + assert desc[1].name == 'b' + assert desc[0].type_code == ft.LONG + assert desc[1].type_code == ft.JSON + assert desc[0].null_ok is True if vector_type != 'pyarrow' else True + assert desc[1].null_ok is True if vector_type != 'pyarrow' else True + + assert [tuple(x) for x in self.cur] == [ + (None, {'x': 20, 'y': 'bar'}), + (0, {'x': 0, 'y': 'bar'}), + (10, {'x': 10, 'y': 'bar'}), + (30, {'x': 30, 'y': 'bar'}), + (40, None), + ] + + def test_json_object_numpy_tvf(self): + self._test_json_object_vector_tvf('numpy') + + def test_json_object_pandas_tvf(self): + self._test_json_object_vector_tvf('pandas') + + def test_json_object_polars_tvf(self): + self._test_json_object_vector_tvf('polars') + + def test_json_object_pyarrow_tvf(self): + self._test_json_object_vector_tvf('pyarrow') + + def test_json_object_list_tvf(self): + self._test_json_object_vector_tvf('list') + + def test_json_object_nonvector_tvf(self): + self._test_json_object_vector_tvf('nonvector') diff --git a/singlestoredb/tests/test_ext_func_data.py b/singlestoredb/tests/test_ext_func_data.py index 0acae21f..4a3832f3 100755 --- a/singlestoredb/tests/test_ext_func_data.py +++ b/singlestoredb/tests/test_ext_func_data.py @@ -31,20 +31,20 @@ BINARY = -254 col_spec = [ - ('tiny', TINYINT), - ('unsigned_tiny', UNSIGNED_TINYINT), - ('short', SMALLINT), - ('unsigned_short', UNSIGNED_SMALLINT), - ('long', INT), - ('unsigned_long', UNSIGNED_INT), - ('float', FLOAT), - ('double', DOUBLE), - ('longlong', BIGINT), - ('unsigned_longlong', UNSIGNED_BIGINT), - ('int24', MEDIUMINT), - ('unsigned_int24', UNSIGNED_MEDIUMINT), - ('string', STRING), - ('binary', BINARY), + ('tiny', TINYINT, None), + ('unsigned_tiny', UNSIGNED_TINYINT, None), + ('short', SMALLINT, None), + ('unsigned_short', UNSIGNED_SMALLINT, None), + ('long', INT, None), + ('unsigned_long', UNSIGNED_INT, None), + ('float', FLOAT, None), + ('double', DOUBLE, None), + ('longlong', BIGINT, None), + ('unsigned_longlong', UNSIGNED_BIGINT, None), + ('int24', MEDIUMINT, None), + ('unsigned_int24', UNSIGNED_MEDIUMINT, None), + ('string', STRING, None), + ('binary', BINARY, None), ] col_types = [x[1] for x in col_spec] @@ -268,7 +268,7 @@ class TestRowdat1(unittest.TestCase): def test_numpy_accel(self): dump_res = rowdat_1._dump_numpy_accel( - col_types, numpy_row_ids, numpy_data, + col_spec, numpy_row_ids, numpy_data, ) load_res = rowdat_1._load_numpy_accel(col_spec, dump_res) @@ -293,7 +293,7 @@ def test_numpy_accel(self): def test_numpy(self): dump_res = rowdat_1._dump_numpy( - col_types, numpy_row_ids, numpy_data, + col_spec, numpy_row_ids, numpy_data, ) load_res = rowdat_1._load_numpy(col_spec, dump_res) @@ -386,7 +386,7 @@ def test_numpy_accel_limits(self, name, dtype, data, res): # Accelerated with self.assertRaises(res, msg=f'Expected {res} for {data} in {dtype}'): rowdat_1._dump_numpy_accel( - [dtype], numpy_row_ids, [(arr, None)], + [('x', dtype, None)], numpy_row_ids, [(arr, None)], ) # Pure Python @@ -395,23 +395,23 @@ def test_numpy_accel_limits(self, name, dtype, data, res): else: with self.assertRaises(res, msg=f'Expected {res} for {data} in {dtype}'): rowdat_1._dump_numpy( - [dtype], numpy_row_ids, [(arr, None)], + [('x', dtype, None)], numpy_row_ids, [(arr, None)], ) else: # Accelerated dump_res = rowdat_1._dump_numpy_accel( - [dtype], numpy_row_ids, [(arr, None)], + [('x', dtype, None)], numpy_row_ids, [(arr, None)], ) - load_res = rowdat_1._load_numpy_accel([('x', dtype)], dump_res) + load_res = rowdat_1._load_numpy_accel([('x', dtype, None)], dump_res) assert load_res[1][0][0] == res, \ f'Expected {res} for {data}, but got {load_res[1][0][0]} in {dtype}' # Pure Python dump_res = rowdat_1._dump_numpy( - [dtype], numpy_row_ids, [(arr, None)], + [('x', dtype, None)], numpy_row_ids, [(arr, None)], ) - load_res = rowdat_1._load_numpy([('x', dtype)], dump_res) + load_res = rowdat_1._load_numpy([('x', dtype, None)], dump_res) assert load_res[1][0][0] == res, \ f'Expected {res} for {data}, but got {load_res[1][0][0]} in {dtype}' @@ -787,9 +787,9 @@ def test_numpy_accel_casts(self, name, dtype, data, res): # Accelerated dump_res = rowdat_1._dump_numpy_accel( - [dtype], numpy_row_ids, [(data, None)], + [('x', dtype, None)], numpy_row_ids, [(data, None)], ) - load_res = rowdat_1._load_numpy_accel([('x', dtype)], dump_res) + load_res = rowdat_1._load_numpy_accel([('x', dtype, None)], dump_res) if name == 'double from float32': assert load_res[1][0][0].dtype is res.dtype @@ -799,9 +799,9 @@ def test_numpy_accel_casts(self, name, dtype, data, res): # Pure Python dump_res = rowdat_1._dump_numpy( - [dtype], numpy_row_ids, [(data, None)], + [('x', dtype, None)], numpy_row_ids, [(data, None)], ) - load_res = rowdat_1._load_numpy([('x', dtype)], dump_res) + load_res = rowdat_1._load_numpy([('x', dtype, None)], dump_res) if name == 'double from float32': assert load_res[1][0][0].dtype is res.dtype @@ -811,7 +811,7 @@ def test_numpy_accel_casts(self, name, dtype, data, res): def test_python(self): dump_res = rowdat_1._dump( - col_types, py_row_ids, py_col_data, + col_spec, py_row_ids, py_col_data, ) load_res = rowdat_1._load(col_spec, dump_res) @@ -823,7 +823,7 @@ def test_python(self): def test_python_accel(self): dump_res = rowdat_1._dump_accel( - col_types, py_row_ids, py_col_data, + col_spec, py_row_ids, py_col_data, ) load_res = rowdat_1._load_accel(col_spec, dump_res) @@ -835,7 +835,7 @@ def test_python_accel(self): def test_polars(self): dump_res = rowdat_1._dump_polars( - col_types, polars_row_ids, polars_data, + col_spec, polars_row_ids, polars_data, ) load_res = rowdat_1._load_polars(col_spec, dump_res) @@ -860,7 +860,7 @@ def test_polars(self): def test_polars_accel(self): dump_res = rowdat_1._dump_polars_accel( - col_types, polars_row_ids, polars_data, + col_spec, polars_row_ids, polars_data, ) load_res = rowdat_1._load_polars_accel(col_spec, dump_res) @@ -885,7 +885,7 @@ def test_polars_accel(self): def test_pandas(self): dump_res = rowdat_1._dump_pandas( - col_types, pandas_row_ids, pandas_data, + col_spec, pandas_row_ids, pandas_data, ) load_res = rowdat_1._load_pandas(col_spec, dump_res) @@ -910,7 +910,7 @@ def test_pandas(self): def test_pandas_accel(self): dump_res = rowdat_1._dump_pandas_accel( - col_types, pandas_row_ids, pandas_data, + col_spec, pandas_row_ids, pandas_data, ) load_res = rowdat_1._load_pandas_accel(col_spec, dump_res) @@ -935,7 +935,7 @@ def test_pandas_accel(self): def test_pyarrow(self): dump_res = rowdat_1._dump_arrow( - col_types, pyarrow_row_ids, pyarrow_data, + col_spec, pyarrow_row_ids, pyarrow_data, ) load_res = rowdat_1._load_arrow(col_spec, dump_res) @@ -960,7 +960,7 @@ def test_pyarrow(self): def test_pyarrow_accel(self): dump_res = rowdat_1._dump_arrow_accel( - col_types, pyarrow_row_ids, pyarrow_data, + col_spec, pyarrow_row_ids, pyarrow_data, ) load_res = rowdat_1._load_arrow_accel(col_spec, dump_res) @@ -988,7 +988,7 @@ class TestJSON(unittest.TestCase): def test_numpy(self): dump_res = jsonx.dump_numpy( - col_types, numpy_row_ids, numpy_data, + col_spec, numpy_row_ids, numpy_data, ) import pprint pprint.pprint(json.loads(dump_res)) @@ -1015,7 +1015,7 @@ def test_numpy(self): def test_python(self): dump_res = jsonx.dump( - col_types, py_row_ids, py_col_data, + col_spec, py_row_ids, py_col_data, ) load_res = jsonx.load(col_spec, dump_res) @@ -1027,7 +1027,7 @@ def test_python(self): def test_polars(self): dump_res = jsonx.dump_polars( - col_types, polars_row_ids, polars_data, + col_spec, polars_row_ids, polars_data, ) load_res = jsonx.load_polars(col_spec, dump_res) @@ -1052,7 +1052,7 @@ def test_polars(self): def test_pandas(self): dump_res = rowdat_1._dump_pandas( - col_types, pandas_row_ids, pandas_data, + col_spec, pandas_row_ids, pandas_data, ) load_res = rowdat_1._load_pandas(col_spec, dump_res) @@ -1077,7 +1077,7 @@ def test_pandas(self): def test_pyarrow(self): dump_res = rowdat_1._dump_arrow( - col_types, pyarrow_row_ids, pyarrow_data, + col_spec, pyarrow_row_ids, pyarrow_data, ) load_res = rowdat_1._load_arrow(col_spec, dump_res) diff --git a/singlestoredb/tests/test_udf.py b/singlestoredb/tests/test_udf.py index ebf0f60c..443aa05d 100755 --- a/singlestoredb/tests/test_udf.py +++ b/singlestoredb/tests/test_udf.py @@ -14,8 +14,8 @@ import numpy as np import pydantic -from ..functions import dtypes as dt from ..functions import signature as sig +from ..functions import sql_types as dt from ..functions import Table from ..functions import udf @@ -735,3 +735,62 @@ def test_dtypes(self): # assert dt.VECTOR(8, dt.I16, nullable=False) == 'VECTOR(8, I16) NOT NULL' # assert dt.VECTOR(8, dt.I32, nullable=False) == 'VECTOR(8, I32) NOT NULL' # assert dt.VECTOR(8, dt.I64, nullable=False) == 'VECTOR(8, I64) NOT NULL' + + def test_json_types(self): + """Test JSON type handling for parameters and returns.""" + + # Test JSON type aliases + from ..functions.typing import JSON + + def alias_json_func(data: JSON) -> JSON: + return data + + sql = to_sql(alias_json_func) + self.assertIn('JSON NOT NULL', sql) + + def alias_vector_json_func(data: List[JSON]) -> List[JSON]: + return data + + sql = to_sql(alias_vector_json_func) + self.assertIn('JSON NOT NULL', sql) + + # Test typing package JSON aliases + from ..functions.typing import numpy as npt + from ..functions.typing import pandas as pdt + from ..functions.typing import polars as plt + from ..functions.typing import pyarrow as pat + + # Test numpy JSONArray + def numpy_json_func(data: npt.JSONArray) -> npt.JSONArray: + return data + + sql = to_sql(numpy_json_func) + self.assertIn('JSON NOT NULL', sql) + self.assertIn('RETURNS JSON NOT NULL', sql) + + # Test pandas JSONSeries + def pandas_json_func(data: pdt.JSONSeries) -> pdt.StringSeries: + import pandas as pd + return pd.Series(['result']) + + sql = to_sql(pandas_json_func) + self.assertIn('JSON NOT NULL', sql) + self.assertIn('RETURNS TEXT NOT NULL', sql) + + # Test polars JSONSeries + def polars_json_func(data: plt.JSONSeries) -> plt.Int32Series: + import polars as pl + return pl.Series([1], dtype=pl.Int32) + + sql = to_sql(polars_json_func) + self.assertIn('JSON NOT NULL', sql) + self.assertIn('RETURNS INT NOT NULL', sql) + + # Test pyarrow JSONArray + def arrow_json_func(data: pat.JSONArray) -> pat.JSONArray: + import pyarrow as pa + return pa.array([{'result': 'success'}]) + + sql = to_sql(arrow_json_func) + self.assertIn('JSON NULL', sql) + self.assertIn('RETURNS JSON NULL', sql)