Skip to content

Commit ddd3418

Browse files
author
Krista Opsahl-Ong
committed
Merge branch 'main' of https://github.com/klopsahlong/dspy_prompt_opt into main
2 parents 9b960f0 + 98bf5c3 commit ddd3418

File tree

6 files changed

+172
-43
lines changed

6 files changed

+172
-43
lines changed

DSPy_Assert.pdf

236 KB
Binary file not shown.

dspy/primitives/assertions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def wrapper(*args, **kwargs):
169169
return wrapper
170170

171171

172-
def suggest_backtrack_handler(func, max_backtracks=2):
172+
def suggest_backtrack_handler(func, bypass_suggest=True, max_backtracks=2):
173173
"""Handler for backtracking suggestion.
174174
175175
Re-run the latest predictor up to `max_backtracks` times,
@@ -198,7 +198,7 @@ def wrapper(*args, **kwargs):
198198

199199
# if last backtrack: ignore suggestion errors
200200
if i == max_backtracks:
201-
result = bypass_suggest_handler(func)(*args, **kwargs)
201+
result = bypass_suggest_handler(func)(*args, **kwargs) if bypass_suggest else None
202202
break
203203

204204
else:

dspy/retrieve/pinecone_rm.py

Lines changed: 99 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Author: Dhar Rawal (@drawal1)
44
"""
55

6+
from dsp.utils import dotdict
67
from typing import Optional, List, Union
78
import openai
89
import dspy
@@ -30,6 +31,7 @@ class PineconeRM(dspy.Retrieve):
3031
pinecone_index_name (str): The name of the Pinecone index to query against.
3132
pinecone_api_key (str, optional): The Pinecone API key. Defaults to None.
3233
pinecone_env (str, optional): The Pinecone environment. Defaults to None.
34+
local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2".
3335
openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
3436
openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
3537
openai_org (str, optional): The organization for OpenAI. Defaults to None.
@@ -57,36 +59,62 @@ def __init__(
5759
pinecone_index_name: str,
5860
pinecone_api_key: Optional[str] = None,
5961
pinecone_env: Optional[str] = None,
60-
openai_embed_model: str = "text-embedding-ada-002",
62+
local_embed_model: Optional[str] = None,
63+
openai_embed_model: Optional[str] = "text-embedding-ada-002",
6164
openai_api_key: Optional[str] = None,
6265
openai_org: Optional[str] = None,
6366
k: int = 3,
6467
):
65-
self._openai_embed_model = openai_embed_model
68+
if local_embed_model is not None:
69+
try:
70+
import torch
71+
from transformers import AutoModel, AutoTokenizer
72+
except ImportError as exc:
73+
raise ModuleNotFoundError(
74+
"You need to install Hugging Face transformers library to use a local embedding model with PineconeRM."
75+
) from exc
76+
77+
self._local_embed_model = AutoModel.from_pretrained(local_embed_model)
78+
self._local_tokenizer = AutoTokenizer.from_pretrained(local_embed_model)
79+
self.use_local_model = True
80+
self.device = torch.device(
81+
'cuda:0' if torch.cuda.is_available() else
82+
'mps' if torch.backends.mps.is_available()
83+
else 'cpu'
84+
)
85+
elif openai_embed_model is not None:
86+
self._openai_embed_model = openai_embed_model
87+
self.use_local_model = False
88+
# If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
89+
if openai_api_key:
90+
openai.api_key = openai_api_key
91+
if openai_org:
92+
openai.organization = openai_org
93+
else:
94+
raise ValueError(
95+
"Either local_embed_model or openai_embed_model must be provided."
96+
)
97+
6698
self._pinecone_index = self._init_pinecone(
6799
pinecone_index_name, pinecone_api_key, pinecone_env
68100
)
69101

70-
# If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
71-
if openai_api_key:
72-
openai.api_key = openai_api_key
73-
if openai_org:
74-
openai.organization = openai_org
75-
76102
super().__init__(k=k)
77103

