Skip to content

Commit fac2774

Browse files
committed
fix(llm): Add async streaming support to ChatNVIDIA provider patch (#1504)
* feat(llm): Add async streaming support to ChatNVIDIA provider Enables stream_async() to work with ChatNVIDIA/NIM models by implementing async streaming decorator and _agenerate method. Prior to this fix, stream_async() would fail with NIM engine configurations. * fix: ensure stream_async background task completes before exit (#1508) Wrap the returned iterator to await the background generation task in a finally block, preventing "Task was destroyed but it is pending" warning. Add overloaded type signatures to provide accurate return types based on the include_generation_metadata parameter.
1 parent be88814 commit fac2774

File tree

3 files changed

+487
-7
lines changed

3 files changed

+487
-7
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,18 @@
1717
from functools import wraps
1818
from typing import Any, List, Optional
1919

20-
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
21-
from langchain_core.language_models.chat_models import generate_from_stream
20+
from langchain_core.callbacks.manager import (
21+
AsyncCallbackManagerForLLMRun,
22+
CallbackManagerForLLMRun,
23+
)
24+
from langchain_core.language_models.chat_models import (
25+
agenerate_from_stream,
26+
generate_from_stream,
27+
)
2228
from langchain_core.messages import BaseMessage
2329
from langchain_core.outputs import ChatResult
2430
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
25-
from pydantic.v1 import Field
31+
from pydantic import Field
2632

2733
log = logging.getLogger(__name__)
2834

@@ -49,6 +55,28 @@ def wrapper(
4955
return wrapper
5056

5157

58+
def async_stream_decorator(func): # pragma: no cover
59+
@wraps(func)
60+
async def wrapper(
61+
self,
62+
messages: List[BaseMessage],
63+
stop: Optional[List[str]] = None,
64+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
65+
stream: Optional[bool] = None,
66+
**kwargs: Any,
67+
) -> ChatResult:
68+
should_stream = stream if stream is not None else self.streaming
69+
if should_stream:
70+
stream_iter = self._astream(
71+
messages, stop=stop, run_manager=run_manager, **kwargs
72+
)
73+
return await agenerate_from_stream(stream_iter)
74+
else:
75+
return await func(self, messages, stop, run_manager, **kwargs)
76+
77+
return wrapper
78+
79+
5280
# NOTE: this needs to have the same name as the original class,
5381
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
5482
class ChatNVIDIA(ChatNVIDIAOriginal):
@@ -65,6 +93,21 @@ def _generate(
6593
**kwargs: Any,
6694
) -> ChatResult:
6795
return super()._generate(
96+
messages=messages,
97+
stop=stop,
98+
run_manager=run_manager,
99+
**kwargs,
100+
)
101+
102+
@async_stream_decorator
103+
async def _agenerate(
104+
self,
105+
messages: List[BaseMessage],
106+
stop: Optional[List[str]] = None,
107+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
108+
**kwargs: Any,
109+
) -> ChatResult:
110+
return await super()._agenerate(
68111
messages=messages, stop=stop, run_manager=run_manager, **kwargs
69112
)
70113

nemoguardrails/rails/llm/llmrails.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@
3030
Callable,
3131
Dict,
3232
List,
33+
Literal,
3334
Optional,
3435
Tuple,
3536
Type,
3637
Union,
3738
cast,
39+
overload,
3840
)
3941

4042
from langchain_core.language_models import BaseChatModel
@@ -1255,15 +1257,39 @@ def _validate_streaming_with_output_rails(self) -> None:
12551257
"generate_async() instead of stream_async()."
12561258
)
12571259

1260+
@overload
12581261
def stream_async(
12591262
self,
12601263
prompt: Optional[str] = None,
12611264
messages: Optional[List[dict]] = None,
12621265
options: Optional[Union[dict, GenerationOptions]] = None,
12631266
state: Optional[Union[dict, State]] = None,
1264-
include_generation_metadata: Optional[bool] = False,
1267+
include_generation_metadata: Literal[False] = False,
12651268
generator: Optional[AsyncIterator[str]] = None,
12661269
) -> AsyncIterator[str]:
1270+
...
1271+
1272+
@overload
1273+
def stream_async(
1274+
self,
1275+
prompt: Optional[str] = None,
1276+
messages: Optional[List[dict]] = None,
1277+
options: Optional[Union[dict, GenerationOptions]] = None,
1278+
state: Optional[Union[dict, State]] = None,
1279+
include_generation_metadata: Literal[True] = ...,
1280+
generator: Optional[AsyncIterator[str]] = None,
1281+
) -> AsyncIterator[Union[str, dict]]:
1282+
...
1283+
1284+
def stream_async(
1285+
self,
1286+
prompt: Optional[str] = None,
1287+
messages: Optional[List[dict]] = None,
1288+
options: Optional[Union[dict, GenerationOptions]] = None,
1289+
state: Optional[Union[dict, State]] = None,
1290+
include_generation_metadata: Optional[bool] = False,
1291+
generator: Optional[AsyncIterator[str]] = None,
1292+
) -> AsyncIterator[Union[str, dict]]:
12671293
"""Simplified interface for getting directly the streamed tokens from the LLM."""
12681294

12691295
self._validate_streaming_with_output_rails()
@@ -1328,15 +1354,24 @@ def task_done_callback(task):
13281354
self.config.rails.output.streaming
13291355
and self.config.rails.output.streaming.enabled
13301356
):
1331-
# returns an async generator
1332-
return self._run_output_rails_in_streaming(
1357+
base_iterator = self._run_output_rails_in_streaming(
13331358
streaming_handler=streaming_handler,
13341359
output_rails_streaming_config=self.config.rails.output.streaming,
13351360
messages=messages,
13361361
prompt=prompt,
13371362
)
13381363
else:
1339-
return streaming_handler
1364+
base_iterator = streaming_handler
1365+
1366+
async def wrapped_iterator():
1367+
try:
1368+
async for chunk in base_iterator:
1369+
if chunk is not None:
1370+
yield chunk
1371+
finally:
1372+
await task
1373+
1374+
return wrapped_iterator()
13401375

13411376
def generate(
13421377
self,

0 commit comments

Comments
 (0)