33
44# Copyright (c) 2023 Oracle and/or its affiliates.
55# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+ """Chat model for OCI data science model deployment endpoint."""
67
7-
8+ import importlib
89import json
910import logging
1011from operator import itemgetter
1112from typing import (
1213 Any ,
1314 AsyncIterator ,
15+ Callable ,
1416 Dict ,
1517 Iterator ,
1618 List ,
1719 Literal ,
1820 Optional ,
21+ Sequence ,
1922 Type ,
2023 Union ,
21- Sequence ,
22- Callable ,
2324)
2425
2526from langchain_core .callbacks import (
3334 generate_from_stream ,
3435)
3536from langchain_core .messages import AIMessageChunk , BaseMessage , BaseMessageChunk
36- from langchain_core .tools import BaseTool
3737from langchain_core .output_parsers import (
3838 JsonOutputParser ,
3939 PydanticOutputParser ,
4040)
4141from langchain_core .outputs import ChatGeneration , ChatGenerationChunk , ChatResult
4242from langchain_core .runnables import Runnable , RunnableMap , RunnablePassthrough
43+ from langchain_core .tools import BaseTool
4344from langchain_core .utils .function_calling import convert_to_openai_tool
44- from langchain_openai .chat_models .base import (
45- _convert_delta_to_message_chunk ,
46- _convert_message_to_dict ,
47- _convert_dict_to_message ,
48- )
45+ from pydantic import BaseModel , Field , model_validator
4946
50- from pydantic import BaseModel , Field
5147from ads .llm .langchain .plugins .llms .oci_data_science_model_deployment_endpoint import (
5248 DEFAULT_MODEL_NAME ,
5349 BaseOCIModelDeployment ,
@@ -63,23 +59,48 @@ def _is_pydantic_class(obj: Any) -> bool:
6359class ChatOCIModelDeployment (BaseChatModel , BaseOCIModelDeployment ):
6460 """OCI Data Science Model Deployment chat model integration.
6561
66- To use, you must provide the model HTTP endpoint from your deployed
67- chat model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict .
62+ Setup:
63+ Install ``oracle-ads`` and ``langchain-openai`` .
6864
69- To authenticate, `oracle-ads` has been used to automatically load
70- credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
65+ .. code-block:: bash
7166
72- Make sure to have the required policies to access the OCI Data
73- Science Model Deployment endpoint. See:
74- https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
67+ pip install -U oracle-ads langchain-openai
68+
69+ Use `ads.set_auth()` to configure authentication.
70+ For example, to use OCI resource_principal for authentication:
71+
72+ .. code-block:: python
73+
74+ import ads
75+ ads.set_auth("resource_principal")
76+
77+ For more details on authentication, see:
78+ https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
79+
80+ Make sure to have the required policies to access the OCI Data
81+ Science Model Deployment endpoint. See:
82+ https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm
83+
84+
85+ Key init args - completion params:
86+ endpoint: str
87+ The OCI model deployment endpoint.
88+ temperature: float
89+ Sampling temperature.
90+ max_tokens: Optional[int]
91+ Max number of tokens to generate.
92+
93+ Key init args — client params:
94+ auth: dict
95+ ADS auth dictionary for OCI authentication.
7596
7697 Instantiate:
7798 .. code-block:: python
7899
79100 from langchain_community.chat_models import ChatOCIModelDeployment
80101
81102 chat = ChatOCIModelDeployment(
82- endpoint="https://modeldeployment.us-ashburn-1 .oci.customer-oci.com/<ocid>/predict",
103+ endpoint="https://modeldeployment.<region> .oci.customer-oci.com/<ocid>/predict",
83104 model="odsc-llm",
84105 streaming=True,
85106 max_retries=3,
@@ -94,15 +115,27 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
94115 .. code-block:: python
95116
96117 messages = [
97- ("system", "You are a helpful translator. Translate the user sentence to French."),
118+ ("system", "Translate the user sentence to French."),
98119 ("human", "Hello World!"),
99120 ]
100121 chat.invoke(messages)
101122
102123 .. code-block:: python
103124
104125 AIMessage(
105- content='Bonjour le monde!',response_metadata={'token_usage': {'prompt_tokens': 40, 'total_tokens': 50, 'completion_tokens': 10},'model_name': 'odsc-llm','system_fingerprint': '','finish_reason': 'stop'},id='run-cbed62da-e1b3-4abd-9df3-ec89d69ca012-0')
126+ content='Bonjour le monde!',
127+ response_metadata={
128+ 'token_usage': {
129+ 'prompt_tokens': 40,
130+ 'total_tokens': 50,
131+ 'completion_tokens': 10
132+ },
133+ 'model_name': 'odsc-llm',
134+ 'system_fingerprint': '',
135+ 'finish_reason': 'stop'
136+ },
137+ id='run-cbed62da-e1b3-4abd-9df3-ec89d69ca012-0'
138+ )
106139
107140 Streaming:
108141 .. code-block:: python
@@ -112,18 +145,18 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
112145
113146 .. code-block:: python
114147
115- content='' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
116- content='\n ' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
117- content='B' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
118- content='on' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
119- content='j' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
120- content='our' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
121- content=' le' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
122- content=' monde' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
123- content='!' id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
124- content='' response_metadata={'finish_reason': 'stop'} id='run-23df02c6 -c43f-42de-87c6-8ad382e125c3 '
125-
126- Asyc :
148+ content='' id='run-02c6 -c43f-42de'
149+ content='\n ' id='run-02c6 -c43f-42de'
150+ content='B' id='run-02c6 -c43f-42de'
151+ content='on' id='run-02c6 -c43f-42de'
152+ content='j' id='run-02c6 -c43f-42de'
153+ content='our' id='run-02c6 -c43f-42de'
154+ content=' le' id='run-02c6 -c43f-42de'
155+ content=' monde' id='run-02c6 -c43f-42de'
156+ content='!' id='run-02c6 -c43f-42de'
157+ content='' response_metadata={'finish_reason': 'stop'} id='run-02c6 -c43f-42de'
158+
159+ Async :
127160 .. code-block:: python
128161
129162 await chat.ainvoke(messages)
@@ -133,7 +166,11 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
133166
134167 .. code-block:: python
135168
136- AIMessage(content='Bonjour le monde!', response_metadata={'finish_reason': 'stop'}, id='run-8657a105-96b7-4bb6-b98e-b69ca420e5d1-0')
169+ AIMessage(
170+ content='Bonjour le monde!',
171+ response_metadata={'finish_reason': 'stop'},
172+ id='run-8657a105-96b7-4bb6-b98e-b69ca420e5d1-0'
173+ )
137174
138175 Structured output:
139176 .. code-block:: python
@@ -147,19 +184,22 @@ class Joke(BaseModel):
147184
148185 structured_llm = chat.with_structured_output(Joke, method="json_mode")
149186 structured_llm.invoke(
150- "Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
187+ "Tell me a joke about cats, "
188+ "respond in JSON with `setup` and `punchline` keys"
151189 )
152190
153191 .. code-block:: python
154192
155- Joke(setup='Why did the cat get stuck in the tree?',punchline='Because it was chasing its tail!')
193+ Joke(
194+ setup='Why did the cat get stuck in the tree?',
195+ punchline='Because it was chasing its tail!'
196+ )
156197
157198 See ``ChatOCIModelDeployment.with_structured_output()`` for more.
158199
159200 Customized Usage:
160-
161- You can inherit from base class and overwrite the `_process_response`, `_process_stream_response`,
162- `_construct_json_body` for satisfying customized needed.
201+ You can inherit from base class and overwrite the `_process_response`,
202+ `_process_stream_response`, `_construct_json_body` for customized usage.
163203
164204 .. code-block:: python
165205
@@ -180,12 +220,31 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
180220 }
181221
182222 chat = MyChatModel(
183- endpoint=f"https://modeldeployment.us-ashburn-1 .oci.customer-oci.com/{ocid}/predict",
223+ endpoint=f"https://modeldeployment.<region> .oci.customer-oci.com/{ocid}/predict",
184224 model="odsc-llm",
185225 }
186226
187227 chat.invoke("tell me a joke")
188228
229+ Response metadata
230+ .. code-block:: python
231+
232+ ai_msg = chat.invoke(messages)
233+ ai_msg.response_metadata
234+
235+ .. code-block:: python
236+
237+ {
238+ 'token_usage': {
239+ 'prompt_tokens': 40,
240+ 'total_tokens': 50,
241+ 'completion_tokens': 10
242+ },
243+ 'model_name': 'odsc-llm',
244+ 'system_fingerprint': '',
245+ 'finish_reason': 'stop'
246+ }
247+
189248 """ # noqa: E501
190249
191250 model_kwargs : Dict [str , Any ] = Field (default_factory = dict )
@@ -198,6 +257,17 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
198257 """Stop words to use when generating. Model output is cut off
199258 at the first occurrence of any of these substrings."""
200259
260+ @model_validator (mode = "before" )
261+ @classmethod
262+ def validate_openai (cls , values : Any ) -> Any :
263+ """Checks if langchain_openai is installed."""
264+ if not importlib .util .find_spec ("langchain_openai" ):
265+ raise ImportError (
266+ "Could not import langchain_openai package. "
267+ "Please install it with `pip install langchain_openai`."
268+ )
269+ return values
270+
201271 @property
202272 def _llm_type (self ) -> str :
203273 """Return type of llm."""
@@ -552,6 +622,8 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
552622 converted messages and additional parameters.
553623
554624 """
625+ from langchain_openai .chat_models .base import _convert_message_to_dict
626+
555627 return {
556628 "messages" : [_convert_message_to_dict (m ) for m in messages ],
557629 ** params ,
@@ -578,6 +650,8 @@ def _process_stream_response(
578650 ValueError: If the response JSON is not well-formed or does not
579651 contain the expected structure.
580652 """
653+ from langchain_openai .chat_models .base import _convert_delta_to_message_chunk
654+
581655 try :
582656 choice = response_json ["choices" ][0 ]
583657 if not isinstance (choice , dict ):
@@ -616,6 +690,8 @@ def _process_response(self, response_json: dict) -> ChatResult:
616690 contain the expected structure.
617691
618692 """
693+ from langchain_openai .chat_models .base import _convert_dict_to_message
694+
619695 generations = []
620696 try :
621697 choices = response_json ["choices" ]
@@ -760,8 +836,9 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
760836 tool_choice : Optional [str ] = None
761837 """Whether to use tool calling.
762838 Defaults to None, tool calling is disabled.
763- Tool calling requires model support and vLLM to be configured with `--tool-call-parser`.
764- Set this to `auto` for the model to determine whether to make tool calls automatically.
839+ Tool calling requires model support and the vLLM to be configured
840+ with `--tool-call-parser`.
841+ Set this to `auto` for the model to make tool calls automatically.
765842 Set this to `required` to force the model to always call one or more tools.
766843 """
767844
0 commit comments