78104
def _init_pinecone(
79105
self,
80106
index_name: str,
81107
api_key: Optional[str] = None,
82108
environment: Optional[str] = None,
109+
dimension: Optional[int] = None,
110+
distance_metric: Optional[str] = None,
83111
) -> pinecone.Index:
84112
"""Initialize pinecone and return the loaded index.
85113
86114
Args:
87-
index_name (str): The name of the index to load.
115+
index_name (str): The name of the index to load. If the index is not does not exist, it will be created.
88116
api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided.
89-
environment (str, optional): The environment (ie. `us-west1-gcp`. Defaults to env PINECONE_ENVIRONMENT.
117+
environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT.
90118
91119
Raises:
92120
ValueError: If api_key or environment is not provided and not set as an environment variable.
@@ -103,14 +131,46 @@ def _init_pinecone(
103131
kwargs["environment"] = environment
104132
pinecone.init(**kwargs)
105133

106-
return pinecone.Index(index_name)
134+
active_indexes = pinecone.list_indexes()
135+
if index_name not in active_indexes:
136+
if dimension is None and distance_metric is None:
137+
raise ValueError(
138+
"dimension and distance_metric must be provided since the index provided does not exist."
139+
)
107140

141+
pinecone.create_index(
142+
name=index_name,
143+
dimension=dimension,
144+
metric=distance_metric,
145+
)
146+
147+
return pinecone.Index(index_name)
148+
149+
def _mean_pooling(
150+
self,
151+
model_output,
152+
attention_mask
153+
):
154+
try:
155+
import torch
156+
except ImportError as exc:
157+
raise ModuleNotFoundError(
158+
"You need to install torch to use a local embedding model with PineconeRM."
159+
) from exc
160+
161+
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
162+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
163+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
164+
108165
@backoff.on_exception(
109166
backoff.expo,
110167
(openai.error.RateLimitError, openai.error.ServiceUnavailableError),
111168
max_time=15,
112169
)
113-
def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
170+
def _get_embeddings(
171+
self,
172+
queries: List[str]
173+
) -> List[List[float]]:
114174
"""Return query vector after creating embedding using OpenAI
115175
116176
Args:
@@ -119,10 +179,30 @@ def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
119179
Returns:
120180
List[List[float]]: List of embeddings corresponding to each query.
121181
"""
122-
embedding = openai.Embedding.create(
123-
input=queries, model=self._openai_embed_model
124-
)
125-
return [embedding["embedding"] for embedding in embedding["data"]]
182+
try:
183+
import torch
184+
except ImportError as exc:
185+
raise ModuleNotFoundError(
186+
"You need to install torch to use a local embedding model with PineconeRM."
187+
) from exc
188+
189+
if not self.use_local_model:
190+
embedding = openai.Embedding.create(
191+
input=queries, model=self._openai_embed_model
192+
)
193+
return [embedding["embedding"] for embedding in embedding["data"]]
194+
195+
# Use local model
196+
encoded_input = self._local_tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(self.device)
197+
with torch.no_grad():
198+
model_output = self._local_embed_model(**encoded_input.to(self.device))
199+
200+
embeddings = self._mean_pooling(model_output, encoded_input['attention_mask'])
201+
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
202+
return normalized_embeddings.cpu().numpy().tolist()
203+
204+
# we need a pooling strategy to get a single vector representation of the input
205+
# so the default is to take the mean of the hidden states
126206

127207
def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
128208
"""Search with pinecone for self.k top passages for query
@@ -149,9 +229,10 @@ def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
149229

150230
# Sort results by score
151231
sorted_results = sorted(
152-
results_dict["matches"], key=lambda x: x["score"], reverse=True
232+
results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True
153233
)
154234
passages = [result["metadata"]["text"] for result in sorted_results]
235+
passages = [dotdict({"long_text": passage for passage in passages})]
155236
return dspy.Prediction(passages=passages)
156237

157238
# For multiple queries, query each and return the highest scoring passages
@@ -170,4 +251,4 @@ def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
170251
sorted_passages = sorted(
171252
passage_scores.items(), key=lambda x: x[1], reverse=True
172253
)[: self.k]
173-
return dspy.Prediction(passages=[passage for passage, _ in sorted_passages])
254+
return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages])

dspy/retrieve/qdrant_rm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import defaultdict
22
from typing import List, Union
33
import dspy
4+
from typing import Optional
45

56
try:
67
from qdrant_client import QdrantClient
@@ -21,7 +22,7 @@ class QdrantRM(dspy.Retrieve):
2122
Args:
2223
qdrant_collection_name (str): The name of the Qdrant collection.
2324
qdrant_client (QdrantClient): A QdrantClient instance.
24-
k (int, optional): The number of top passages to retrieve. Defaults to 3.
25+
k (int, optional): The default number of top passages to retrieve. Defaults to 3.
2526
2627
Examples:
2728
Below is a code snippet that shows how to use Qdrant as the default retriver:
@@ -51,12 +52,12 @@ def __init__(
5152

5253
super().__init__(k=k)
5354

54-
def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
55+
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction:
5556
"""Search with Qdrant for self.k top passages for query
5657
5758
Args:
5859
query_or_queries (Union[str, List[str]]): The query or queries to search for.
59-
60+
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.
6061
Returns:
6162
dspy.Prediction: An object containing the retrieved passages.
6263
"""
@@ -66,9 +67,10 @@ def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
6667
else query_or_queries
6768
)
6869
queries = [q for q in queries if q] # Filter empty queries
69-
70+
71+
k = k if k is not None else self.k
7072
batch_results = self._qdrant_client.query_batch(
71-
self._qdrant_collection_name, query_texts=queries, limit=self.k)
73+
self._qdrant_collection_name, query_texts=queries, limit=k)
7274

7375
passages = defaultdict(float)
7476
for batch in batch_results:
@@ -77,5 +79,5 @@ def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
7779
passages[result.document] += result.score
7880

7981
sorted_passages = sorted(
80-
passages.items(), key=lambda x: x[1], reverse=True)[:self.k]
82+
passages.items(), key=lambda x: x[1], reverse=True)[:k]
8183
return dspy.Prediction(passages=[passage for passage, _ in sorted_passages])

0 commit comments

Comments
 (0)