11#!/usr/bin/env python
2- # -*- coding: utf-8 -*--
32
4- # Copyright (c) 2023 Oracle and/or its affiliates.
3+ # Copyright (c) 2024 Oracle and/or its affiliates.
54# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65"""Chat model for OCI data science model deployment endpoint."""
76
5049)
5150
5251logger = logging .getLogger (__name__ )
52+ DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
5353
5454
5555def _is_pydantic_class (obj : Any ) -> bool :
@@ -93,6 +93,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
9393 Key init args — client params:
9494 auth: dict
9595 ADS auth dictionary for OCI authentication.
96+ headers: Optional[Dict]
97+ The headers to be added to the Model Deployment request.
9698
9799 Instantiate:
98100 .. code-block:: python
@@ -109,6 +111,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
109111 "temperature": 0.2,
110112 # other model parameters ...
111113 },
114+ headers={
115+ "route": "/v1/chat/completions",
116+ # other request headers ...
117+ },
112118 )
113119
114120 Invocation:
@@ -257,6 +263,9 @@ def _construct_json_body(self, messages: list, params: dict) -> dict:
257263 """Stop words to use when generating. Model output is cut off
258264 at the first occurrence of any of these substrings."""
259265
266+ headers : Optional [Dict [str , Any ]] = {"route" : DEFAULT_INFERENCE_ENDPOINT_CHAT }
267+ """The headers to be added to the Model Deployment request."""
268+
260269 @model_validator (mode = "before" )
261270 @classmethod
262271 def validate_openai (cls , values : Any ) -> Any :
@@ -704,7 +713,7 @@ def _process_response(self, response_json: dict) -> ChatResult:
704713
705714 for choice in choices :
706715 message = _convert_dict_to_message (choice ["message" ])
707- generation_info = dict ( finish_reason = choice .get ("finish_reason" ))
716+ generation_info = { " finish_reason" : choice .get ("finish_reason" )}
708717 if "logprobs" in choice :
709718 generation_info ["logprobs" ] = choice ["logprobs" ]
710719
@@ -794,7 +803,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
794803 """Number of most likely tokens to consider at each step."""
795804
796805 min_p : Optional [float ] = 0.0
797- """Float that represents the minimum probability for a token to be considered.
806+ """Float that represents the minimum probability for a token to be considered.
798807 Must be in [0,1]. 0 to disable this."""
799808
800809 repetition_penalty : Optional [float ] = 1.0
@@ -818,7 +827,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
818827 the EOS token is generated."""
819828
820829 min_tokens : Optional [int ] = 0
821- """Minimum number of tokens to generate per output sequence before
830+ """Minimum number of tokens to generate per output sequence before
822831 EOS or stop_token_ids can be generated"""
823832
824833 stop_token_ids : Optional [List [int ]] = None
@@ -836,7 +845,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
836845 tool_choice : Optional [str ] = None
837846 """Whether to use tool calling.
838847 Defaults to None, tool calling is disabled.
839- Tool calling requires model support and the vLLM to be configured
848+ Tool calling requires model support and the vLLM to be configured
840849 with `--tool-call-parser`.
841850 Set this to `auto` for the model to make tool calls automatically.
842851 Set this to `required` to force the model to always call one or more tools.
@@ -956,9 +965,9 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
956965 """Total probability mass of tokens to consider at each step."""
957966
958967 top_logprobs : Optional [int ] = None
959- """An integer between 0 and 5 specifying the number of most
960- likely tokens to return at each token position, each with an
961- associated log probability. logprobs must be set to true if
968+ """An integer between 0 and 5 specifying the number of most
969+ likely tokens to return at each token position, each with an
970+ associated log probability. logprobs must be set to true if
962971 this parameter is used."""
963972
964973 @property
0 commit comments