Skip to content

Commit 797cf94

Browse files
committed
Merge branch 'mongodb-atlas-rm'
2 parents 8b2f767 + 1864168 commit 797cf94

File tree

9 files changed

+1264
-79
lines changed

9 files changed

+1264
-79
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Or open our intro notebook in Google Colab: [<img align="center" src="https://co
5858

5959
> _Note: If you're looking for Demonstrate-Search-Predict (DSP), which is the previous version of DSPy, you can find it on the [v1](https://github.com/stanfordnlp/dspy/tree/v1) branch of this repo._
6060
61+
By default, DSPy depends on `openai==0.28`. However, if you install `openai>=1.0`, the library will use that just fine. Both are supported.
6162

6263
For the optional Pinecone, Qdrant, [chromadb](https://github.com/chroma-core/chroma), or [marqo](https://github.com/marqo-ai/marqo) retrieval integration(s), include the extra(s) below:
6364

dsp/evaluation/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from openai import InvalidRequestError
2-
from openai.error import APIError
31

42
import dsp
53
import tqdm

dsp/modules/gpt3.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,23 @@
44

55
import backoff
66
import openai
7-
import openai.error
8-
from openai.openai_object import OpenAIObject
97

108
from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
119
from dsp.modules.lm import LM
1210

11+
try:
12+
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
13+
except Exception:
14+
OPENAI_LEGACY = True
15+
16+
try:
17+
from openai.openai_object import OpenAIObject
18+
import openai.error
19+
ERRORS = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError)
20+
except Exception:
21+
ERRORS = (openai.RateLimitError, openai.APIError)
22+
OpenAIObject = dict
23+
1324

1425
def backoff_hdlr(details):
1526
"""Handler from https://pypi.org/project/backoff/"""
@@ -36,13 +47,19 @@ def __init__(
3647
model: str = "gpt-3.5-turbo-instruct",
3748
api_key: Optional[str] = None,
3849
api_provider: Literal["openai", "azure"] = "openai",
50+
api_base: Optional[str] = None,
3951
model_type: Literal["chat", "text"] = None,
4052
**kwargs,
4153
):
4254
super().__init__(model)
4355
self.provider = "openai"
4456

45-
default_model_type = "chat" if ('gpt-3.5' in model or 'turbo' in model or 'gpt-4' in model) and ('instruct' not in model) else "text"
57+
default_model_type = (
58+
"chat"
59+
if ("gpt-3.5" in model or "turbo" in model or "gpt-4" in model)
60+
and ("instruct" not in model)
61+
else "text"
62+
)
4663
self.model_type = model_type if model_type else default_model_type
4764

4865
if api_provider == "azure":
@@ -58,8 +75,8 @@ def __init__(
5875
if api_key:
5976
openai.api_key = api_key
6077

61-
if kwargs.get("api_base"):
62-
openai.api_base = kwargs["api_base"]
78+
if api_base:
79+
openai.base_url = api_base
6380

6481
self.kwargs = {
6582
"temperature": 0.0,
@@ -70,29 +87,27 @@ def __init__(
7087
"n": 1,
7188
**kwargs,
7289
} # TODO: add kwargs above for </s>
73-
90+
7491
if api_provider != "azure":
7592
self.kwargs["model"] = model
7693
self.history: list[dict[str, Any]] = []
7794

78-
def _openai_client():
95+
def _openai_client(self):
7996
return openai
8097

81-
def basic_request(self, prompt: str, **kwargs) -> OpenAIObject:
98+
def basic_request(self, prompt: str, **kwargs):
8299
raw_kwargs = kwargs
83100

84101
kwargs = {**self.kwargs, **kwargs}
85102
if self.model_type == "chat":
86103
# caching mechanism requires hashable kwargs
87104
kwargs["messages"] = [{"role": "user", "content": prompt}]
88-
kwargs = {
89-
"stringify_request": json.dumps(kwargs)
90-
}
91-
response = cached_gpt3_turbo_request(**kwargs)
92-
105+
kwargs = {"stringify_request": json.dumps(kwargs)}
106+
response = chat_request(**kwargs)
107+
93108
else:
94109
kwargs["prompt"] = prompt
95-
response = cached_gpt3_request(**kwargs)
110+
response = completions_request(**kwargs)
96111

97112
history = {
98113
"prompt": prompt,
@@ -106,15 +121,15 @@ def basic_request(self, prompt: str, **kwargs) -> OpenAIObject:
106121

107122
@backoff.on_exception(
108123
backoff.expo,
109-
(openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError),
124+
ERRORS,
110125
max_time=1000,
111126
on_backoff=backoff_hdlr,
112127
)
113-
def request(self, prompt: str, **kwargs) -> OpenAIObject:
128+
def request(self, prompt: str, **kwargs):
114129
"""Handles retreival of GPT-3 completions whilst handling rate limiting and caching."""
115130
if "model_type" in kwargs:
116131
del kwargs["model_type"]
117-
132+
118133
return self.basic_request(prompt, **kwargs)
119134

120135
def _get_choice_text(self, choice: dict[str, Any]) -> str:
@@ -150,6 +165,7 @@ def __call__(
150165
# kwargs = {**kwargs, "logprobs": 5}
151166

152167
response = self.request(prompt, **kwargs)
168+
153169
choices = response["choices"]
154170

155171
completed_choices = [c for c in choices if c["finish_reason"] != "length"]
@@ -158,7 +174,6 @@ def __call__(
158174
choices = completed_choices
159175

160176
completions = [self._get_choice_text(c) for c in choices]
161-
162177
if return_sorted and kwargs.get("n", 1) > 1:
163178
scored_completions = []
164179

@@ -181,31 +196,57 @@ def __call__(
181196
return completions
182197

183198

199+
184200
@CacheMemory.cache
185201
def cached_gpt3_request_v2(**kwargs):
186202
return openai.Completion.create(**kwargs)
187203

188-
189204
@functools.lru_cache(maxsize=None if cache_turn_on else 0)
190205
@NotebookCacheMemory.cache
191206
def cached_gpt3_request_v2_wrapped(**kwargs):
192207
return cached_gpt3_request_v2(**kwargs)
193208

194-
195-
cached_gpt3_request = cached_gpt3_request_v2_wrapped
196-
197-
198209
@CacheMemory.cache
199210
def _cached_gpt3_turbo_request_v2(**kwargs) -> OpenAIObject:
200211
if "stringify_request" in kwargs:
201212
kwargs = json.loads(kwargs["stringify_request"])
202213
return cast(OpenAIObject, openai.ChatCompletion.create(**kwargs))
203214

204-
205215
@functools.lru_cache(maxsize=None if cache_turn_on else 0)
206216
@NotebookCacheMemory.cache
207217
def _cached_gpt3_turbo_request_v2_wrapped(**kwargs) -> OpenAIObject:
208218
return _cached_gpt3_turbo_request_v2(**kwargs)
209219

220+
@CacheMemory.cache
221+
def v1_cached_gpt3_request_v2(**kwargs):
222+
return openai.completions.create(**kwargs)
223+
224+
@functools.lru_cache(maxsize=None if cache_turn_on else 0)
225+
@NotebookCacheMemory.cache
226+
def v1_cached_gpt3_request_v2_wrapped(**kwargs):
227+
return v1_cached_gpt3_request_v2(**kwargs)
228+
229+
@CacheMemory.cache
230+
def v1_cached_gpt3_turbo_request_v2(**kwargs):
231+
if "stringify_request" in kwargs:
232+
kwargs = json.loads(kwargs["stringify_request"])
233+
return openai.chat.completions.create(**kwargs)
234+
235+
@functools.lru_cache(maxsize=None if cache_turn_on else 0)
236+
@NotebookCacheMemory.cache
237+
def v1_cached_gpt3_turbo_request_v2_wrapped(**kwargs):
238+
return v1_cached_gpt3_turbo_request_v2(**kwargs)
239+
240+
241+
242+
def chat_request(**kwargs):
243+
if OPENAI_LEGACY:
244+
return _cached_gpt3_turbo_request_v2_wrapped(**kwargs)
245+
246+
return v1_cached_gpt3_turbo_request_v2_wrapped(**kwargs).model_dump()
247+
248+
def completions_request(**kwargs):
249+
if OPENAI_LEGACY:
250+
return cached_gpt3_request_v2_wrapped(**kwargs)
210251

211-
cached_gpt3_turbo_request = _cached_gpt3_turbo_request_v2_wrapped
252+
return v1_cached_gpt3_request_v2_wrapped(**kwargs).model_dump()

dsp/modules/sentence_vectorizer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
109109
return embeddings
110110

111111

112+
try:
113+
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
114+
except Exception:
115+
OPENAI_LEGACY = True
116+
117+
112118
class OpenAIVectorizer(BaseSentenceVectorizer):
113119
'''
114120
This vectorizer uses OpenAI API to convert texts to embeddings. Changing `model` is not
@@ -124,6 +130,11 @@ def __init__(
124130
self.model = model
125131
self.embed_batch_size = embed_batch_size
126132

133+
if OPENAI_LEGACY:
134+
self.Embedding = openai.Embedding
135+
else:
136+
self.Embedding = openai.embeddings
137+
127138
if api_key:
128139
openai.api_key = api_key
129140

@@ -138,7 +149,7 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
138149
end_idx = (cur_batch_idx + 1) * self.embed_batch_size
139150
cur_batch = text_to_vectorize[start_idx: end_idx]
140151
# OpenAI API call:
141-
response = openai.Embedding.create(
152+
response = self.Embedding.create(
142153
model=self.model,
143154
input=cur_batch
144155
)
@@ -147,4 +158,4 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
147158
embeddings_list.extend(cur_batch_embeddings)
148159

149160
embeddings = np.array(embeddings_list, dtype=np.float32)
150-
return embeddings
161+
return embeddings

dspy/retrieve/chromadb_rm.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
try:
1212
import chromadb
1313
from chromadb.config import Settings
14+
from chromadb.utils import embedding_functions
1415
except ImportError:
1516
chromadb = None
1617

@@ -70,17 +71,13 @@ def __init__(
7071

7172
self._init_chromadb(collection_name, persist_directory)
7273

73-
# If not provided, defaults to env vars
74-
if openai_api_key:
75-
openai.api_key = openai_api_key
76-
if openai_api_type:
77-
openai.api_type = openai_api_type
78-
if openai_api_base:
79-
openai.api_base = openai_api_base
80-
if openai_api_version:
81-
openai.api_version = openai_api_version
82-
if openai_api_provider:
83-
self._openai_api_provider = openai_api_provider
74+
self.openai_ef = embedding_functions.OpenAIEmbeddingFunction(
75+
api_key=openai_api_key,
76+
api_base=openai_api_base,
77+
api_type=openai_api_type,
78+
api_version=openai_api_version,
79+
model_name=openai_embed_model,
80+
)
8481

8582
super().__init__(k=k)
8683

@@ -111,7 +108,7 @@ def _init_chromadb(
111108

112109
@backoff.on_exception(
113110
backoff.expo,
114-
(openai.error.RateLimitError, openai.error.ServiceUnavailableError),
111+
(openai.RateLimitError),
115112
max_time=15,
116113
)
117114
def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
@@ -124,24 +121,10 @@ def _get_embeddings(self, queries: List[str]) -> List[List[float]]:
124121
List[List[float]]: List of embeddings corresponding to each query.
125122
"""
126123

127-
if self._openai_api_provider == "azure":
128-
model_args = {
129-
"engine": self._openai_embed_model,
130-
"deployment_id": self._openai_embed_model,
131-
"api_version": openai.api_version,
132-
"api_base": openai.api_base,
133-
}
134-
embedding = openai.Embedding.create(
135-
input=queries,
136-
model=self._openai_embed_model,
137-
**model_args,
138-
api_provider=self._openai_api_provider
139-
)
140-
else:
141-
embedding = openai.Embedding.create(
142-
input=queries, model=self._openai_embed_model
143-
)
144-
return [embedding["embedding"] for embedding in embedding["data"]]
124+
embedding = self.openai_ef._client.create(
125+
input=queries, model=self._openai_embed_model
126+
)
127+
return [embedding.embedding for embedding in embedding.data]
145128

146129
def forward(
147130
self, query_or_queries: Union[str, List[str]], k: Optional[int] = None

dspy/retrieve/mongodb_atlas_rm.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1-
from typing import List, Union, Any
1+
from typing import List, Optional, Union, Any
22
import dspy
33
import os
4-
import openai
4+
from openai import (
5+
OpenAI,
6+
APITimeoutError,
7+
InternalServerError,
8+
RateLimitError,
9+
UnprocessableEntityError,
10+
)
511
import backoff
612

713
try:
@@ -39,24 +45,25 @@ def build_vector_search_pipeline(
3945
class Embedder:
4046
def __init__(self, provider: str, model: str):
4147
if provider == "openai":
42-
openai.api_key = os.getenv("OPENAI_API_KEY")
43-
if not openai.api_key:
48+
api_key = os.getenv("OPENAI_API_KEY")
49+
if not api_key:
4450
raise ValueError("Environment variable OPENAI_API_KEY must be set")
45-
self.client = openai
51+
self.client = OpenAI()
4652
self.model = model
4753

4854
@backoff.on_exception(
4955
backoff.expo,
5056
(
51-
openai.error.RateLimitError,
52-
openai.error.ServiceUnavailableError,
53-
openai.error.APIError,
57+
APITimeoutError,
58+
InternalServerError,
59+
RateLimitError,
60+
UnprocessableEntityError,
5461
),
5562
max_time=15,
5663
)
5764
def __call__(self, queries) -> Any:
58-
embedding = self.client.Embedding.create(input=queries, model=self.model)
59-
return [embedding["embedding"] for embedding in embedding["data"]]
65+
embedding = self.client.embeddings.create(input=queries, model=self.model)
66+
return [result.embedding for result in embedding.data]
6067

6168

6269
class MongoDBAtlasRM(dspy.Retrieve):
@@ -98,13 +105,8 @@ def __init__(
98105

99106
self.embedder = Embedder(provider=embedding_provider, model=embedding_model)
100107

101-
def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
102-
queries = (
103-
[query_or_queries]
104-
if isinstance(query_or_queries, str)
105-
else query_or_queries
106-
)
107-
query_vector = self.embedder(queries)
108+
def forward(self, query_or_queries: str) -> dspy.Prediction:
109+
query_vector = self.embedder([query_or_queries])
108110
pipeline = build_vector_search_pipeline(
109111
index_name=self.index_name,
110112
query_vector=query_vector[0],

0 commit comments

Comments
 (0)