Skip to content

Commit f6cabb3

Browse files
committed
making predictinos fast
1 parent 1006f58 commit f6cabb3

File tree

6 files changed

+184
-15
lines changed

6 files changed

+184
-15
lines changed

dsp/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from .databricks import *
99
from .google import *
1010
from .gpt3 import *
11+
from .groq_client import *
1112
from .hf import HFModel
1213
from .hf_client import Anyscale, HFClientTGI, Together
1314
from .ollama import *
1415
from .pyserini import *
1516
from .sbert import *
1617
from .sentence_vectorizer import *
18+

dsp/modules/groq.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

dsp/modules/groq_client.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import logging
2+
import os
3+
import json
4+
from typing import Any, Literal, Optional
5+
import backoff
6+
from groq import Groq, AsyncGroq
7+
import groq
8+
import functools
9+
10+
11+
import dsp
12+
from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
13+
from dsp.modules.lm import LM
14+
15+
16+
# Configure logging
17+
logging.basicConfig(
18+
level=logging.INFO,
19+
format="%(message)s",
20+
handlers=[logging.FileHandler("groq_usage.log")],
21+
)
22+
23+
24+
25+
def backoff_hdlr(details):
26+
"""Handler from https://pypi.org/project/backoff/"""
27+
print(
28+
"Backing off {wait:0.1f} seconds after {tries} tries "
29+
"calling function {target} with kwargs "
30+
"{kwargs}".format(**details),
31+
)
32+
33+
34+
class GroqLM(LM):
35+
"""Wrapper around groq's API.
36+
37+
Args:
38+
model (str, optional): groq supported LLM model to use. Defaults to "mixtral-8x7b-32768".
39+
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None.
40+
**kwargs: Additional arguments to pass to the API provider.
41+
"""
42+
43+
def __init__(
44+
self,
45+
model: str = "mixtral-8x7b-32768",
46+
api_key: Optional[str] = None,
47+
**kwargs,
48+
):
49+
super().__init__(model)
50+
51+
if api_key:
52+
self.api_key = api_key
53+
self.client = Groq(api_key = api_key)
54+
55+
self.kwargs = {
56+
"temperature": 0.0,
57+
"max_tokens": 150,
58+
"top_p": 1,
59+
"frequency_penalty": 0,
60+
"presence_penalty": 0,
61+
"n": 1,
62+
**kwargs,
63+
}
64+
models = self.client.models.list().data
65+
if models is not None:
66+
if model in [m.id for m in models]:
67+
self.kwargs["model"] = model
68+
self.history: list[dict[str, Any]] = []
69+
70+
71+
def log_usage(self, response):
72+
"""Log the total tokens from the Groq API response."""
73+
usage_data = response.get("usage")
74+
if usage_data:
75+
total_tokens = usage_data.get("total_tokens")
76+
logging.info(f"{total_tokens}")
77+
78+
def basic_request(self, prompt: str, **kwargs):
79+
raw_kwargs = kwargs
80+
81+
kwargs = {**self.kwargs, **kwargs}
82+
83+
kwargs["messages"] = [{"role": "user", "content": prompt}]
84+
response = self.chat_request(**kwargs)
85+
86+
history = {
87+
"prompt": prompt,
88+
"response": response,
89+
"kwargs": kwargs,
90+
"raw_kwargs": raw_kwargs,
91+
}
92+
93+
self.history.append(history)
94+
95+
return response
96+
97+
@backoff.on_exception(
98+
backoff.expo,
99+
groq.RateLimitError,
100+
max_time=1000,
101+
on_backoff=backoff_hdlr,
102+
)
103+
def request(self, prompt: str, **kwargs):
104+
"""Handles retreival of model completions whilst handling rate limiting and caching."""
105+
if "model_type" in kwargs:
106+
del kwargs["model_type"]
107+
108+
return self.basic_request(prompt, **kwargs)
109+
110+
def _get_choice_text(self, choice) -> str:
111+
return choice.message.content
112+
113+
def chat_request(self, **kwargs):
114+
"""Handles retreival of model completions whilst handling rate limiting and caching."""
115+
response = self.client.chat.completions.create(**kwargs)
116+
return response
117+
118+
def __call__(
119+
self,
120+
prompt: str,
121+
only_completed: bool = True,
122+
return_sorted: bool = False,
123+
**kwargs,
124+
) -> list[dict[str, Any]]:
125+
"""Retrieves completions from model.
126+
127+
Args:
128+
prompt (str): prompt to send to model
129+
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
130+
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.
131+
132+
Returns:
133+
list[dict[str, Any]]: list of completion choices
134+
"""
135+
136+
assert only_completed, "for now"
137+
assert return_sorted is False, "for now"
138+
response = self.request(prompt, **kwargs)
139+
140+
if dsp.settings.log_openai_usage:
141+
self.log_usage(response)
142+
143+
choices = response.choices
144+
145+
completions = [self._get_choice_text(c) for c in choices]
146+
if return_sorted and kwargs.get("n", 1) > 1:
147+
scored_completions = []
148+
149+
for c in choices:
150+
tokens, logprobs = (
151+
c["logprobs"]["tokens"],
152+
c["logprobs"]["token_logprobs"],
153+
)
154+
155+
if "<|endoftext|>" in tokens:
156+
index = tokens.index("<|endoftext|>") + 1
157+
tokens, logprobs = tokens[:index], logprobs[:index]
158+
159+
avglog = sum(logprobs) / len(logprobs)
160+
scored_completions.append((avglog, self._get_choice_text(c)))
161+
162+
scored_completions = sorted(scored_completions, reverse=True)
163+
completions = [c for _, c in scored_completions]
164+
165+
return completions
166+

dspy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
Pyserini = dsp.PyseriniRetriever
2020
Clarifai = dsp.ClarifaiLLM
2121
Google = dsp.Google
22+
GROQ = dsp.GroqLM
2223

2324
HFClientTGI = dsp.HFClientTGI
2425
HFClientVLLM = HFClientVLLM
2526

27+
2628
Anyscale = dsp.Anyscale
2729
Together = dsp.Together
2830
HFModel = dsp.HFModel

poetry.lock

Lines changed: 13 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ python = ">=3.9,<3.12"
8282
pydantic = "2.5.0"
8383
backoff = "^2.2.1"
8484
joblib = "^1.3.2"
85-
openai = ">=0.28.1,<2.0.0"
85+
openai = "0.28.1"
8686
pandas = "^2.1.1"
8787
regex = "^2023.10.3"
8888
ujson = "^5.8.0"

0 commit comments

Comments
 (0)