Skip to content

Commit 5953d79

Browse files
committed
Vector: Add support for CrateDB's FLOAT_VECTOR data type: FloatVector
https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector
1 parent 07bba7b commit 5953d79

File tree

6 files changed

+180
-2
lines changed

6 files changed

+180
-2
lines changed

CHANGES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33

44
## Unreleased
55

6+
- Added support for CrateDB's [FLOAT_VECTOR] data type. For SQLAlchemy
7+
column definitions, you can use it like `FloatVector(dimensions=1024)`.
8+
9+
[FLOAT_VECTOR]: https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector
10+
611

712
## 2023/09/29 0.34.0
813

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ dependencies = [
9191
"verlib2==0.2",
9292
]
9393
[project.optional-dependencies]
94+
all = [
95+
"sqlalchemy-cratedb[vector]",
96+
]
9497
develop = [
9598
"black<24",
9699
"mypy<1.9",
@@ -114,6 +117,9 @@ test = [
114117
"pytest-cov<5",
115118
"pytest-mock<4",
116119
]
120+
vector = [
121+
"numpy",
122+
]
117123
[project.urls]
118124
changelog = "https://github.com/crate-workbench/sqlalchemy-cratedb/blob/main/CHANGES.md"
119125
documentation = "https://github.com/crate-workbench/sqlalchemy-cratedb"

src/sqlalchemy_cratedb/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,12 @@ def visit_ARRAY(self, type_, **kw):
238238
def visit_OBJECT(self, type_, **kw):
239239
return "OBJECT"
240240

241+
def visit_FLOAT_VECTOR(self, type_, **kw):
242+
dimensions = type_.dimensions
243+
if dimensions is None:
244+
raise ValueError("FloatVector must be initialized with dimension size")
245+
return f"FLOAT_VECTOR({dimensions})"
246+
241247

242248
class CrateCompiler(compiler.SQLCompiler):
243249

src/sqlalchemy_cratedb/dialect.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from crate.client.exceptions import TimezoneUnawareException
3535
from .sa_version import SA_VERSION, SA_1_4, SA_2_0
36-
from .type import ObjectArray, ObjectType
36+
from .type import FloatVector, ObjectArray, ObjectType
3737

3838
TYPES_MAP = {
3939
"boolean": sqltypes.Boolean,
@@ -51,7 +51,8 @@
5151
"float": sqltypes.Float,
5252
"real": sqltypes.Float,
5353
"string": sqltypes.String,
54-
"text": sqltypes.String
54+
"text": sqltypes.String,
55+
"float_vector": FloatVector,
5556
}
5657
try:
5758
# SQLAlchemy >= 1.1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .array import ObjectArray
22
from .geo import Geopoint, Geoshape
33
from .object import ObjectType
4+
from .vector import FloatVector
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""
2+
## About
3+
SQLAlchemy data type implementation for CrateDB's `FLOAT_VECTOR` type.
4+
5+
## References
6+
- https://crate.io/docs/crate/reference/en/master/general/ddl/data-types.html#float-vector
7+
- https://crate.io/docs/crate/reference/en/master/general/builtins/scalar-functions.html#scalar-knn-match
8+
9+
## Details
10+
The implementation is based on SQLAlchemy's `TypeDecorator`, and also
11+
offers compiler support.
12+
13+
## Notes
14+
CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`.
15+
-- https://github.com/crate/crate/blob/5.5.1/server/src/main/java/io/crate/types/FloatVectorType.java#L55
16+
17+
On the other hand, pgvector use a comparator to apply different similarity
18+
functions as operators, see `pgvector.sqlalchemy.Vector.comparator_factory`.
19+
20+
<->: l2/euclidean_distance
21+
<#>: max_inner_product
22+
<=>: cosine_distance
23+
24+
## Backlog
25+
- The type implementation might want to be accompanied by corresponding support
26+
for the `KNN_MATCH` function, similar to what the dialect already offers for
27+
fulltext search through its `Match` predicate.
28+
29+
## Origin
30+
This module is based on the corresponding pgvector implementation
31+
by Andrew Kane. Thank you.
32+
33+
The MIT License (MIT)
34+
Copyright (c) 2021-2023 Andrew Kane
35+
https://github.com/pgvector/pgvector-python
36+
"""
37+
import typing as t
38+
39+
if t.TYPE_CHECKING:
40+
import numpy.typing as npt
41+
42+
import sqlalchemy as sa
43+
44+
__all__ = ["FloatVector"]
45+
46+
47+
def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]:
48+
import numpy as np
49+
50+
# from `pgvector.utils`
51+
# could be ndarray if already cast by lower-level driver
52+
if value is None or isinstance(value, np.ndarray):
53+
return value
54+
55+
return np.array(value, dtype=np.float32)
56+
57+
58+
def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
59+
import numpy as np
60+
61+
# from `pgvector.utils`
62+
if value is None:
63+
return value
64+
65+
if isinstance(value, np.ndarray):
66+
if value.ndim != 1:
67+
raise ValueError("expected ndim to be 1")
68+
69+
if not np.issubdtype(value.dtype, np.integer) and not np.issubdtype(value.dtype, np.floating):
70+
raise ValueError("dtype must be numeric")
71+
72+
value = value.tolist()
73+
74+
if dim is not None and len(value) != dim:
75+
raise ValueError("expected %d dimensions, not %d" % (dim, len(value)))
76+
77+
return value
78+
79+
80+
class FloatVector(sa.TypeDecorator[t.Sequence[float]]):
81+
82+
"""
83+
An improved implementation of the `FloatVector` data type for CrateDB,
84+
compared to the previous implementation on behalf of the LangChain adapter.
85+
86+
The previous implementation, based on SQLAlchemy's `UserDefinedType`, didn't
87+
respect the `python_type` property on backward/reverse resolution of types.
88+
This was observed on Meltano's database connector machinery doing a
89+
type cast, which led to a `NotImplementedError`.
90+
91+
typing.cast(type, sql_type.python_type) => NotImplementedError
92+
93+
The `UserDefinedType` approach is easier to implement, because it doesn't
94+
need compiler support.
95+
96+
To get full SQLAlchemy type support, including support for forward- and
97+
backward resolution / type casting, the custom data type should derive
98+
from SQLAlchemy's `TypeEngine` base class instead.
99+
100+
When deriving from `TypeEngine`, you will need to set the `__visit_name__`
101+
attribute, and add a corresponding visitor method to the `CrateTypeCompiler`,
102+
in this case, `visit_FLOAT_VECTOR`.
103+
104+
Now, rendering a DDL succeeds. However, when reflecting the DDL schema back,
105+
it doesn't work until you will establish a corresponding reverse type mapping.
106+
107+
By invoking `SELECT DISTINCT(data_type) FROM information_schema.columns;`,
108+
you will find out that the internal type name is `float_vector`, so you
109+
announce it to the dialect using `TYPES_MAP["float_vector"] = FloatVector`.
110+
111+
Still not there: `NotImplementedError: Default TypeEngine.as_generic() heuristic
112+
method was unsuccessful for target_cratedb.sqlalchemy.vector.FloatVector. A
113+
custom as_generic() method must be implemented for this type class.`
114+
115+
So, as it signals that the type implementation also needs an `as_generic`
116+
property, let's supply one, returning `sqltypes.ARRAY`.
117+
118+
It looks like, in exchange to those improvements, the `get_col_spec`
119+
method is not needed any longer.
120+
121+
TODO: Would it be a good idea to derive from SQLAlchemy's
122+
`ARRAY` right away, to get a few of the features without
123+
the need to redefine them?
124+
125+
Please note the outcome of this analysis and the corresponding implementation
126+
has been derived from empirical observations, and from the feeling that we also
127+
lack corresponding support on the other special data types of CrateDB (ARRAY and
128+
OBJECT) within the SQLAlchemy dialect, i.e. "that something must be wrong or
129+
incomplete". In this spirit, it is advisable to review and improve their
130+
implementations correspondingly.
131+
"""
132+
133+
cache_ok = False
134+
135+
__visit_name__ = "FLOAT_VECTOR"
136+
137+
_is_array = True
138+
139+
zero_indexes = False
140+
141+
impl = sa.ARRAY
142+
143+
def __init__(self, dimensions: int = None):
144+
super().__init__(sa.FLOAT, dimensions=dimensions)
145+
146+
def as_generic(self):
147+
return sa.ARRAY
148+
149+
def bind_processor(self, dialect: sa.Dialect) -> t.Callable:
150+
def process(value: t.Iterable) -> t.Optional[t.List]:
151+
return to_db(value, self.dimensions)
152+
153+
return process
154+
155+
def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable:
156+
def process(value: t.Any) -> t.Optional[npt.ArrayLike]:
157+
return from_db(value)
158+
159+
return process

0 commit comments

Comments
 (0)