44
55import backoff
66import openai
7- import openai .error
8- from openai .openai_object import OpenAIObject
97
108from dsp .modules .cache_utils import CacheMemory , NotebookCacheMemory , cache_turn_on
119from dsp .modules .lm import LM
1210
11+ try :
12+ OPENAI_LEGACY = int (openai .version .__version__ [0 ]) == 0
13+ except Exception :
14+ OPENAI_LEGACY = True
15+
16+ try :
17+ from openai .openai_object import OpenAIObject
18+ import openai .error
19+ ERRORS = (openai .error .RateLimitError , openai .error .ServiceUnavailableError , openai .error .APIError )
20+ except Exception :
21+ ERRORS = (openai .RateLimitError , openai .APIError )
22+ OpenAIObject = dict
23+
1324
1425def backoff_hdlr (details ):
1526 """Handler from https://pypi.org/project/backoff/"""
@@ -36,13 +47,19 @@ def __init__(
3647 model : str = "gpt-3.5-turbo-instruct" ,
3748 api_key : Optional [str ] = None ,
3849 api_provider : Literal ["openai" , "azure" ] = "openai" ,
50+ api_base : Optional [str ] = None ,
3951 model_type : Literal ["chat" , "text" ] = None ,
4052 ** kwargs ,
4153 ):
4254 super ().__init__ (model )
4355 self .provider = "openai"
4456
45- default_model_type = "chat" if ('gpt-3.5' in model or 'turbo' in model or 'gpt-4' in model ) and ('instruct' not in model ) else "text"
57+ default_model_type = (
58+ "chat"
59+ if ("gpt-3.5" in model or "turbo" in model or "gpt-4" in model )
60+ and ("instruct" not in model )
61+ else "text"
62+ )
4663 self .model_type = model_type if model_type else default_model_type
4764
4865 if api_provider == "azure" :
@@ -58,8 +75,8 @@ def __init__(
5875 if api_key :
5976 openai .api_key = api_key
6077
61- if kwargs . get ( " api_base" ) :
62- openai .api_base = kwargs [ " api_base" ]
78+ if api_base :
79+ openai .base_url = api_base
6380
6481 self .kwargs = {
6582 "temperature" : 0.0 ,
@@ -70,29 +87,27 @@ def __init__(
7087 "n" : 1 ,
7188 ** kwargs ,
7289 } # TODO: add kwargs above for </s>
73-
90+
7491 if api_provider != "azure" :
7592 self .kwargs ["model" ] = model
7693 self .history : list [dict [str , Any ]] = []
7794
78- def _openai_client ():
95+ def _openai_client (self ):
7996 return openai
8097
81- def basic_request (self , prompt : str , ** kwargs ) -> OpenAIObject :
98+ def basic_request (self , prompt : str , ** kwargs ):
8299 raw_kwargs = kwargs
83100
84101 kwargs = {** self .kwargs , ** kwargs }
85102 if self .model_type == "chat" :
86103 # caching mechanism requires hashable kwargs
87104 kwargs ["messages" ] = [{"role" : "user" , "content" : prompt }]
88- kwargs = {
89- "stringify_request" : json .dumps (kwargs )
90- }
91- response = cached_gpt3_turbo_request (** kwargs )
92-
105+ kwargs = {"stringify_request" : json .dumps (kwargs )}
106+ response = chat_request (** kwargs )
107+
93108 else :
94109 kwargs ["prompt" ] = prompt
95- response = cached_gpt3_request (** kwargs )
110+ response = completions_request (** kwargs )
96111
97112 history = {
98113 "prompt" : prompt ,
@@ -106,15 +121,15 @@ def basic_request(self, prompt: str, **kwargs) -> OpenAIObject:
106121
107122 @backoff .on_exception (
108123 backoff .expo ,
109- ( openai . error . RateLimitError , openai . error . ServiceUnavailableError , openai . error . APIError ) ,
124+ ERRORS ,
110125 max_time = 1000 ,
111126 on_backoff = backoff_hdlr ,
112127 )
113- def request (self , prompt : str , ** kwargs ) -> OpenAIObject :
128+ def request (self , prompt : str , ** kwargs ):
114129 """Handles retreival of GPT-3 completions whilst handling rate limiting and caching."""
115130 if "model_type" in kwargs :
116131 del kwargs ["model_type" ]
117-
132+
118133 return self .basic_request (prompt , ** kwargs )
119134
120135 def _get_choice_text (self , choice : dict [str , Any ]) -> str :
@@ -150,6 +165,7 @@ def __call__(
150165 # kwargs = {**kwargs, "logprobs": 5}
151166
152167 response = self .request (prompt , ** kwargs )
168+
153169 choices = response ["choices" ]
154170
155171 completed_choices = [c for c in choices if c ["finish_reason" ] != "length" ]
@@ -158,7 +174,6 @@ def __call__(
158174 choices = completed_choices
159175
160176 completions = [self ._get_choice_text (c ) for c in choices ]
161-
162177 if return_sorted and kwargs .get ("n" , 1 ) > 1 :
163178 scored_completions = []
164179
@@ -181,31 +196,57 @@ def __call__(
181196 return completions
182197
183198
199+
184200@CacheMemory .cache
185201def cached_gpt3_request_v2 (** kwargs ):
186202 return openai .Completion .create (** kwargs )
187203
188-
189204@functools .lru_cache (maxsize = None if cache_turn_on else 0 )
190205@NotebookCacheMemory .cache
191206def cached_gpt3_request_v2_wrapped (** kwargs ):
192207 return cached_gpt3_request_v2 (** kwargs )
193208
194-
195- cached_gpt3_request = cached_gpt3_request_v2_wrapped
196-
197-
198209@CacheMemory .cache
199210def _cached_gpt3_turbo_request_v2 (** kwargs ) -> OpenAIObject :
200211 if "stringify_request" in kwargs :
201212 kwargs = json .loads (kwargs ["stringify_request" ])
202213 return cast (OpenAIObject , openai .ChatCompletion .create (** kwargs ))
203214
204-
205215@functools .lru_cache (maxsize = None if cache_turn_on else 0 )
206216@NotebookCacheMemory .cache
207217def _cached_gpt3_turbo_request_v2_wrapped (** kwargs ) -> OpenAIObject :
208218 return _cached_gpt3_turbo_request_v2 (** kwargs )
209219
220+ @CacheMemory .cache
221+ def v1_cached_gpt3_request_v2 (** kwargs ):
222+ return openai .completions .create (** kwargs )
223+
224+ @functools .lru_cache (maxsize = None if cache_turn_on else 0 )
225+ @NotebookCacheMemory .cache
226+ def v1_cached_gpt3_request_v2_wrapped (** kwargs ):
227+ return v1_cached_gpt3_request_v2 (** kwargs )
228+
229+ @CacheMemory .cache
230+ def v1_cached_gpt3_turbo_request_v2 (** kwargs ):
231+ if "stringify_request" in kwargs :
232+ kwargs = json .loads (kwargs ["stringify_request" ])
233+ return openai .chat .completions .create (** kwargs )
234+
235+ @functools .lru_cache (maxsize = None if cache_turn_on else 0 )
236+ @NotebookCacheMemory .cache
237+ def v1_cached_gpt3_turbo_request_v2_wrapped (** kwargs ):
238+ return v1_cached_gpt3_turbo_request_v2 (** kwargs )
239+
240+
241+
242+ def chat_request (** kwargs ):
243+ if OPENAI_LEGACY :
244+ return _cached_gpt3_turbo_request_v2_wrapped (** kwargs )
245+
246+ return v1_cached_gpt3_turbo_request_v2_wrapped (** kwargs ).model_dump ()
247+
248+ def completions_request (** kwargs ):
249+ if OPENAI_LEGACY :
250+ return cached_gpt3_request_v2_wrapped (** kwargs )
210251
211- cached_gpt3_turbo_request = _cached_gpt3_turbo_request_v2_wrapped
252+ return v1_cached_gpt3_request_v2_wrapped ( ** kwargs ). model_dump ()
0 commit comments