11#!/usr/bin/env python
22# -*- coding: utf-8 -*--
33
4- # Copyright (c) 2023 Oracle and/or its affiliates.
4+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
77import logging
1010from langchain .callbacks .manager import CallbackManagerForLLMRun
1111
1212from ads .llm .langchain .plugins .base import BaseLLM , GenerativeAiClientModel
13- from ads .llm .langchain .plugins .contant import *
13+ from ads .llm .langchain .plugins .contant import Task
1414
1515logger = logging .getLogger (__name__ )
1616
@@ -32,7 +32,7 @@ class GenerativeAI(GenerativeAiClientModel, BaseLLM):
3232 """
3333
3434 task : str = "text_generation"
35- """Indicates the task ."""
35+ """Task can be either text_generation or text_summarization ."""
3636
3737 model : Optional [str ] = "cohere.command"
3838 """Model name to use."""
@@ -106,7 +106,7 @@ def _default_params(self) -> Dict[str, Any]:
106106
107107 def _invocation_params (self , stop : Optional [List [str ]], ** kwargs : Any ) -> dict :
108108 params = self ._default_params
109- if self .task == Task .SUMMARY_TEXT :
109+ if self .task == Task .TEXT_SUMMARIZATION :
110110 return {** params }
111111
112112 if self .stop is not None and stop is not None :
@@ -149,11 +149,7 @@ def _call(
149149 self ._print_request (prompt , params )
150150
151151 try :
152- completion = (
153- self .completion_with_retry (prompt = prompt , ** params )
154- if self .task == Task .TEXT_GENERATION
155- else self .completion_with_retry (input = prompt , ** params )
156- )
152+ completion = self .completion_with_retry (prompt = prompt , ** params )
157153 except Exception :
158154 logger .error (
159155 "Error occur when invoking oci service api."
@@ -164,103 +160,95 @@ def _call(
164160 )
165161 raise
166162
167- # completion = self._process_response(response, params.get("num_generations", 1))
168- # self._print_response(completion, response)
169163 return completion
170164
171- def _process_response (self , response : Any , num_generations : int = 1 ) -> str :
172- if self .task == Task .SUMMARY_TEXT :
173- return response .data .summary
174-
175- return (
176- response .data .generated_texts [0 ][0 ].text
177- if num_generations == 1
178- else [gen .text for gen in response .data .generated_texts [0 ]]
165+ def _text_generation (self , request_class , serving_mode , ** kwargs ):
166+ from oci .generative_ai_inference .models import (
167+ GenerateTextDetails ,
168+ GenerateTextResult ,
179169 )
180170
181- def _completion_with_retry_v1 (self , ** kwargs : Any ):
182- from oci .generative_ai .models import (
183- GenerateTextDetails ,
184- OnDemandServingMode ,
185- SummarizeTextDetails ,
171+ compartment_id = kwargs .pop ("compartment_id" )
172+ inference_request = request_class (** kwargs )
173+ response = self .client .generate_text (
174+ GenerateTextDetails (
175+ compartment_id = compartment_id ,
176+ serving_mode = serving_mode ,
177+ inference_request = inference_request ,
178+ ),
179+ ** self .endpoint_kwargs ,
180+ ).data
181+ response : GenerateTextResult
182+ return response .inference_response
183+
184+ def _cohere_completion (self , serving_mode , ** kwargs ) -> str :
185+ from oci .generative_ai_inference .models import (
186+ CohereLlmInferenceRequest ,
187+ CohereLlmInferenceResponse ,
186188 )
187189
188- # TODO: Add retry logic for OCI
189- # Convert the ``model`` parameter to OCI ``ServingMode``
190- # Note that "ServingMode` is not JSON serializable.
191- kwargs ["prompts" ] = [kwargs .pop ("prompt" )]
192- kwargs ["serving_mode" ] = OnDemandServingMode (model_id = self .model )
193- if self .task == Task .TEXT_GENERATION :
194- response = self .client .generate_text (
195- GenerateTextDetails (** kwargs ), ** self .endpoint_kwargs
196- )
197- if kwargs .get ("num_generations" , 1 ) == 1 :
198- completion = response .data .generated_texts [0 ][0 ].text
199- else :
200- completion = [gen .text for gen in response .data .generated_texts [0 ]]
190+ response = self ._text_generation (
191+ CohereLlmInferenceRequest , serving_mode , ** kwargs
192+ )
193+ response : CohereLlmInferenceResponse
194+ if kwargs .get ("num_generations" , 1 ) == 1 :
195+ completion = response .generated_texts [0 ].text
201196 else :
202- response = self .client .summarize_text (
203- SummarizeTextDetails (** kwargs ), ** self .endpoint_kwargs
204- )
205- completion = response .data .summary
197+ completion = [result .text for result in response .generated_texts ]
206198 self ._print_response (completion , response )
207199 return completion
208200
209- def _completion_with_retry_v2 (self , ** kwargs : Any ) :
201+ def _llama_completion (self , serving_mode , ** kwargs ) -> str :
210202 from oci .generative_ai_inference .models import (
211- GenerateTextDetails ,
212- OnDemandServingMode ,
213- SummarizeTextDetails ,
214- CohereLlmInferenceRequest ,
215203 LlamaLlmInferenceRequest ,
204+ LlamaLlmInferenceResponse ,
216205 )
217206
218- request_class_mapping = {
219- "cohere" : CohereLlmInferenceRequest ,
220- "llama" : LlamaLlmInferenceRequest ,
221- }
207+ # truncate and stop_sequence are not supported.
208+ kwargs .pop ("truncate" , None )
209+ kwargs .pop ("stop_sequences" , None )
210+ # top_k must be >1 or -1
211+ if "top_k" in kwargs and kwargs ["top_k" ] == 0 :
212+ kwargs .pop ("top_k" )
222213
223- request_class = None
224- for prefix , oci_request_class in request_class_mapping .items ():
225- if self .model .startswith (prefix ):
226- request_class = oci_request_class
227- if not request_class :
228- raise ValueError (f"Model { self .model } is not supported." )
229-
230- if self .model .startswith ("llama" ):
231- kwargs .pop ("truncate" , None )
232- kwargs .pop ("stop_sequences" , None )
233-
234- serving_mode = OnDemandServingMode (model_id = self .model )
235- if self .task == Task .TEXT_GENERATION :
236- compartment_id = kwargs .pop ("compartment_id" )
237- inference_request = request_class (** kwargs )
238- response = self .client .generate_text (
239- GenerateTextDetails (
240- compartment_id = compartment_id ,
241- serving_mode = serving_mode ,
242- inference_request = inference_request ,
243- ),
244- ** self .endpoint_kwargs ,
245- )
246- if kwargs .get ("num_generations" , 1 ) == 1 :
247- completion = response .data .inference_response .generated_texts [0 ].text
248- else :
249- completion = [gen .text for gen in response .data .generated_texts ]
214+ # top_p must be 1 when temperature is 0
215+ if kwargs .get ("temperature" ) == 0 :
216+ kwargs ["top_p" ] = 1
250217
218+ response = self ._text_generation (
219+ LlamaLlmInferenceRequest , serving_mode , ** kwargs
220+ )
221+ response : LlamaLlmInferenceResponse
222+ if kwargs .get ("num_generations" , 1 ) == 1 :
223+ completion = response .choices [0 ].text
251224 else :
252- response = self .client .summarize_text (
253- SummarizeTextDetails (serving_mode = serving_mode , ** kwargs ),
254- ** self .endpoint_kwargs ,
255- )
256- completion = response .data .summary
225+ completion = [result .text for result in response .choices ]
257226 self ._print_response (completion , response )
258227 return completion
259228
229+ def _cohere_summarize (self , serving_mode , ** kwargs ) -> str :
230+ from oci .generative_ai_inference .models import SummarizeTextDetails
231+
232+ kwargs ["input" ] = kwargs .pop ("prompt" )
233+
234+ response = self .client .summarize_text (
235+ SummarizeTextDetails (serving_mode = serving_mode , ** kwargs ),
236+ ** self .endpoint_kwargs ,
237+ )
238+ return response .data .summary
239+
260240 def completion_with_retry (self , ** kwargs : Any ) -> Any :
261- if self .client .__class__ .__name__ == "GenerativeAiClient" :
262- return self ._completion_with_retry_v1 (** kwargs )
263- return self ._completion_with_retry_v2 (** kwargs )
241+ from oci .generative_ai_inference .models import OnDemandServingMode
242+
243+ serving_mode = OnDemandServingMode (model_id = self .model )
244+
245+ if self .task == Task .TEXT_SUMMARIZATION :
246+ return self ._cohere_summarize (serving_mode , ** kwargs )
247+ elif self .model .startswith ("cohere" ):
248+ return self ._cohere_completion (serving_mode , ** kwargs )
249+ elif self .model .startswith ("meta.llama" ):
250+ return self ._llama_completion (serving_mode , ** kwargs )
251+ raise ValueError (f"Model { self .model } is not supported." )
264252
265253 def batch_completion (
266254 self ,
@@ -299,9 +287,9 @@ def batch_completion(
299287 responses = gen_ai.batch_completion("Tell me a joke.", num_generations=5)
300288
301289 """
302- if self .task == Task .SUMMARY_TEXT :
290+ if self .task == Task .TEXT_SUMMARIZATION :
303291 raise NotImplementedError (
304- f"task={ Task .SUMMARY_TEXT } does not support batch_completion. "
292+ f"task={ Task .TEXT_SUMMARIZATION } does not support batch_completion. "
305293 )
306294
307295 return self ._call (
0 commit comments