1+ import logging
2+ import os
3+ from typing import Any , Optional
4+
5+ import backoff
6+
7+ from dsp .modules .lm import LM
8+
9+ try :
10+ import anthropic
11+ anthropic_rate_limit = anthropic .RateLimitError
12+ except ImportError :
13+ anthropic_rate_limit = Exception
14+
15+
16+ logger = logging .getLogger (__name__ )
17+
18+ BASE_URL = "https://api.anthropic.com/v1/messages"
19+
20+
21+ def backoff_hdlr (details ):
22+ """Handler from https://pypi.org/project/backoff/"""
23+ print (
24+ "Backing off {wait:0.1f} seconds after {tries} tries "
25+ "calling function {target} with kwargs "
26+ "{kwargs}" .format (** details ),
27+ )
28+
29+
30+ def giveup_hdlr (details ):
31+ """wrapper function that decides when to give up on retry"""
32+ if "rate limits" in details .message :
33+ return False
34+ return True
35+
36+
37+ class Claude (LM ):
38+ """Wrapper around anthropic's API. Supports both the Anthropic and Azure APIs."""
39+ def __init__ (
40+ self ,
41+ model : str = "claude-instant-1.2" ,
42+ api_key : Optional [str ] = None ,
43+ api_base : Optional [str ] = None ,
44+ ** kwargs ,
45+ ):
46+ super ().__init__ (model )
47+
48+ try :
49+ from anthropic import Anthropic , RateLimitError
50+ except ImportError as err :
51+ raise ImportError ("Claude requires `pip install anthropic`." ) from err
52+
53+ self .provider = "anthropic"
54+ self .api_key = api_key = os .environ .get ("ANTHROPIC_API_KEY" ) if api_key is None else api_key
55+ self .api_base = BASE_URL if api_base is None else api_base
56+
57+ self .kwargs = {
58+ "temperature" : 0.0 if "temperature" not in kwargs else kwargs ["temperature" ],
59+ "max_tokens" : min (kwargs .get ("max_tokens" , 4096 ), 4096 ),
60+ "top_p" : 1.0 if "top_p" not in kwargs else kwargs ["top_p" ],
61+ "top_k" : 1 if "top_k" not in kwargs else kwargs ["top_k" ],
62+ "n" : kwargs .pop ("n" , kwargs .pop ("num_generations" , 1 )),
63+ ** kwargs ,
64+ }
65+ self .kwargs ["model" ] = model
66+ self .history : list [dict [str , Any ]] = []
67+ self .client = Anthropic (api_key = api_key )
68+
69+ def log_usage (self , response ):
70+ """Log the total tokens from the Anthropic API response."""
71+ usage_data = response .usage
72+ if usage_data :
73+ total_tokens = usage_data .input_tokens + usage_data .output_tokens
74+ logger .info (f'{ total_tokens } ' )
75+
76+ def basic_request (self , prompt : str , ** kwargs ):
77+ raw_kwargs = kwargs
78+
79+ kwargs = {** self .kwargs , ** kwargs }
80+ # caching mechanism requires hashable kwargs
81+ kwargs ["messages" ] = [{"role" : "user" , "content" : prompt }]
82+ kwargs .pop ("n" )
83+ print (kwargs )
84+ response = self .client .messages .create (** kwargs )
85+
86+ history = {
87+ "prompt" : prompt ,
88+ "response" : response ,
89+ "kwargs" : kwargs ,
90+ "raw_kwargs" : raw_kwargs ,
91+ }
92+ self .history .append (history )
93+
94+ return response
95+
96+ @backoff .on_exception (
97+ backoff .expo ,
98+ (anthropic_rate_limit ),
99+ max_time = 1000 ,
100+ max_tries = 8 ,
101+ on_backoff = backoff_hdlr ,
102+ giveup = giveup_hdlr ,
103+ )
104+ def request (self , prompt : str , ** kwargs ):
105+ """Handles retrieval of completions from Anthropic whilst handling API errors"""
106+ return self .basic_request (prompt , ** kwargs )
107+
108+ def __call__ (self , prompt , only_completed = True , return_sorted = False , ** kwargs ):
109+ """Retrieves completions from Anthropic.
110+
111+ Args:
112+ prompt (str): prompt to send to Anthropic
113+ only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
114+ return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.
115+
116+ Returns:
117+ list[str]: list of completion choices
118+ """
119+
120+ assert only_completed , "for now"
121+ assert return_sorted is False , "for now"
122+
123+
124+ # per eg here: https://docs.anthropic.com/claude/reference/messages-examples
125+ # max tokens can be used as a proxy to return smaller responses
126+ # so this cannot be a proper indicator for incomplete response unless it isnt the user-intent.
127+ # if only_completed and response.stop_reason != "end_turn":
128+ # choices = []
129+
130+ n = kwargs .pop ("n" , 1 )
131+ completions = []
132+ for i in range (n ):
133+ response = self .request (prompt , ** kwargs )
134+ # TODO: Log llm usage instead of hardcoded openai usage
135+ # if dsp.settings.log_openai_usage:
136+ # self.log_usage(response)
137+ if only_completed and response .stop_reason == "max_tokens" :
138+ continue
139+ completions = [c .text for c in response .content ]
140+ return completions
0 commit comments