11#!/usr/bin/env python
2- # -*- coding: utf-8 -*--
32
4- # Copyright (c) 2023 Oracle and/or its affiliates.
3+ # Copyright (c) 2024 Oracle and/or its affiliates.
54# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65
76
2423
2524import aiohttp
2625import requests
26+ from langchain_community .utilities .requests import Requests
2727from langchain_core .callbacks import (
2828 AsyncCallbackManagerForLLMRun ,
2929 CallbackManagerForLLMRun ,
3434from langchain_core .utils import get_from_dict_or_env
3535from pydantic import Field , model_validator
3636
37- from langchain_community .utilities .requests import Requests
38-
3937logger = logging .getLogger (__name__ )
4038
4139
4240DEFAULT_TIME_OUT = 300
4341DEFAULT_CONTENT_TYPE_JSON = "application/json"
4442DEFAULT_MODEL_NAME = "odsc-llm"
43+ DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
4544
4645
4746class TokenExpiredError (Exception ):
@@ -86,6 +85,9 @@ class BaseOCIModelDeployment(Serializable):
8685 max_retries : int = 3
8786 """Maximum number of retries to make when generating."""
8887
88+ default_headers : Optional [Dict [str , Any ]] = None
89+ """The headers to be added to the Model Deployment request."""
90+
8991 @model_validator (mode = "before" )
9092 @classmethod
9193 def validate_environment (cls , values : Dict ) -> Dict :
@@ -101,7 +103,7 @@ def validate_environment(cls, values: Dict) -> Dict:
101103 "Please install it with `pip install oracle_ads`."
102104 ) from ex
103105
104- if not values .get ("auth" , None ):
106+ if not values .get ("auth" ):
105107 values ["auth" ] = ads .common .auth .default_signer ()
106108
107109 values ["endpoint" ] = get_from_dict_or_env (
@@ -125,12 +127,12 @@ def _headers(
125127 Returns:
126128 Dict: A dictionary containing the appropriate headers for the request.
127129 """
130+ headers = self .default_headers or {}
128131 if is_async :
129132 signer = self .auth ["signer" ]
130133 _req = requests .Request ("POST" , self .endpoint , json = body )
131134 req = _req .prepare ()
132135 req = signer (req )
133- headers = {}
134136 for key , value in req .headers .items ():
135137 headers [key ] = value
136138
@@ -140,7 +142,7 @@ def _headers(
140142 )
141143 return headers
142144
143- return (
145+ headers . update (
144146 {
145147 "Content-Type" : DEFAULT_CONTENT_TYPE_JSON ,
146148 "enable-streaming" : "true" ,
@@ -152,6 +154,8 @@ def _headers(
152154 }
153155 )
154156
157+ return headers
158+
155159 def completion_with_retry (
156160 self , run_manager : Optional [CallbackManagerForLLMRun ] = None , ** kwargs : Any
157161 ) -> Any :
@@ -357,7 +361,7 @@ def _refresh_signer(self) -> bool:
357361 self .auth ["signer" ].refresh_security_token ()
358362 return True
359363 return False
360-
364+
361365 @classmethod
362366 def is_lc_serializable (cls ) -> bool :
363367 """Return whether this model can be serialized by LangChain."""
@@ -388,6 +392,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
388392 model="odsc-llm",
389393 streaming=True,
390394 model_kwargs={"frequency_penalty": 1.0},
395+ headers={
396+ "route": "/v1/completions",
397+ # other request headers ...
398+ }
391399 )
392400 llm.invoke("tell me a joke.")
393401
@@ -477,6 +485,25 @@ def _identifying_params(self) -> Dict[str, Any]:
477485 ** self ._default_params ,
478486 }
479487
488+ def _headers (
489+ self , is_async : Optional [bool ] = False , body : Optional [dict ] = None
490+ ) -> Dict :
491+ """Construct and return the headers for a request.
492+
493+ Args:
494+ is_async (bool, optional): Indicates if the request is asynchronous.
495+ Defaults to `False`.
496+ body (optional): The request body to be included in the headers if
497+ the request is asynchronous.
498+
499+ Returns:
500+ Dict: A dictionary containing the appropriate headers for the request.
501+ """
502+ return {
503+ "route" : DEFAULT_INFERENCE_ENDPOINT ,
504+ ** super ()._headers (is_async = is_async , body = body ),
505+ }
506+
480507 def _generate (
481508 self ,
482509 prompts : List [str ],
@@ -712,9 +739,9 @@ def _process_response(self, response_json: dict) -> List[Generation]:
712739 def _generate_info (self , choice : dict ) -> Any :
713740 """Extracts generation info from the response."""
714741 gen_info = {}
715- finish_reason = choice .get ("finish_reason" , None )
716- logprobs = choice .get ("logprobs" , None )
717- index = choice .get ("index" , None )
742+ finish_reason = choice .get ("finish_reason" )
743+ logprobs = choice .get ("logprobs" )
744+ index = choice .get ("index" )
718745 if finish_reason :
719746 gen_info .update ({"finish_reason" : finish_reason })
720747 if logprobs is not None :
0 commit comments