Skip to content

Commit d80ef5e

Browse files
committed
fix(streaming): raise error when stream_async used with disabled output rails streaming
When output rails are configured but output.streaming.enabled is False (or not set), calling stream_async() would result in undefined behavior or hangs due to the conflict between streaming expectations and blocking output rail processing. This change adds explicit validation in stream_async() to detect this misconfiguration and raise a clear ValueError with actionable guidance: - Set rails.output.streaming.enabled = True to use streaming with output rails - Use generate_async() instead for non-streaming with output rails Updated affected tests to expect and validate the new error behavior instead of relying on the previous buggy behavior.
1 parent 912cdea commit d80ef5e

File tree

4 files changed

+93
-50
lines changed

4 files changed

+93
-50
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,16 @@ def stream_async(
12511251
) -> AsyncIterator[str]:
12521252
"""Simplified interface for getting directly the streamed tokens from the LLM."""
12531253

1254+
if len(self.config.rails.output.flows) > 0 and (
1255+
not self.config.rails.output.streaming
1256+
or not self.config.rails.output.streaming.enabled
1257+
):
1258+
raise ValueError(
1259+
"stream_async() cannot be used when output rails are configured but "
1260+
"output.streaming.enabled is False. Either set "
1261+
"rails.output.streaming.enabled to True in your configuration, or use "
1262+
"generate_async() instead of stream_async()."
1263+
)
12541264
# if an external generator is provided, use it directly
12551265
if generator:
12561266
if (

tests/test_parallel_streaming_output_rails.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -605,21 +605,21 @@ async def test_parallel_streaming_output_rails_performance_benefits():
605605
async def test_parallel_streaming_output_rails_default_config_behavior(
606606
parallel_output_rails_default_config,
607607
):
608-
"""Tests parallel output rails with default streaming configuration"""
608+
"""Tests that stream_async raises an error with default config (no explicit streaming config)"""
609609

610-
llm_completions = [
611-
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
612-
' "This is a test message with default streaming config."',
613-
]
610+
from nemoguardrails import LLMRails
614611

615-
chunks = await run_parallel_self_check_test(
616-
parallel_output_rails_default_config, llm_completions
617-
)
612+
llmrails = LLMRails(parallel_output_rails_default_config)
618613

619-
response = "".join(chunks)
620-
assert len(response) > 0
621-
assert len(chunks) > 0
622-
assert "test message" in response
614+
with pytest.raises(ValueError) as exc_info:
615+
async for chunk in llmrails.stream_async(
616+
messages=[{"role": "user", "content": "Hi!"}]
617+
):
618+
pass
619+
620+
assert "stream_async() cannot be used when output rails are configured" in str(
621+
exc_info.value
622+
)
623623

624624
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
625625

tests/test_streaming.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,50 @@ def _calculate_number_of_actions(input_length, chunk_size, context_size):
474474
return math.ceil((input_length - context_size) / (chunk_size - context_size))
475475

476476

477+
@pytest.mark.asyncio
478+
async def test_streaming_with_output_rails_disabled_raises_error():
479+
config = RailsConfig.from_content(
480+
config={
481+
"models": [],
482+
"rails": {
483+
"output": {
484+
"flows": {"self check output"},
485+
"streaming": {
486+
"enabled": False,
487+
},
488+
}
489+
},
490+
"streaming": True,
491+
"prompts": [{"task": "self_check_output", "content": "a test template"}],
492+
},
493+
colang_content="""
494+
define user express greeting
495+
"hi"
496+
497+
define flow
498+
user express greeting
499+
bot tell joke
500+
""",
501+
)
502+
503+
chat = TestChat(
504+
config,
505+
llm_completions=[],
506+
streaming=True,
507+
)
508+
509+
with pytest.raises(ValueError) as exc_info:
510+
async for chunk in chat.app.stream_async(
511+
messages=[{"role": "user", "content": "Hi!"}],
512+
):
513+
pass
514+
515+
assert "stream_async() cannot be used when output rails are configured" in str(
516+
exc_info.value
517+
)
518+
assert "output.streaming.enabled is False" in str(exc_info.value)
519+
520+
477521
@pytest.mark.asyncio
478522
async def test_streaming_error_handling():
479523
"""Test that errors during streaming are properly formatted and returned."""

tests/test_streaming_output_rails.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,16 @@ def output_rails_streaming_config_default():
8989

9090
@pytest.mark.asyncio
9191
async def test_stream_async_streaming_disabled(output_rails_streaming_config_default):
92-
"""Tests if stream_async returns a StreamingHandler instance when streaming is disabled"""
92+
"""Tests that stream_async raises an error when output rails are configured but streaming is disabled"""
9393

9494
llmrails = LLMRails(output_rails_streaming_config_default)
9595

96-
result = llmrails.stream_async(prompt="test")
97-
assert isinstance(
98-
result, StreamingHandler
99-
), "Expected StreamingHandler instance when streaming is disabled"
96+
with pytest.raises(ValueError) as exc_info:
97+
llmrails.stream_async(prompt="test")
98+
99+
assert "stream_async() cannot be used when output rails are configured" in str(
100+
exc_info.value
101+
)
100102

101103

102104
@pytest.mark.asyncio
@@ -175,32 +177,19 @@ async def test_streaming_output_rails_blocked_explicit(output_rails_streaming_co
175177
async def test_streaming_output_rails_blocked_default_config(
176178
output_rails_streaming_config_default,
177179
):
178-
"""Tests if output rails streaming default config do not block content with BLOCK keyword"""
180+
"""Tests that stream_async raises an error with default config (output rails without explicit streaming config)"""
179181

180-
# text with a BLOCK keyword
181-
llm_completions = [
182-
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
183-
' "This is a [BLOCK] joke that should be blocked."',
184-
]
182+
llmrails = LLMRails(output_rails_streaming_config_default)
185183

186-
chunks = await run_self_check_test(
187-
output_rails_streaming_config_default, llm_completions
188-
)
184+
with pytest.raises(ValueError) as exc_info:
185+
async for chunk in llmrails.stream_async(
186+
messages=[{"role": "user", "content": "Hi!"}]
187+
):
188+
pass
189189

190-
expected_error = {
191-
"error": {
192-
"message": "Blocked by self check output rails.",
193-
"type": "guardrails_violation",
194-
"param": "self check output",
195-
"code": "content_blocked",
196-
}
197-
}
198-
199-
error_chunks = [
200-
json.loads(chunk) for chunk in chunks if chunk.startswith('{"error":')
201-
]
202-
assert len(error_chunks) == 0
203-
assert expected_error not in error_chunks
190+
assert "stream_async() cannot be used when output rails are configured" in str(
191+
exc_info.value
192+
)
204193

205194
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
206195

@@ -235,19 +224,19 @@ async def test_streaming_output_rails_blocked_at_start(output_rails_streaming_co
235224
async def test_streaming_output_rails_default_config_not_blocked_at_start(
236225
output_rails_streaming_config_default,
237226
):
238-
"""Tests blocking with BLOCK at the very beginning of the response does not return abort sse"""
227+
"""Tests that stream_async raises an error with default config (output rails without explicit streaming config)"""
239228

240-
llm_completions = [
241-
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
242-
' "[BLOCK] This should be blocked immediately at the start."',
243-
]
229+
llmrails = LLMRails(output_rails_streaming_config_default)
244230

245-
chunks = await run_self_check_test(
246-
output_rails_streaming_config_default, llm_completions
247-
)
231+
with pytest.raises(ValueError) as exc_info:
232+
async for chunk in llmrails.stream_async(
233+
messages=[{"role": "user", "content": "Hi!"}]
234+
):
235+
pass
248236

249-
with pytest.raises(JSONDecodeError):
250-
json.loads(chunks[0])
237+
assert "stream_async() cannot be used when output rails are configured" in str(
238+
exc_info.value
239+
)
251240

252241
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
253242

0 commit comments

Comments
 (0)