Skip to content

Commit a048a51

Browse files
fix(dspy): updating syntax for embeddings method call which will be deprecated
1 parent 8775674 commit a048a51

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

docs/api/retrieval_model_clients/SnowflakeRM.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
sidebar_position:
2+
sidebar_position:
33
---
44

55
# retrieve.SnowflakeRM
@@ -20,6 +20,7 @@ SnowflakeRM(
2020
```
2121

2222
**Parameters:**
23+
2324
- `snowflake_table_name (str)`: The name of the Snowflake table containing embeddings.
2425
- `snowflake_credentials (dict)`: The connection parameters needed to initialize a Snowflake Snowpark Session.
2526
- `k (int, optional)`: The number of top passages to retrieve. Defaults to 3.
@@ -34,10 +35,12 @@ SnowflakeRM(
3435
Search the Snowflake table for the top `k` passages matching the given query or queries, using embeddings generated via the default `e5-base-v2` model or the specified `embedding_model`.
3536

3637
**Parameters:**
38+
3739
- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.
3840
- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.
3941

4042
**Returns:**
43+
4144
- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"id": str, "score": float, "long_text": str, "metadatas": dict }]`
4245

4346
### Quickstart
@@ -53,14 +56,14 @@ from dspy.retrieve.snowflake_rm import SnowflakeRM
5356
import os
5457

5558
connection_parameters = {
56-
59+
5760
"account": os.getenv('SNOWFLAKE_ACCOUNT'),
5861
"user": os.getenv('SNOWFLAKE_USER'),
5962
"password": os.getenv('SNOWFLAKE_PASSWORD'),
6063
"role": os.getenv('SNOWFLAKE_ROLE'),
6164
"warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'),
6265
"database": os.getenv('SNOWFLAKE_DATABASE'),
63-
"schema": os.getenv('SNOWFLAKE_SCHEMA')}
66+
"schema": os.getenv('SNOWFLAKE_SCHEMA')}
6467

6568
retriever_model = SnowflakeRM(
6669
snowflake_table_name="<YOUR_SNOWFLAKE_TABLE_NAME>",
@@ -74,4 +77,3 @@ results = retriever_model("Explore the meaning of life", k=5)
7477
for result in results:
7578
print("Document:", result.long_text, "\n")
7679
```
77-

dspy/retrieve/snowflake_rm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
except ImportError:
1313
raise ImportError(
14-
"The snowflake-snowpark-python library is required to use SnowflakeRM. Install it with dspy-ai[snowflake]"
14+
"The snowflake-snowpark-python library is required to use SnowflakeRM. Install it with dspy-ai[snowflake]",
1515
)
1616

1717

@@ -93,23 +93,22 @@ def _top_k_similar_chunks(self, query_embeddings, k):
9393
lit(query_embeddings).cast(VectorType(float, len(query_embeddings))),
9494
).as_("dist"),
9595
)
96-
.sort("dist",ascending=False)
96+
.sort("dist", ascending=False)
9797
.limit(k)
9898
)
9999

100100
return top_k.select(doc_table_key).to_pandas().values
101101

102102
@classmethod
103103
def _init_cortex(cls, credentials: dict) -> None:
104-
105104
session = Session.builder.configs(credentials).create()
106-
session.query_tag = {"origin":"sf_sit", "name":"dspy", "version":{"major":1, "minor":0}}
105+
session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}}
107106

108107
return session
109108

110109
def _get_embeddings(self, query: str) -> list[float]:
111110
# create embeddings for the query
112-
embed = snow_fn.builtin("snowflake.cortex.embed_text")
111+
embed = snow_fn.builtin("snowflake.cortex.embed_text_768")
113112
cortex_embed_args = embed(snow_fn.lit(self.embeddings_model), snow_fn.lit(query))
114113

115114
return self.client.range(1).withColumn("complete_cal", cortex_embed_args).collect()[0].COMPLETE_CAL

0 commit comments

Comments
 (0)