|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -import inspect |
17 | 16 | import logging |
18 | 17 | from functools import wraps |
19 | | -from typing import Any, Dict, List, Optional |
| 18 | +from typing import Any, List, Optional |
20 | 19 |
|
21 | 20 | from langchain_core.callbacks.manager import ( |
22 | 21 | AsyncCallbackManagerForLLMRun, |
|
29 | 28 | from langchain_core.messages import BaseMessage |
30 | 29 | from langchain_core.outputs import ChatResult |
31 | 30 | from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal |
32 | | -from pydantic import Field |
| 31 | +from pydantic.v1 import Field |
33 | 32 |
|
34 | | -log = logging.getLogger(__name__) # pragma: no cover |
| 33 | +log = logging.getLogger(__name__) |
35 | 34 |
|
36 | 35 |
|
37 | | -def stream_decorator(func): # pragma: no cover |
| 36 | +def stream_decorator(func): |
38 | 37 | @wraps(func) |
39 | 38 | def wrapper( |
40 | 39 | self, |
@@ -80,52 +79,10 @@ async def wrapper( |
80 | 79 |
|
81 | 80 | # NOTE: this needs to have the same name as the original class, |
82 | 81 | # otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail. |
83 | | -class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover |
| 82 | +class ChatNVIDIA(ChatNVIDIAOriginal): |
84 | 83 | streaming: bool = Field( |
85 | 84 | default=False, description="Whether to use streaming or not" |
86 | 85 | ) |
87 | | - custom_headers: Optional[Dict[str, str]] = Field( |
88 | | - default=None, description="Custom HTTP headers to send with requests" |
89 | | - ) |
90 | | - |
91 | | - def __init__(self, **kwargs: Any): |
92 | | - super().__init__(**kwargs) |
93 | | - if self.custom_headers: |
94 | | - custom_headers_error = ( |
95 | | - "custom_headers requires langchain-nvidia-ai-endpoints >= 0.3.0. " |
96 | | - "Your version does not support the required client structure or " |
97 | | - "extra_headers parameter. Please upgrade: " |
98 | | - "pip install --upgrade langchain-nvidia-ai-endpoints>=0.3.0" |
99 | | - ) |
100 | | - if not hasattr(self._client, "get_req"): |
101 | | - raise RuntimeError(custom_headers_error) |
102 | | - |
103 | | - sig = inspect.signature(self._client.get_req) |
104 | | - if "extra_headers" not in sig.parameters: |
105 | | - raise RuntimeError(custom_headers_error) |
106 | | - |
107 | | - self._wrap_client_methods() |
108 | | - |
109 | | - def _wrap_client_methods(self): |
110 | | - original_get_req = self._client.get_req |
111 | | - original_get_req_stream = self._client.get_req_stream |
112 | | - |
113 | | - def wrapped_get_req(payload: dict = None, extra_headers: dict = None): |
114 | | - payload = payload or {} |
115 | | - extra_headers = extra_headers or {} |
116 | | - merged_headers = {**extra_headers, **self.custom_headers} |
117 | | - return original_get_req(payload=payload, extra_headers=merged_headers) |
118 | | - |
119 | | - def wrapped_get_req_stream(payload: dict = None, extra_headers: dict = None): |
120 | | - payload = payload or {} |
121 | | - extra_headers = extra_headers or {} |
122 | | - merged_headers = {**extra_headers, **self.custom_headers} |
123 | | - return original_get_req_stream( |
124 | | - payload=payload, extra_headers=merged_headers |
125 | | - ) |
126 | | - |
127 | | - object.__setattr__(self._client, "get_req", wrapped_get_req) |
128 | | - object.__setattr__(self._client, "get_req_stream", wrapped_get_req_stream) |
129 | 86 |
|
130 | 87 | @stream_decorator |
131 | 88 | def _generate( |
|
0 commit comments