@@ -149,8 +149,8 @@ def _call(
149149 self ._print_request (prompt , params )
150150
151151 try :
152- response = (
153- self .completion_with_retry (prompts = [ prompt ] , ** params )
152+ completion = (
153+ self .completion_with_retry (prompt = prompt , ** params )
154154 if self .task == Task .TEXT_GENERATION
155155 else self .completion_with_retry (input = prompt , ** params )
156156 )
@@ -164,8 +164,8 @@ def _call(
164164 )
165165 raise
166166
167- completion = self ._process_response (response , params .get ("num_generations" , 1 ))
168- self ._print_response (completion , response )
167+ # completion = self._process_response(response, params.get("num_generations", 1))
168+ # self._print_response(completion, response)
169169 return completion
170170
171171 def _process_response (self , response : Any , num_generations : int = 1 ) -> str :
@@ -178,7 +178,7 @@ def _process_response(self, response: Any, num_generations: int = 1) -> str:
178178 else [gen .text for gen in response .data .generated_texts [0 ]]
179179 )
180180
181- def completion_with_retry (self , ** kwargs : Any ) -> Any :
181+ def _completion_with_retry_v1 (self , ** kwargs : Any ):
182182 from oci .generative_ai .models import (
183183 GenerateTextDetails ,
184184 OnDemandServingMode ,
@@ -188,15 +188,79 @@ def completion_with_retry(self, **kwargs: Any) -> Any:
188188 # TODO: Add retry logic for OCI
189189 # Convert the ``model`` parameter to OCI ``ServingMode``
190190 # Note that "ServingMode` is not JSON serializable.
191+ kwargs ["prompts" ] = [kwargs .pop ("prompt" )]
191192 kwargs ["serving_mode" ] = OnDemandServingMode (model_id = self .model )
192193 if self .task == Task .TEXT_GENERATION :
193- return self .client .generate_text (
194+ response = self .client .generate_text (
194195 GenerateTextDetails (** kwargs ), ** self .endpoint_kwargs
195196 )
197+ if kwargs .get ("num_generations" , 1 ) == 1 :
198+ completion = response .data .generated_texts [0 ][0 ].text
199+ else :
200+ completion = [gen .text for gen in response .data .generated_texts [0 ]]
196201 else :
197- return self .client .summarize_text (
202+ response = self .client .summarize_text (
198203 SummarizeTextDetails (** kwargs ), ** self .endpoint_kwargs
199204 )
205+ completion = response .data .summary
206+ self ._print_response (completion , response )
207+ return completion
208+
209+ def _completion_with_retry_v2 (self , ** kwargs : Any ):
210+ from oci .generative_ai_inference .models import (
211+ GenerateTextDetails ,
212+ OnDemandServingMode ,
213+ SummarizeTextDetails ,
214+ CohereLlmInferenceRequest ,
215+ LlamaLlmInferenceRequest ,
216+ )
217+
218+ request_class_mapping = {
219+ "cohere" : CohereLlmInferenceRequest ,
220+ "llama" : LlamaLlmInferenceRequest ,
221+ }
222+
223+ request_class = None
224+ for prefix , oci_request_class in request_class_mapping .items ():
225+ if self .model .startswith (prefix ):
226+ request_class = oci_request_class
227+ if not request_class :
228+ raise ValueError (f"Model { self .model } is not supported." )
229+
230+ if self .model .startswith ("llama" ):
231+ kwargs .pop ("truncate" , None )
232+ kwargs .pop ("stop_sequences" , None )
233+
234+ serving_mode = OnDemandServingMode (model_id = self .model )
235+ if self .task == Task .TEXT_GENERATION :
236+ compartment_id = kwargs .pop ("compartment_id" )
237+ inference_request = request_class (** kwargs )
238+ response = self .client .generate_text (
239+ GenerateTextDetails (
240+ compartment_id = compartment_id ,
241+ serving_mode = serving_mode ,
242+ inference_request = inference_request ,
243+ ),
244+ ** self .endpoint_kwargs ,
245+ )
246+ if kwargs .get ("num_generations" , 1 ) == 1 :
247+ completion = response .data .inference_response .generated_texts [0 ].text
248+ else :
249+ completion = [gen .text for gen in response .data .generated_texts ]
250+
251+ else :
252+ response = self .client .summarize_text (
253+ SummarizeTextDetails (serving_mode = serving_mode , ** kwargs ),
254+ ** self .endpoint_kwargs ,
255+ )
256+ completion = response .data .summary
257+ self ._print_response (completion , response )
258+ return completion
259+
260+ def completion_with_retry (self , ** kwargs : Any ) -> Any :
261+ if self .client .__class__ .__name__ == "GenerativeAiClient" :
262+ return self ._completion_with_retry_v1 (** kwargs )
263+ return self ._completion_with_retry_v2 (** kwargs )
200264
201265 def batch_completion (
202266 self ,
0 commit comments