Skip to content

Commit 3b4204a

Browse files
committed
revert(llm): remove custom HTTP headers patch now in langchain-nvidia-ai-endpoints v0.3.19 (#1503)
This reverts commit aafd733.
1 parent e92cf26 commit 3b4204a

File tree

2 files changed

+5
-506
lines changed

2 files changed

+5
-506
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import inspect
1716
import logging
1817
from functools import wraps
19-
from typing import Any, Dict, List, Optional
18+
from typing import Any, List, Optional
2019

2120
from langchain_core.callbacks.manager import (
2221
AsyncCallbackManagerForLLMRun,
@@ -29,12 +28,12 @@
2928
from langchain_core.messages import BaseMessage
3029
from langchain_core.outputs import ChatResult
3130
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
32-
from pydantic import Field
31+
from pydantic.v1 import Field
3332

34-
log = logging.getLogger(__name__) # pragma: no cover
33+
log = logging.getLogger(__name__)
3534

3635

37-
def stream_decorator(func): # pragma: no cover
36+
def stream_decorator(func):
3837
@wraps(func)
3938
def wrapper(
4039
self,
@@ -80,52 +79,10 @@ async def wrapper(
8079

8180
# NOTE: this needs to have the same name as the original class,
8281
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
83-
class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover
82+
class ChatNVIDIA(ChatNVIDIAOriginal):
8483
streaming: bool = Field(
8584
default=False, description="Whether to use streaming or not"
8685
)
87-
custom_headers: Optional[Dict[str, str]] = Field(
88-
default=None, description="Custom HTTP headers to send with requests"
89-
)
90-
91-
def __init__(self, **kwargs: Any):
92-
super().__init__(**kwargs)
93-
if self.custom_headers:
94-
custom_headers_error = (
95-
"custom_headers requires langchain-nvidia-ai-endpoints >= 0.3.0. "
96-
"Your version does not support the required client structure or "
97-
"extra_headers parameter. Please upgrade: "
98-
"pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0"
99-
)
100-
if not hasattr(self._client, "get_req"):
101-
raise RuntimeError(custom_headers_error)
102-
103-
sig = inspect.signature(self._client.get_req)
104-
if "extra_headers" not in sig.parameters:
105-
raise RuntimeError(custom_headers_error)
106-
107-
self._wrap_client_methods()
108-
109-
def _wrap_client_methods(self):
110-
original_get_req = self._client.get_req
111-
original_get_req_stream = self._client.get_req_stream
112-
113-
def wrapped_get_req(payload: dict = None, extra_headers: dict = None):
114-
payload = payload or {}
115-
extra_headers = extra_headers or {}
116-
merged_headers = {**extra_headers, **self.custom_headers}
117-
return original_get_req(payload=payload, extra_headers=merged_headers)
118-
119-
def wrapped_get_req_stream(payload: dict = None, extra_headers: dict = None):
120-
payload = payload or {}
121-
extra_headers = extra_headers or {}
122-
merged_headers = {**extra_headers, **self.custom_headers}
123-
return original_get_req_stream(
124-
payload=payload, extra_headers=merged_headers
125-
)
126-
127-
object.__setattr__(self._client, "get_req", wrapped_get_req)
128-
object.__setattr__(self._client, "get_req_stream", wrapped_get_req_stream)
12986

13087
@stream_decorator
13188
def _generate(

0 commit comments

Comments
 (0)