|
28 | 28 |
|
29 | 29 | import pytest |
30 | 30 | import sqlalchemy as sa |
31 | | -from sqlalchemy.orm import Session |
| 31 | +from sqlalchemy.orm import Session, sessionmaker |
32 | 32 | from sqlalchemy.sql import select |
33 | 33 |
|
34 | | -from sqlalchemy_cratedb import SA_VERSION, SA_1_4 |
35 | | -from sqlalchemy_cratedb.type import FloatVector |
| 34 | +try: |
| 35 | + from sqlalchemy.orm import declarative_base |
| 36 | +except ImportError: |
| 37 | + from sqlalchemy.ext.declarative import declarative_base |
36 | 38 |
|
37 | 39 | from crate.client.cursor import Cursor |
38 | 40 |
|
| 41 | +from sqlalchemy_cratedb import SA_VERSION, SA_1_4 |
| 42 | +from sqlalchemy_cratedb import FloatVector, knn_match |
39 | 43 | from sqlalchemy_cratedb.type.vector import from_db, to_db |
40 | 44 |
|
41 | 45 | fake_cursor = MagicMock(name="fake_cursor") |
@@ -102,6 +106,14 @@ def test_sql_select(self): |
102 | 106 | "SELECT testdrive.data FROM testdrive", select(self.table.c.data) |
103 | 107 | ) |
104 | 108 |
|
| 109 | + def test_sql_match(self): |
| 110 | + query = self.session.query(self.table.c.name) \ |
| 111 | + .filter(knn_match(self.table.c.data, [42.42, 43.43], 3)) |
| 112 | + self.assertSQL( |
| 113 | + "SELECT testdrive.name AS testdrive_name FROM testdrive WHERE KNN_MATCH(testdrive.data, ?, ?)", |
| 114 | + query |
| 115 | + ) |
| 116 | + |
105 | 117 |
|
106 | 118 | def test_from_db_success(): |
107 | 119 | """ |
@@ -201,3 +213,37 @@ def test_float_vector_as_generic(): |
201 | 213 | fv = FloatVector(3) |
202 | 214 | assert isinstance(fv.as_generic(), sa.ARRAY) |
203 | 215 | assert fv.python_type is list |
| 216 | + |
| 217 | + |
| 218 | +def test_float_vector_integration(): |
| 219 | + """ |
| 220 | + An integration test for `FLOAT_VECTOR` and `KNN_SEARCH`. |
| 221 | + """ |
| 222 | + np = pytest.importorskip("numpy") |
| 223 | + |
| 224 | + engine = sa.create_engine(f"crate://") |
| 225 | + session = sessionmaker(bind=engine)() |
| 226 | + Base = declarative_base() |
| 227 | + |
| 228 | + # Define DDL. |
| 229 | + class SearchIndex(Base): |
| 230 | + __tablename__ = 'search' |
| 231 | + name = sa.Column(sa.String, primary_key=True) |
| 232 | + embedding = sa.Column(FloatVector(3)) |
| 233 | + |
| 234 | + Base.metadata.drop_all(engine, checkfirst=True) |
| 235 | + Base.metadata.create_all(engine, checkfirst=True) |
| 236 | + |
| 237 | + # Insert record. |
| 238 | + foo_item = SearchIndex(name="foo", embedding=[42.42, 43.43, 44.44]) |
| 239 | + session.add(foo_item) |
| 240 | + session.commit() |
| 241 | + session.execute(sa.text("REFRESH TABLE search")) |
| 242 | + |
| 243 | + # Query record. |
| 244 | + query = session.query(SearchIndex.embedding) \ |
| 245 | + .filter(knn_match(SearchIndex.embedding, [42.42, 43.43, 41.41], 3)) |
| 246 | + result = query.first() |
| 247 | + |
| 248 | + # Compare outcome. |
| 249 | + assert np.array_equal(result.embedding, np.array([42.42, 43.43, 44.44], np.float32)) |
0 commit comments