|
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "code", |
15 | | - "execution_count": 97, |
| 15 | + "execution_count": null, |
16 | 16 | "metadata": {}, |
17 | 17 | "outputs": [], |
18 | 18 | "source": [ |
|
21 | 21 | }, |
22 | 22 | { |
23 | 23 | "cell_type": "code", |
24 | | - "execution_count": 98, |
| 24 | + "execution_count": null, |
25 | 25 | "metadata": {}, |
26 | 26 | "outputs": [], |
27 | 27 | "source": [ |
|
31 | 31 | }, |
32 | 32 | { |
33 | 33 | "cell_type": "code", |
34 | | - "execution_count": 99, |
| 34 | + "execution_count": null, |
35 | 35 | "metadata": {}, |
36 | 36 | "outputs": [], |
37 | 37 | "source": [ |
|
42 | 42 | }, |
43 | 43 | { |
44 | 44 | "cell_type": "code", |
45 | | - "execution_count": 100, |
| 45 | + "execution_count": null, |
46 | 46 | "metadata": {}, |
47 | 47 | "outputs": [], |
48 | 48 | "source": [ |
|
52 | 52 | }, |
53 | 53 | { |
54 | 54 | "cell_type": "code", |
55 | | - "execution_count": 101, |
| 55 | + "execution_count": null, |
56 | 56 | "metadata": {}, |
57 | 57 | "outputs": [], |
58 | 58 | "source": [ |
|
73 | 73 | }, |
74 | 74 | { |
75 | 75 | "cell_type": "code", |
76 | | - "execution_count": 102, |
| 76 | + "execution_count": null, |
77 | 77 | "metadata": {}, |
78 | 78 | "outputs": [], |
79 | 79 | "source": [ |
|
146 | 146 | }, |
147 | 147 | { |
148 | 148 | "cell_type": "code", |
149 | | - "execution_count": 103, |
| 149 | + "execution_count": null, |
150 | 150 | "metadata": {}, |
151 | 151 | "outputs": [], |
152 | 152 | "source": [ |
|
265 | 265 | " .format(index_name=index_name_quoted, table_name=table_name_quoted, column_name=column_name_quoted, with_clause=with_clause)\n" |
266 | 266 | ] |
267 | 267 | }, |
| 268 | + { |
| 269 | + "attachments": {}, |
| 270 | + "cell_type": "markdown", |
| 271 | + "metadata": {}, |
| 272 | + "source": [ |
| 273 | + "# Query Params" |
| 274 | + ] |
| 275 | + }, |
| 276 | + { |
| 277 | + "cell_type": "code", |
| 278 | + "execution_count": null, |
| 279 | + "metadata": {}, |
| 280 | + "outputs": [], |
| 281 | + "source": [ |
| 282 | + "#| export\n", |
| 283 | + "\n", |
| 284 | + "class QueryParams:\n", |
| 285 | + " def __init__(self, params: dict[str, Any]) -> None:\n", |
| 286 | + " self.params = params\n", |
| 287 | + " \n", |
| 288 | + " def get_statements(self) -> List[str]:\n", |
| 289 | + " return [\"SET LOCAL \" + key + \" = \" + str(value) for key, value in self.params.items()]\n", |
| 290 | + "\n", |
| 291 | + "class TimescaleVectorIndexParams(QueryParams):\n", |
| 292 | + " def __init__(self, search_list_size: int) -> None:\n", |
| 293 | + " super().__init__({\"tsv.query_search_list_size\": search_list_size})\n", |
| 294 | + "\n", |
| 295 | + "class IvfflatIndexParams(QueryParams):\n", |
| 296 | + " def __init__(self, probes: int) -> None:\n", |
| 297 | + " super().__init__({\"ivfflat.probes\": probes})\n", |
| 298 | + "\n", |
| 299 | + "class HNSWIndexParams(QueryParams):\n", |
| 300 | + " def __init__(self, ef_search: int) -> None:\n", |
| 301 | + " super().__init__({\"hnsw.ef_search\": ef_search})" |
| 302 | + ] |
| 303 | + }, |
268 | 304 | { |
269 | 305 | "attachments": {}, |
270 | 306 | "cell_type": "markdown", |
|
275 | 311 | }, |
276 | 312 | { |
277 | 313 | "cell_type": "code", |
278 | | - "execution_count": 104, |
| 314 | + "execution_count": null, |
279 | 315 | "metadata": {}, |
280 | 316 | "outputs": [], |
281 | 317 | "source": [ |
|
290 | 326 | }, |
291 | 327 | { |
292 | 328 | "cell_type": "code", |
293 | | - "execution_count": 105, |
| 329 | + "execution_count": null, |
294 | 330 | "metadata": {}, |
295 | 331 | "outputs": [], |
296 | 332 | "source": [ |
|
388 | 424 | }, |
389 | 425 | { |
390 | 426 | "cell_type": "code", |
391 | | - "execution_count": 106, |
| 427 | + "execution_count": null, |
392 | 428 | "metadata": {}, |
393 | 429 | "outputs": [], |
394 | 430 | "source": [ |
|
534 | 570 | }, |
535 | 571 | { |
536 | 572 | "cell_type": "code", |
537 | | - "execution_count": 107, |
| 573 | + "execution_count": null, |
538 | 574 | "metadata": {}, |
539 | 575 | "outputs": [], |
540 | 576 | "source": [ |
|
848 | 884 | }, |
849 | 885 | { |
850 | 886 | "cell_type": "code", |
851 | | - "execution_count": 108, |
| 887 | + "execution_count": null, |
852 | 888 | "metadata": {}, |
853 | 889 | "outputs": [ |
854 | 890 | { |
|
876 | 912 | "Generates a query to create the tables, indexes, and extensions needed to store the vector data." |
877 | 913 | ] |
878 | 914 | }, |
879 | | - "execution_count": 108, |
| 915 | + "execution_count": null, |
880 | 916 | "metadata": {}, |
881 | 917 | "output_type": "execute_result" |
882 | 918 | } |
|
895 | 931 | }, |
896 | 932 | { |
897 | 933 | "cell_type": "code", |
898 | | - "execution_count": 109, |
| 934 | + "execution_count": null, |
899 | 935 | "metadata": {}, |
900 | 936 | "outputs": [], |
901 | 937 | "source": [ |
|
1128 | 1164 | " filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n", |
1129 | 1165 | " predicates: Optional[Predicates] = None,\n", |
1130 | 1166 | " uuid_time_filter: Optional[UUIDTimeRange] = None,\n", |
| 1167 | + " query_params: Optional[QueryParams] = None\n", |
1131 | 1168 | " ): \n", |
1132 | 1169 | " \"\"\"\n", |
1133 | 1170 | " Retrieves similar records using a similarity query.\n", |
|
1149 | 1186 | " \"\"\"\n", |
1150 | 1187 | " (query, params) = self.builder.search_query(\n", |
1151 | 1188 | " query_embedding, limit, filter, predicates, uuid_time_filter)\n", |
1152 | | - " async with await self.connect() as pool:\n", |
1153 | | - " return await pool.fetch(query, *params)" |
| 1189 | + " if query_params is not None:\n", |
| 1190 | + " async with await self.connect() as pool:\n", |
| 1191 | + " async with pool.transaction():\n", |
| 1192 | + " #Looks like there is no way to pipeline this: https://github.com/MagicStack/asyncpg/issues/588\n", |
| 1193 | + " statements = query_params.get_statements()\n", |
| 1194 | + " for statement in statements:\n", |
| 1195 | + " await pool.execute(statement)\n", |
| 1196 | + " return await pool.fetch(query, *params)\n", |
| 1197 | + " else:\n", |
| 1198 | + " async with await self.connect() as pool:\n", |
| 1199 | + " return await pool.fetch(query, *params)" |
1154 | 1200 | ] |
1155 | 1201 | }, |
1156 | 1202 | { |
1157 | 1203 | "cell_type": "code", |
1158 | | - "execution_count": 110, |
| 1204 | + "execution_count": null, |
1159 | 1205 | "metadata": {}, |
1160 | 1206 | "outputs": [ |
1161 | 1207 | { |
|
1183 | 1229 | "Creates necessary tables." |
1184 | 1230 | ] |
1185 | 1231 | }, |
1186 | | - "execution_count": 110, |
| 1232 | + "execution_count": null, |
1187 | 1233 | "metadata": {}, |
1188 | 1234 | "output_type": "execute_result" |
1189 | 1235 | } |
|
1194 | 1240 | }, |
1195 | 1241 | { |
1196 | 1242 | "cell_type": "code", |
1197 | | - "execution_count": 111, |
| 1243 | + "execution_count": null, |
1198 | 1244 | "metadata": {}, |
1199 | 1245 | "outputs": [ |
1200 | 1246 | { |
|
1222 | 1268 | "Creates necessary tables." |
1223 | 1269 | ] |
1224 | 1270 | }, |
1225 | | - "execution_count": 111, |
| 1271 | + "execution_count": null, |
1226 | 1272 | "metadata": {}, |
1227 | 1273 | "output_type": "execute_result" |
1228 | 1274 | } |
|
1233 | 1279 | }, |
1234 | 1280 | { |
1235 | 1281 | "cell_type": "code", |
1236 | | - "execution_count": 112, |
| 1282 | + "execution_count": null, |
1237 | 1283 | "metadata": {}, |
1238 | 1284 | "outputs": [ |
| 1285 | + { |
| 1286 | + "name": "stderr", |
| 1287 | + "output_type": "stream", |
| 1288 | + "text": [ |
| 1289 | + "/Users/cevian/.pyenv/versions/3.11.4/envs/nbdev_env/lib/python3.11/site-packages/fastcore/docscrape.py:225: UserWarning: potentially wrong underline length... \n", |
| 1290 | + "Returns \n", |
| 1291 | + "-------- in \n", |
| 1292 | + "Retrieves similar records using a similarity query.\n", |
| 1293 | + "...\n", |
| 1294 | + " else: warn(msg)\n" |
| 1295 | + ] |
| 1296 | + }, |
1239 | 1297 | { |
1240 | 1298 | "data": { |
1241 | 1299 | "text/markdown": [ |
|
1248 | 1306 | "> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n", |
1249 | 1307 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n", |
1250 | 1308 | "> ne, predicates:Optional[__main__.Predicates]=None,\n", |
1251 | | - "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n", |
| 1309 | + "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n", |
| 1310 | + "> query_params:Optional[__main__.QueryParams]=None)\n", |
1252 | 1311 | "\n", |
1253 | 1312 | "Retrieves similar records using a similarity query.\n", |
1254 | 1313 | "\n", |
|
1259 | 1318 | "| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n", |
1260 | 1319 | "| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n", |
1261 | 1320 | "| uuid_time_filter | Optional | None | |\n", |
| 1321 | + "| query_params | Optional | None | |\n", |
1262 | 1322 | "| **Returns** | **List: List of similar records.** | | |" |
1263 | 1323 | ], |
1264 | 1324 | "text/plain": [ |
|
1271 | 1331 | "> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n", |
1272 | 1332 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n", |
1273 | 1333 | "> ne, predicates:Optional[__main__.Predicates]=None,\n", |
1274 | | - "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n", |
| 1334 | + "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n", |
| 1335 | + "> query_params:Optional[__main__.QueryParams]=None)\n", |
1275 | 1336 | "\n", |
1276 | 1337 | "Retrieves similar records using a similarity query.\n", |
1277 | 1338 | "\n", |
|
1282 | 1343 | "| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n", |
1283 | 1344 | "| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n", |
1284 | 1345 | "| uuid_time_filter | Optional | None | |\n", |
| 1346 | + "| query_params | Optional | None | |\n", |
1285 | 1347 | "| **Returns** | **List: List of similar records.** | | |" |
1286 | 1348 | ] |
1287 | 1349 | }, |
1288 | | - "execution_count": 112, |
| 1350 | + "execution_count": null, |
1289 | 1351 | "metadata": {}, |
1290 | 1352 | "output_type": "execute_result" |
1291 | 1353 | } |
|
1296 | 1358 | }, |
1297 | 1359 | { |
1298 | 1360 | "cell_type": "code", |
1299 | | - "execution_count": 117, |
| 1361 | + "execution_count": null, |
1300 | 1362 | "metadata": {}, |
1301 | 1363 | "outputs": [], |
1302 | 1364 | "source": [ |
|
1317 | 1379 | }, |
1318 | 1380 | { |
1319 | 1381 | "cell_type": "code", |
1320 | | - "execution_count": 118, |
| 1382 | + "execution_count": null, |
1321 | 1383 | "metadata": {}, |
1322 | 1384 | "outputs": [], |
1323 | 1385 | "source": [ |
|
1564 | 1626 | "assert len(rec) == 0\n", |
1565 | 1627 | "rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(end_date=specific_datetime+timedelta(seconds=1), time_delta=timedelta(days=7)))\n", |
1566 | 1628 | "assert len(rec) == 1\n", |
| 1629 | + "rec = await vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(10))\n", |
| 1630 | + "assert len(rec) == 2\n", |
| 1631 | + "rec = await vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(100))\n", |
| 1632 | + "assert len(rec) == 2\n", |
1567 | 1633 | "await vec.drop_table()\n", |
1568 | 1634 | "await vec.close()" |
1569 | 1635 | ] |
|
1883 | 1949 | " filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n", |
1884 | 1950 | " predicates: Optional[Predicates] = None,\n", |
1885 | 1951 | " uuid_time_filter: Optional[UUIDTimeRange] = None,\n", |
| 1952 | + " query_params: Optional[QueryParams] = None,\n", |
1886 | 1953 | " ):\n", |
1887 | 1954 | " \"\"\"\n", |
1888 | 1955 | " Retrieves similar records using a similarity query.\n", |
|
1910 | 1977 | " (query, params) = self.builder.search_query(\n", |
1911 | 1978 | " query_embedding_np, limit, filter, predicates, uuid_time_filter)\n", |
1912 | 1979 | " query, params = self._translate_to_pyformat(query, params)\n", |
| 1980 | + "\n", |
| 1981 | + " if query_params is not None:\n", |
| 1982 | + " prefix = \"; \".join(query_params.get_statements())\n", |
| 1983 | + " query = f\"{prefix}; {query}\"\n", |
| 1984 | + " \n", |
1913 | 1985 | " with self.connect() as conn:\n", |
1914 | 1986 | " with conn.cursor() as cur:\n", |
1915 | 1987 | " cur.execute(query, params)\n", |
|
2021 | 2093 | "> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n", |
2022 | 2094 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n", |
2023 | 2095 | "> e, predicates:Optional[__main__.Predicates]=None,\n", |
2024 | | - "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n", |
| 2096 | + "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n", |
| 2097 | + "> query_params:Optional[__main__.QueryParams]=None)\n", |
2025 | 2098 | "\n", |
2026 | 2099 | "Retrieves similar records using a similarity query.\n", |
2027 | 2100 | "\n", |
|
2032 | 2105 | "| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n", |
2033 | 2106 | "| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n", |
2034 | 2107 | "| uuid_time_filter | Optional | None | |\n", |
| 2108 | + "| query_params | Optional | None | |\n", |
2035 | 2109 | "| **Returns** | **List: List of similar records.** | | |" |
2036 | 2110 | ], |
2037 | 2111 | "text/plain": [ |
|
2044 | 2118 | "> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n", |
2045 | 2119 | "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n", |
2046 | 2120 | "> e, predicates:Optional[__main__.Predicates]=None,\n", |
2047 | | - "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n", |
| 2121 | + "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n", |
| 2122 | + "> query_params:Optional[__main__.QueryParams]=None)\n", |
2048 | 2123 | "\n", |
2049 | 2124 | "Retrieves similar records using a similarity query.\n", |
2050 | 2125 | "\n", |
|
2055 | 2130 | "| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n", |
2056 | 2131 | "| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n", |
2057 | 2132 | "| uuid_time_filter | Optional | None | |\n", |
| 2133 | + "| query_params | Optional | None | |\n", |
2058 | 2134 | "| **Returns** | **List: List of similar records.** | | |" |
2059 | 2135 | ] |
2060 | 2136 | }, |
|
2314 | 2390 | "assert len(rec) == 0\n", |
2315 | 2391 | "rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(end_date=specific_datetime+timedelta(seconds=1), time_delta=timedelta(days=7)))\n", |
2316 | 2392 | "assert len(rec) == 1\n", |
| 2393 | + "rec = vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(10))\n", |
| 2394 | + "assert len(rec) == 2\n", |
| 2395 | + "rec = vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(100))\n", |
| 2396 | + "assert len(rec) == 2\n", |
2317 | 2397 | "vec.drop_table()\n", |
2318 | 2398 | "vec.close()" |
2319 | 2399 | ] |
|
0 commit comments