Skip to content

Commit 974e4ec

Browse files
authored
Merge pull request #562 from skucherlapati/claude_3
Adds native support for Claude models
2 parents edadecd + 805ed48 commit 974e4ec

File tree

4 files changed

+179
-1
lines changed

4 files changed

+179
-1
lines changed

dsp/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .anthropic import Claude
12
from .azure_openai import AzureOpenAI
23
from .bedrock import *
34
from .cache_utils import *

dsp/modules/anthropic.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import logging
2+
import os
3+
from typing import Any, Optional
4+
5+
import backoff
6+
7+
from dsp.modules.lm import LM
8+
9+
try:
10+
import anthropic
11+
anthropic_rate_limit = anthropic.RateLimitError
12+
except ImportError:
13+
anthropic_rate_limit = Exception
14+
15+
16+
logger = logging.getLogger(__name__)
17+
18+
BASE_URL = "https://api.anthropic.com/v1/messages"
19+
20+
21+
def backoff_hdlr(details):
22+
"""Handler from https://pypi.org/project/backoff/"""
23+
print(
24+
"Backing off {wait:0.1f} seconds after {tries} tries "
25+
"calling function {target} with kwargs "
26+
"{kwargs}".format(**details),
27+
)
28+
29+
30+
def giveup_hdlr(details):
31+
"""wrapper function that decides when to give up on retry"""
32+
if "rate limits" in details.message:
33+
return False
34+
return True
35+
36+
37+
class Claude(LM):
38+
"""Wrapper around anthropic's API. Supports both the Anthropic and Azure APIs."""
39+
def __init__(
40+
self,
41+
model: str = "claude-instant-1.2",
42+
api_key: Optional[str] = None,
43+
api_base: Optional[str] = None,
44+
**kwargs,
45+
):
46+
super().__init__(model)
47+
48+
try:
49+
from anthropic import Anthropic, RateLimitError
50+
except ImportError as err:
51+
raise ImportError("Claude requires `pip install anthropic`.") from err
52+
53+
self.provider = "anthropic"
54+
self.api_key = api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key
55+
self.api_base = BASE_URL if api_base is None else api_base
56+
57+
self.kwargs = {
58+
"temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"],
59+
"max_tokens": min(kwargs.get("max_tokens", 4096), 4096),
60+
"top_p": 1.0 if "top_p" not in kwargs else kwargs["top_p"],
61+
"top_k": 1 if "top_k" not in kwargs else kwargs["top_k"],
62+
"n": kwargs.pop("n", kwargs.pop("num_generations", 1)),
63+
**kwargs,
64+
}
65+
self.kwargs["model"] = model
66+
self.history: list[dict[str, Any]] = []
67+
self.client = Anthropic(api_key=api_key)
68+
69+
def log_usage(self, response):
70+
"""Log the total tokens from the Anthropic API response."""
71+
usage_data = response.usage
72+
if usage_data:
73+
total_tokens = usage_data.input_tokens + usage_data.output_tokens
74+
logger.info(f'{total_tokens}')
75+
76+
def basic_request(self, prompt: str, **kwargs):
77+
raw_kwargs = kwargs
78+
79+
kwargs = {**self.kwargs, **kwargs}
80+
# caching mechanism requires hashable kwargs
81+
kwargs["messages"] = [{"role": "user", "content": prompt}]
82+
kwargs.pop("n")
83+
print(kwargs)
84+
response = self.client.messages.create(**kwargs)
85+
86+
history = {
87+
"prompt": prompt,
88+
"response": response,
89+
"kwargs": kwargs,
90+
"raw_kwargs": raw_kwargs,
91+
}
92+
self.history.append(history)
93+
94+
return response
95+
96+
@backoff.on_exception(
97+
backoff.expo,
98+
(anthropic_rate_limit),
99+
max_time=1000,
100+
max_tries=8,
101+
on_backoff=backoff_hdlr,
102+
giveup=giveup_hdlr,
103+
)
104+
def request(self, prompt: str, **kwargs):
105+
"""Handles retrieval of completions from Anthropic whilst handling API errors"""
106+
return self.basic_request(prompt, **kwargs)
107+
108+
def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
109+
"""Retrieves completions from Anthropic.
110+
111+
Args:
112+
prompt (str): prompt to send to Anthropic
113+
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
114+
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.
115+
116+
Returns:
117+
list[str]: list of completion choices
118+
"""
119+
120+
assert only_completed, "for now"
121+
assert return_sorted is False, "for now"
122+
123+
124+
# per eg here: https://docs.anthropic.com/claude/reference/messages-examples
125+
# max tokens can be used as a proxy to return smaller responses
126+
# so this cannot be a proper indicator for incomplete response unless it isnt the user-intent.
127+
# if only_completed and response.stop_reason != "end_turn":
128+
# choices = []
129+
130+
n = kwargs.pop("n", 1)
131+
completions = []
132+
for i in range(n):
133+
response = self.request(prompt, **kwargs)
134+
# TODO: Log llm usage instead of hardcoded openai usage
135+
# if dsp.settings.log_openai_usage:
136+
# self.log_usage(response)
137+
if only_completed and response.stop_reason == "max_tokens":
138+
continue
139+
completions = [c.text for c in response.content]
140+
return completions

poetry.lock

Lines changed: 36 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies = [
3434
]
3535

3636
[project.optional-dependencies]
37+
anthropic = ["anthropic~=0.18.0"]
3738
chromadb = ["chromadb~=0.4.14"]
3839
qdrant = ["qdrant-client~=1.6.2", "fastembed~=0.1.0"]
3940
marqo = ["marqo"]
@@ -84,6 +85,7 @@ tqdm = "^4.66.1"
8485
datasets = "^2.14.6"
8586
requests = "^2.31.0"
8687
optuna = "^3.4.0"
88+
anthropic = { version = "^0.18.0", optional = true }
8789
chromadb = { version = "^0.4.14", optional = true }
8890
fastembed = { version = "^0.1.0", optional = true }
8991
marqo = { version = "*", optional = true }

0 commit comments

Comments
 (0)