Skip to content

Commit 822f932

Browse files
Merge pull request stanfordnlp#859 from sfc-gh-alherrera/dspy-snowflake
DSPy Support For Snowflake LLM's
2 parents 5bc17d8 + 23cd780 commit 822f932

File tree

10 files changed

+605
-30
lines changed

10 files changed

+605
-30
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ Or open our intro notebook in Google Colab: [<img align="center" src="https://co
7272

7373
By default, DSPy installs the latest `openai` from pip. However, if you install old version before OpenAI changed their API `openai~=0.28.1`, the library will use that just fine. Both are supported.
7474

75-
For the optional (alphabetically sorted) [Chromadb](https://github.com/chroma-core/chroma), [Qdrant](https://github.com/qdrant/qdrant), [Marqo](https://github.com/marqo-ai/marqo), Pinecone, [Weaviate](https://github.com/weaviate/weaviate),
75+
For the optional (alphabetically sorted) [Chromadb](https://github.com/chroma-core/chroma), [Qdrant](https://github.com/qdrant/qdrant), [Marqo](https://github.com/marqo-ai/marqo), Pinecone, [Snowflake](https://github.com/snowflakedb/snowpark-python) [Weaviate](https://github.com/weaviate/weaviate),
7676
or [Milvus](https://github.com/milvus-io/milvus) retrieval integration(s), include the extra(s) below:
7777

7878
```
79-
pip install dspy-ai[chromadb] # or [qdrant] or [marqo] or [mongodb] or [pinecone] or [weaviate] or [milvus]
79+
pip install dspy-ai[chromadb] # or [qdrant] or [marqo] or [mongodb] or [pinecone] or [snowflake] or [weaviate] or [milvus]
8080
```
8181

8282
## 2) Documentation
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
---
2+
sidebar_position:
3+
---
4+
5+
# dspy.Snowflake
6+
7+
### Usage
8+
9+
```python
10+
import dspy
11+
import os
12+
13+
connection_parameters = {
14+
15+
"account": os.getenv('SNOWFLAKE_ACCOUNT'),
16+
"user": os.getenv('SNOWFLAKE_USER'),
17+
"password": os.getenv('SNOWFLAKE_PASSWORD'),
18+
"role": os.getenv('SNOWFLAKE_ROLE'),
19+
"warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'),
20+
"database": os.getenv('SNOWFLAKE_DATABASE'),
21+
"schema": os.getenv('SNOWFLAKE_SCHEMA')}
22+
23+
lm = dspy.Snowflake(model="mixtral-8x7b",credentials=connection_parameters)
24+
```
25+
26+
### Constructor
27+
28+
The constructor inherits from the base class `LM` and verifies the `credentials` for using Snowflake API.
29+
30+
```python
31+
class Snowflake(LM):
32+
def __init__(
33+
self,
34+
model,
35+
credentials,
36+
**kwargs):
37+
```
38+
39+
**Parameters:**
40+
- `model` (_str_): model hosted by [Snowflake Cortex](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#availability).
41+
- `credentials` (_dict_): connection parameters required to initialize a [snowflake snowpark session](https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session)
42+
43+
### Methods
44+
45+
Refer to [`dspy.Snowflake`](https://dspy-docs.vercel.app/api/language_model_clients/Snowflake) documentation.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
---
2+
sidebar_position:
3+
---
4+
5+
# retrieve.SnowflakeRM
6+
7+
### Constructor
8+
9+
Initialize an instance of the `SnowflakeRM` class, with the option to use `e5-base-v2` or `snowflake-arctic-embed-m` embeddings or any other Snowflake Cortex supported embeddings model.
10+
11+
```python
12+
SnowflakeRM(
13+
snowflake_table_name: str,
14+
snowflake_credentials: dict,
15+
k: int = 3,
16+
embeddings_field: str,
17+
embeddings_text_field:str,
18+
embeddings_model: str = "e5-base-v2",
19+
)
20+
```
21+
22+
**Parameters:**
23+
24+
- `snowflake_table_name (str)`: The name of the Snowflake table containing embeddings.
25+
- `snowflake_credentials (dict)`: The connection parameters needed to initialize a Snowflake Snowpark Session.
26+
- `k (int, optional)`: The number of top passages to retrieve. Defaults to 3.
27+
- `embeddings_field (str)`: The name of the column in the Snowflake table containing the embeddings.
28+
- `embeddings_text_field (str)`: The name of the column in the Snowflake table containing the passages.
29+
- `embeddings_model (str)`: The model to be used to convert text to embeddings
30+
31+
### Methods
32+
33+
#### `forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction`
34+
35+
Search the Snowflake table for the top `k` passages matching the given query or queries, using embeddings generated via the default `e5-base-v2` model or the specified `embedding_model`.
36+
37+
**Parameters:**
38+
39+
- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.
40+
- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.
41+
42+
**Returns:**
43+
44+
- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"id": str, "score": float, "long_text": str, "metadatas": dict }]`
45+
46+
### Quickstart
47+
48+
To support passage retrieval, it assumes that a Snowflake table has been created and populated with the passages in a column `embeddings_text_field` and the embeddings in another column `embeddings_field`
49+
50+
SnowflakeRM uses `e5-base-v2` embeddings model by default or any Snowflake Cortex supported embeddings model.
51+
52+
#### Default OpenAI Embeddings
53+
54+
```python
55+
from dspy.retrieve.snowflake_rm import SnowflakeRM
56+
import os
57+
58+
connection_parameters = {
59+
60+
"account": os.getenv('SNOWFLAKE_ACCOUNT'),
61+
"user": os.getenv('SNOWFLAKE_USER'),
62+
"password": os.getenv('SNOWFLAKE_PASSWORD'),
63+
"role": os.getenv('SNOWFLAKE_ROLE'),
64+
"warehouse": os.getenv('SNOWFLAKE_WAREHOUSE'),
65+
"database": os.getenv('SNOWFLAKE_DATABASE'),
66+
"schema": os.getenv('SNOWFLAKE_SCHEMA')}
67+
68+
retriever_model = SnowflakeRM(
69+
snowflake_table_name="<YOUR_SNOWFLAKE_TABLE_NAME>",
70+
snowflake_credentials=connection_parameters,
71+
embeddings_field="<YOUR_EMBEDDINGS_COLUMN_NAME>",
72+
embeddings_text_field= "<YOUR_PASSAGE_COLUMN_NAME>"
73+
)
74+
75+
results = retriever_model("Explore the meaning of life", k=5)
76+
77+
for result in results:
78+
print("Document:", result.long_text, "\n")
79+
```

dsp/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@
2323
from .pyserini import *
2424
from .sbert import *
2525
from .sentence_vectorizer import *
26+
from .snowflake import *
2627
from .watsonx import *
28+

dsp/modules/snowflake.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Module for interacting with Snowflake Cortex."""
2+
import json
3+
from typing import Any
4+
5+
import backoff
6+
from pydantic_core import PydanticCustomError
7+
8+
from dsp.modules.lm import LM
9+
10+
try:
11+
from snowflake.snowpark import Session
12+
from snowflake.snowpark import functions as snow_func
13+
14+
except ImportError:
15+
pass
16+
17+
18+
def backoff_hdlr(details) -> None:
19+
"""Handler from https://pypi.org/project/backoff ."""
20+
print(
21+
f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries ",
22+
f"calling function {details['target']} with kwargs",
23+
f"{details['kwargs']}",
24+
)
25+
26+
27+
def giveup_hdlr(details) -> bool:
28+
"""Wrapper function that decides when to give up on retry."""
29+
if "rate limits" in str(details):
30+
return False
31+
return True
32+
33+
34+
class Snowflake(LM):
35+
"""Wrapper around Snowflake's CortexAPI.
36+
37+
Currently supported models include 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b',
38+
'llama2-70b-chat','mistral-7b','gemma-7b','llama3-8b','llama3-70b','reka-core'.
39+
"""
40+
41+
def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs):
42+
"""Parameters
43+
44+
----------
45+
model : str
46+
Which pre-trained model from Snowflake to use?
47+
Choices are 'snowflake-arctic','mistral-large','reka-flash','mixtral-8x7b','llama2-70b-chat','mistral-7b','gemma-7b'
48+
Full list of supported models is available here: https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#complete
49+
credentials: dict
50+
Snowflake credentials required to initialize the session.
51+
Full list of requirements can be found here: https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.Session
52+
**kwargs: dict
53+
Additional arguments to pass to the API provider.
54+
"""
55+
super().__init__(model)
56+
57+
self.model = model
58+
cortex_models = [
59+
"llama3-8b",
60+
"llama3-70b",
61+
"reka-core",
62+
"snowflake-arctic",
63+
"mistral-large",
64+
"reka-flash",
65+
"mixtral-8x7b",
66+
"llama2-70b-chat",
67+
"mistral-7b",
68+
"gemma-7b",
69+
]
70+
71+
if model in cortex_models:
72+
self.available_args = {
73+
"max_tokens",
74+
"temperature",
75+
"top_p",
76+
}
77+
else:
78+
raise PydanticCustomError(
79+
"model",
80+
'model name is not valid, got "{model_name}"',
81+
)
82+
83+
self.client = self._init_cortex(credentials=credentials)
84+
self.provider = "Snowflake"
85+
self.history: list[dict[str, Any]] = []
86+
self.kwargs = {
87+
**self.kwargs,
88+
"temperature": 0.7,
89+
"max_output_tokens": 1024,
90+
"top_p": 1.0,
91+
"top_k": 1,
92+
**kwargs,
93+
}
94+
95+
@classmethod
96+
def _init_cortex(cls, credentials: dict) -> None:
97+
session = Session.builder.configs(credentials).create()
98+
session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}}
99+
100+
return session
101+
102+
def _prepare_params(
103+
self,
104+
parameters: Any,
105+
) -> dict:
106+
params_mapping = {"n": "candidate_count", "max_tokens": "max_output_tokens"}
107+
params = {params_mapping.get(k, k): v for k, v in parameters.items()}
108+
params = {**self.kwargs, **params}
109+
return {k: params[k] for k in set(params.keys()) & self.available_args}
110+
111+
def _cortex_complete_request(self, prompt: str, **kwargs) -> dict:
112+
complete = snow_func.builtin("snowflake.cortex.complete")
113+
cortex_complete_args = complete(
114+
snow_func.lit(self.model),
115+
snow_func.lit([{"role": "user", "content": prompt}]),
116+
snow_func.lit(kwargs),
117+
)
118+
raw_response = self.client.range(1).withColumn("complete_cal", cortex_complete_args).collect()
119+
120+
if len(raw_response) > 0:
121+
return json.loads(raw_response[0].COMPLETE_CAL)
122+
123+
else:
124+
return json.loads('{"choices": [{"messages": "None"}]}')
125+
126+
def basic_request(self, prompt: str, **kwargs) -> list:
127+
raw_kwargs = kwargs
128+
kwargs = self._prepare_params(raw_kwargs)
129+
130+
response = self._cortex_complete_request(prompt, **kwargs)
131+
132+
history = {
133+
"prompt": prompt,
134+
"response": {
135+
"prompt": prompt,
136+
"choices": [{"text": c} for c in response["choices"]],
137+
},
138+
"kwargs": kwargs,
139+
"raw_kwargs": raw_kwargs,
140+
}
141+
142+
self.history.append(history)
143+
144+
return [i["text"]["messages"] for i in history["response"]["choices"]]
145+
146+
@backoff.on_exception(
147+
backoff.expo,
148+
(Exception),
149+
max_time=1000,
150+
on_backoff=backoff_hdlr,
151+
giveup=giveup_hdlr,
152+
)
153+
def _request(self, prompt: str, **kwargs):
154+
"""Handles retrieval of completions from Snowflake Cortex whilst handling API errors."""
155+
return self.basic_request(prompt, **kwargs)
156+
157+
def __call__(
158+
self,
159+
prompt: str,
160+
only_completed: bool = True,
161+
return_sorted: bool = False,
162+
**kwargs,
163+
):
164+
return self._request(prompt, **kwargs)

dspy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Google = dsp.Google
2727
GoogleVertexAI = dsp.GoogleVertexAI
2828
GROQ = dsp.GroqLM
29+
Snowflake = dsp.Snowflake
2930
Claude = dsp.Claude
3031

3132
HFClientTGI = dsp.HFClientTGI

0 commit comments

Comments
 (0)