|
8 | 8 | from typing import Any, Dict, List, Optional |
9 | 9 |
|
10 | 10 | import requests |
11 | | -from oci.auth import signers |
12 | 11 | from langchain.callbacks.manager import CallbackManagerForLLMRun |
| 12 | +from langchain.pydantic_v1 import root_validator |
| 13 | +from langchain.utils import get_from_dict_or_env |
| 14 | +from oci.auth import signers |
13 | 15 |
|
14 | 16 | from ads.llm.langchain.plugins.base import BaseLLM |
15 | 17 | from ads.llm.langchain.plugins.contant import ( |
|
23 | 25 | class ModelDeploymentLLM(BaseLLM): |
24 | 26 | """Base class for LLM deployed on OCI Model Deployment.""" |
25 | 27 |
|
26 | | - endpoint: str |
| 28 | + endpoint: str = "" |
27 | 29 | """The uri of the endpoint from the deployed Model Deployment model.""" |
28 | 30 |
|
29 | 31 | best_of: int = 1 |
30 | 32 | """Generates best_of completions server-side and returns the "best" |
31 | 33 | (the one with the highest log probability per token). |
32 | 34 | """ |
33 | 35 |
|
| 36 | + @root_validator() |
| 37 | + def validate_environment( # pylint: disable=no-self-argument |
| 38 | + cls, values: Dict |
| 39 | + ) -> Dict: |
| 40 | + """Fetch endpoint from environment variable or arguments.""" |
| 41 | + values["endpoint"] = get_from_dict_or_env( |
| 42 | + values, |
| 43 | + "endpoint", |
| 44 | + "OCI_LLM_ENDPOINT", |
| 45 | + ) |
| 46 | + return values |
| 47 | + |
34 | 48 | @property |
35 | 49 | def _default_params(self) -> Dict[str, Any]: |
36 | 50 | """Default parameters for the model.""" |
@@ -73,7 +87,7 @@ def _call( |
73 | 87 | run_manager: Optional[CallbackManagerForLLMRun] = None, |
74 | 88 | **kwargs: Any, |
75 | 89 | ) -> str: |
76 | | - """Call out to OCI Data Science Model Deployment TGI endpoint. |
| 90 | + """Call out to OCI Data Science Model Deployment endpoint. |
77 | 91 |
|
78 | 92 | Parameters |
79 | 93 | ---------- |
@@ -203,8 +217,11 @@ class ModelDeploymentTGI(ModelDeploymentLLM): |
203 | 217 | """ |
204 | 218 |
|
205 | 219 | watermark = True |
| 220 | + """Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_. |
| 221 | + Defaults to True.""" |
206 | 222 |
|
207 | 223 | return_full_text = False |
| 224 | + """Whether to prepend the prompt to the generated text. Defaults to False.""" |
208 | 225 |
|
209 | 226 | @property |
210 | 227 | def _llm_type(self) -> str: |
@@ -241,6 +258,7 @@ class ModelDeploymentVLLM(ModelDeploymentLLM): |
241 | 258 | """VLLM deployed on OCI Model Deployment""" |
242 | 259 |
|
243 | 260 | model: str |
| 261 | + """Name of the model.""" |
244 | 262 |
|
245 | 263 | n: int = 1 |
246 | 264 | """Number of output sequences to return for the given prompt.""" |
|
0 commit comments