Skip to content

Commit cd14a07

Browse files
authored
feat(streaming): support external async token generators (#1286)
* feat(streaming): support external async token generators Add ability to pass custom async token generators to `stream_async`, enabling integration with external LLMs or custom streaming sources. Update docs and add tests for output rails interaction and edge cases with external generators.
1 parent 9b0a6cd commit cd14a07

File tree

3 files changed

+336
-2
lines changed

3 files changed

+336
-2
lines changed

docs/user-guides/advanced/streaming.md

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,75 @@ result = await app.generate_async(
7171
print(result)
7272
```
7373

74-
For the complete working example, check out this [demo script](https://github.com/NVIDIA/NeMo-Guardrails/tree/develop/examples/scripts/demo_streaming.py).
74+
### Using External Token Generators
75+
76+
You can also provide your own async generator that yields tokens, which is useful when:
77+
78+
- You want to use a different LLM provider that has its own streaming API
79+
- You have pre-generated responses that you want to stream through guardrails
80+
- You want to implement custom token generation logic
81+
- You want to test your output rails or its config in streaming mode on predefined responses without actually relying on an actual LLM generation.
82+
83+
To use an external generator, pass it to the `generator` parameter of `stream_async`:
84+
85+
```python
86+
from nemoguardrails import LLMRails
87+
from typing import AsyncIterator
88+
89+
app = LLMRails(config)
90+
91+
async def my_token_generator() -> AsyncIterator[str]:
92+
# This could be from OpenAI API, Anthropic API, or any other LLM API that already has a streaming token generator. Mocking the stream here, for a simple example.
93+
tokens = ["Hello", " ", "world", "!"]
94+
for token in tokens:
95+
yield token
96+
97+
messages = [{"role": "user", "content": "The most famous program ever written is"}]"}]
98+
99+
# use the external generator with guardrails
100+
async for chunk in app.stream_async(
101+
messages=messages,
102+
generator=my_token_generator()
103+
):
104+
print(f"CHUNK: {chunk}")
105+
```
106+
107+
When using an external generator:
108+
109+
- The internal LLM generation is completely bypassed
110+
- Output rails are still applied to the LLM responses returned by the external generator, if configured
111+
- The generator should yield string tokens
112+
113+
Example with a real LLM API:
114+
115+
```python
116+
async def openai_streaming_generator(messages) -> AsyncIterator[str]:
117+
"""Example using OpenAI's streaming API."""
118+
import openai
119+
120+
stream = await openai.ChatCompletion.create(
121+
model="gpt-4o",
122+
messages=messages,
123+
stream=True
124+
)
125+
126+
# Yield tokens as they arrive
127+
async for chunk in stream:
128+
if chunk.choices[0].delta.content:
129+
yield chunk.choices[0].delta.content
130+
131+
config = RailsConfig.from_path("config/with_output_rails")
132+
app = LLMRails(config)
133+
134+
async for chunk in app.stream_async(
135+
messages=[{"role": "user", "content": "Tell me a story"}],
136+
generator=openai_streaming_generator(messages)
137+
):
138+
# output rails will be applied to these chunks
139+
print(chunk, end="", flush=True)
140+
```
141+
142+
This feature enables seamless integration of NeMo Guardrails with any streaming LLM or token source while maintaining all the safety features of output rails.
75143

76144
### Server API
77145

nemoguardrails/rails/llm/llmrails.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,8 +1063,21 @@ def stream_async(
10631063
options: Optional[Union[dict, GenerationOptions]] = None,
10641064
state: Optional[Union[dict, State]] = None,
10651065
include_generation_metadata: Optional[bool] = False,
1066+
generator: Optional[AsyncIterator[str]] = None,
10661067
) -> AsyncIterator[str]:
10671068
"""Simplified interface for getting directly the streamed tokens from the LLM."""
1069+
1070+
# if an external generator is provided, use it directly
1071+
if generator:
1072+
if self.config.rails.output.streaming.enabled:
1073+
return self._run_output_rails_in_streaming(
1074+
streaming_handler=generator,
1075+
messages=messages,
1076+
prompt=prompt,
1077+
)
1078+
else:
1079+
return generator
1080+
10681081
self.explain_info = self._ensure_explain_info()
10691082

10701083
streaming_handler = StreamingHandler(

tests/test_streaming_output_rails.py

Lines changed: 254 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
import asyncio
1919
import json
20-
import math
2120
from json.decoder import JSONDecodeError
21+
from typing import AsyncIterator
2222

2323
import pytest
2424

@@ -250,3 +250,256 @@ async def test_streaming_output_rails_default_config_not_blocked_at_start(
250250
json.loads(chunks[0])
251251

252252
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
253+
254+
255+
async def simple_token_generator() -> AsyncIterator[str]:
256+
"""Simple generator that yields tokens."""
257+
tokens = ["Hello", " ", "world", "!"]
258+
for token in tokens:
259+
yield token
260+
261+
262+
async def offensive_token_generator() -> AsyncIterator[str]:
263+
"""Generator that yields potentially offensive content."""
264+
265+
tokens = ["This", " ", "is", " ", "offensive", " ", "content", " ", "idiot", "!"]
266+
for token in tokens:
267+
yield token
268+
269+
270+
@pytest.mark.asyncio
271+
async def test_external_generator_without_output_rails():
272+
"""Test that external generator works without output rails."""
273+
config = RailsConfig.from_content(
274+
config={
275+
"models": [],
276+
"rails": {},
277+
"streaming": True,
278+
}
279+
)
280+
281+
rails = LLMRails(config)
282+
283+
tokens = []
284+
async for token in rails.stream_async(generator=simple_token_generator()):
285+
tokens.append(token)
286+
287+
assert tokens == ["Hello", " ", "world", "!"]
288+
assert "".join(tokens) == "Hello world!"
289+
290+
291+
@pytest.mark.asyncio
292+
async def test_external_generator_with_output_rails_allowed():
293+
"""Test that external generator works with output rails that allow content."""
294+
config = RailsConfig.from_content(
295+
config={
296+
"models": [],
297+
"rails": {
298+
"output": {
299+
"flows": ["self check output"],
300+
"streaming": {
301+
"enabled": True,
302+
"chunk_size": 4,
303+
"context_size": 2,
304+
"stream_first": False,
305+
},
306+
}
307+
},
308+
"streaming": True,
309+
"prompts": [
310+
{"task": "self_check_output", "content": "Check: {{ bot_response }}"}
311+
],
312+
},
313+
colang_content="""
314+
define flow self check output
315+
execute self_check_output
316+
""",
317+
)
318+
319+
rails = LLMRails(config)
320+
321+
@action(name="self_check_output")
322+
async def self_check_output(**kwargs):
323+
return True
324+
325+
rails.register_action(self_check_output, "self_check_output")
326+
327+
tokens = []
328+
async for token in rails.stream_async(
329+
generator=simple_token_generator(),
330+
messages=[{"role": "user", "content": "Hello"}],
331+
):
332+
tokens.append(token)
333+
334+
assert tokens == ["Hello", " ", "world", "!"]
335+
336+
337+
@pytest.mark.asyncio
338+
async def test_external_generator_with_output_rails_blocked():
339+
"""Test that external generator content can be blocked by output rails."""
340+
config = RailsConfig.from_content(
341+
config={
342+
"models": [],
343+
"rails": {
344+
"output": {
345+
"flows": ["self check output"],
346+
"streaming": {
347+
"enabled": True,
348+
"chunk_size": 6,
349+
"context_size": 2,
350+
"stream_first": False,
351+
},
352+
}
353+
},
354+
"streaming": True,
355+
"prompts": [
356+
{"task": "self_check_output", "content": "Check: {{ bot_response }}"}
357+
],
358+
},
359+
colang_content="""
360+
define flow self check output
361+
execute self_check_output
362+
""",
363+
)
364+
365+
rails = LLMRails(config)
366+
367+
@action(name="self_check_output")
368+
async def self_check_output(**kwargs):
369+
bot_message = kwargs.get(
370+
"bot_message", kwargs.get("context", {}).get("bot_message", "")
371+
)
372+
# block if message contains "offensive" or "idiot"
373+
if "offensive" in bot_message.lower() or "idiot" in bot_message.lower():
374+
return False
375+
return True
376+
377+
rails.register_action(self_check_output, "self_check_output")
378+
379+
tokens = []
380+
error_received = False
381+
382+
async for token in rails.stream_async(
383+
generator=offensive_token_generator(),
384+
messages=[{"role": "user", "content": "Generate something"}],
385+
):
386+
if isinstance(token, str) and token.startswith('{"error"'):
387+
error_received = True
388+
break
389+
tokens.append(token)
390+
391+
assert error_received, "Expected to receive an error JSON when content is blocked"
392+
assert len(tokens) == 0
393+
394+
395+
@pytest.mark.asyncio
396+
async def test_external_generator_with_custom_llm():
397+
"""Test using external generator as a custom LLM replacement."""
398+
399+
async def custom_llm_generator(messages):
400+
"""Simulate a custom LLM that generates based on input."""
401+
402+
user_message = messages[-1]["content"] if messages else ""
403+
404+
if "weather" in user_message.lower():
405+
response = "The weather is sunny today!"
406+
elif "name" in user_message.lower():
407+
response = "I am an AI assistant."
408+
else:
409+
response = "I can help you with that."
410+
411+
for token in response.split(" "):
412+
yield token + " "
413+
414+
config = RailsConfig.from_content(
415+
config={
416+
"models": [],
417+
"rails": {},
418+
"streaming": True,
419+
}
420+
)
421+
422+
rails = LLMRails(config)
423+
424+
messages = [{"role": "user", "content": "What's the weather?"}]
425+
tokens = []
426+
427+
async for token in rails.stream_async(
428+
generator=custom_llm_generator(messages), messages=messages
429+
):
430+
tokens.append(token)
431+
432+
result = "".join(tokens).strip()
433+
assert result == "The weather is sunny today!"
434+
435+
436+
@pytest.mark.asyncio
437+
async def test_external_generator_empty_stream():
438+
"""Test that empty generator streams work correctly."""
439+
440+
async def empty_generator():
441+
if False:
442+
yield
443+
444+
config = RailsConfig.from_content(
445+
config={
446+
"models": [],
447+
"rails": {},
448+
"streaming": True,
449+
}
450+
)
451+
452+
rails = LLMRails(config)
453+
454+
tokens = []
455+
async for token in rails.stream_async(generator=empty_generator()):
456+
tokens.append(token)
457+
458+
assert tokens == []
459+
460+
461+
@pytest.mark.asyncio
462+
async def test_external_generator_single_chunk():
463+
"""Test generator that yields a single large chunk."""
464+
465+
async def single_chunk_generator():
466+
yield "This is a complete response in a single chunk."
467+
468+
config = RailsConfig.from_content(
469+
config={
470+
"models": [],
471+
"rails": {
472+
"output": {
473+
"flows": ["self check output"],
474+
"streaming": {
475+
"enabled": True,
476+
"chunk_size": 10,
477+
"context_size": 5,
478+
"stream_first": True,
479+
},
480+
}
481+
},
482+
"streaming": True,
483+
"prompts": [
484+
{"task": "self_check_output", "content": "Check: {{ bot_response }}"}
485+
],
486+
},
487+
colang_content="""
488+
define flow self check output
489+
execute self_check_output
490+
""",
491+
)
492+
493+
rails = LLMRails(config)
494+
495+
@action(name="self_check_output")
496+
async def self_check_output(**kwargs):
497+
return True
498+
499+
rails.register_action(self_check_output, "self_check_output")
500+
501+
tokens = []
502+
async for token in rails.stream_async(generator=single_chunk_generator()):
503+
tokens.append(token)
504+
505+
assert "".join(tokens) == "This is a complete response in a single chunk."

0 commit comments

Comments
 (0)