|
18 | 18 | Union, |
19 | 19 | ) |
20 | 20 |
|
| 21 | +from redisvl.query.query import VectorQuery |
21 | 22 | from redisvl.redis.utils import convert_bytes, make_dict |
22 | 23 | from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper |
23 | 24 |
|
|
34 | 35 | from redis.commands.search.indexDefinition import IndexDefinition |
35 | 36 |
|
36 | 37 | from redisvl.exceptions import ( |
| 38 | + QueryValidationError, |
37 | 39 | RedisModuleVersionError, |
38 | 40 | RedisSearchError, |
39 | 41 | RedisVLError, |
|
46 | 48 | BaseVectorQuery, |
47 | 49 | CountQuery, |
48 | 50 | FilterQuery, |
49 | | - HybridQuery, |
50 | 51 | ) |
51 | 52 | from redisvl.query.filter import FilterExpression |
52 | 53 | from redisvl.redis.connection import ( |
53 | 54 | RedisConnectionFactory, |
54 | 55 | convert_index_info_to_schema, |
55 | 56 | ) |
56 | | -from redisvl.redis.utils import convert_bytes |
57 | 57 | from redisvl.schema import IndexSchema, StorageType |
58 | | -from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric |
| 58 | +from redisvl.schema.fields import ( |
| 59 | + VECTOR_NORM_MAP, |
| 60 | + VectorDistanceMetric, |
| 61 | + VectorIndexAlgorithm, |
| 62 | +) |
59 | 63 | from redisvl.utils.log import get_logger |
60 | 64 |
|
61 | 65 | logger = get_logger(__name__) |
@@ -194,6 +198,15 @@ def _storage(self) -> BaseStorage: |
194 | 198 | index_schema=self.schema |
195 | 199 | ) |
196 | 200 |
|
| 201 | + def _validate_query(self, query: BaseQuery) -> None: |
| 202 | + """Validate a query.""" |
| 203 | + if isinstance(query, VectorQuery): |
| 204 | + field = self.schema.fields[query._vector_field_name] |
| 205 | + if query.ef_runtime and field.attrs.algorithm != VectorIndexAlgorithm.HNSW: # type: ignore |
| 206 | + raise QueryValidationError( |
| 207 | + "Vector field using 'flat' algorithm does not support EF_RUNTIME query parameter." |
| 208 | + ) |
| 209 | + |
197 | 210 | @property |
198 | 211 | def name(self) -> str: |
199 | 212 | """The name of the Redis search index.""" |
@@ -592,6 +605,27 @@ def drop_keys(self, keys: Union[str, List[str]]) -> int: |
592 | 605 | else: |
593 | 606 | return self._redis_client.delete(keys) # type: ignore |
594 | 607 |
|
| 608 | + def drop_documents(self, ids: Union[str, List[str]]) -> int: |
| 609 | + """Remove documents from the index by their document IDs. |
| 610 | +
|
| 611 | + This method converts document IDs to Redis keys automatically by applying |
| 612 | + the index's key prefix and separator configuration. |
| 613 | +
|
| 614 | + Args: |
| 615 | + ids (Union[str, List[str]]): The document ID or IDs to remove from the index. |
| 616 | +
|
| 617 | + Returns: |
| 618 | + int: Count of documents deleted from Redis. |
| 619 | + """ |
| 620 | + if isinstance(ids, list): |
| 621 | + if not ids: |
| 622 | + return 0 |
| 623 | + keys = [self.key(id) for id in ids] |
| 624 | + return self._redis_client.delete(*keys) # type: ignore |
| 625 | + else: |
| 626 | + key = self.key(ids) |
| 627 | + return self._redis_client.delete(key) # type: ignore |
| 628 | + |
595 | 629 | def expire_keys( |
596 | 630 | self, keys: Union[str, List[str]], ttl: int |
597 | 631 | ) -> Union[int, List[int]]: |
@@ -816,6 +850,10 @@ def batch_query( |
816 | 850 |
|
817 | 851 | def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: |
818 | 852 | """Execute a query and process results.""" |
| 853 | + try: |
| 854 | + self._validate_query(query) |
| 855 | + except QueryValidationError as e: |
| 856 | + raise QueryValidationError(f"Invalid query: {str(e)}") from e |
819 | 857 | results = self.search(query.query, query_params=query.params) |
820 | 858 | return process_results(results, query=query, schema=self.schema) |
821 | 859 |
|
@@ -1236,6 +1274,28 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int: |
1236 | 1274 | else: |
1237 | 1275 | return await client.delete(keys) |
1238 | 1276 |
|
| 1277 | + async def drop_documents(self, ids: Union[str, List[str]]) -> int: |
| 1278 | + """Remove documents from the index by their document IDs. |
| 1279 | +
|
| 1280 | + This method converts document IDs to Redis keys automatically by applying |
| 1281 | + the index's key prefix and separator configuration. |
| 1282 | +
|
| 1283 | + Args: |
| 1284 | + ids (Union[str, List[str]]): The document ID or IDs to remove from the index. |
| 1285 | +
|
| 1286 | + Returns: |
| 1287 | + int: Count of documents deleted from Redis. |
| 1288 | + """ |
| 1289 | + client = await self._get_client() |
| 1290 | + if isinstance(ids, list): |
| 1291 | + if not ids: |
| 1292 | + return 0 |
| 1293 | + keys = [self.key(id) for id in ids] |
| 1294 | + return await client.delete(*keys) |
| 1295 | + else: |
| 1296 | + key = self.key(ids) |
| 1297 | + return await client.delete(key) |
| 1298 | + |
1239 | 1299 | async def expire_keys( |
1240 | 1300 | self, keys: Union[str, List[str]], ttl: int |
1241 | 1301 | ) -> Union[int, List[int]]: |
@@ -1356,9 +1416,10 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]: |
1356 | 1416 | async def _aggregate( |
1357 | 1417 | self, aggregation_query: AggregationQuery |
1358 | 1418 | ) -> List[Dict[str, Any]]: |
1359 | | - """Execute an aggretation query and processes the results.""" |
| 1419 | + """Execute an aggregation query and processes the results.""" |
1360 | 1420 | results = await self.aggregate( |
1361 | | - aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined] |
| 1421 | + aggregation_query, |
| 1422 | + query_params=aggregation_query.params, # type: ignore[attr-defined] |
1362 | 1423 | ) |
1363 | 1424 | return process_aggregate_results( |
1364 | 1425 | results, |
@@ -1486,6 +1547,10 @@ async def batch_query( |
1486 | 1547 |
|
1487 | 1548 | async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: |
1488 | 1549 | """Asynchronously execute a query and process results.""" |
| 1550 | + try: |
| 1551 | + self._validate_query(query) |
| 1552 | + except QueryValidationError as e: |
| 1553 | + raise QueryValidationError(f"Invalid query: {str(e)}") from e |
1489 | 1554 | results = await self.search(query.query, query_params=query.params) |
1490 | 1555 | return process_results(results, query=query, schema=self.schema) |
1491 | 1556 |
|
|
0 commit comments