|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import json |
16 | 17 | from unittest.mock import MagicMock |
17 | 18 |
|
18 | 19 | import pytest |
19 | 20 | from langchain.llms.base import BaseLLM |
20 | 21 | from pydantic import ValidationError |
21 | 22 |
|
22 | | -from nemoguardrails.rails.llm.config import ( |
23 | | - Document, |
24 | | - Instruction, |
25 | | - Model, |
26 | | - RailsConfig, |
27 | | - TaskPrompt, |
28 | | -) |
| 23 | +from nemoguardrails.rails.llm.config import Model, RailsConfig, TaskPrompt |
29 | 24 | from nemoguardrails.rails.llm.llmrails import LLMRails |
30 | 25 |
|
31 | 26 |
|
@@ -342,3 +337,45 @@ def test_llm_rails_configure_streaming_without_attr(caplog): |
342 | 337 | rails._configure_main_llm_streaming(mock_llm) |
343 | 338 |
|
344 | 339 | assert caplog.messages[-1] == "Provided main LLM does not support streaming." |
| 340 | + |
| 341 | + |
| 342 | +def test_rails_config_streaming_supported_no_output_flows(): |
| 343 | + """Check `streaming_supported` property doesn't depend on RailsConfig.streaming with no output flows""" |
| 344 | + |
| 345 | + config = RailsConfig( |
| 346 | + models=[], |
| 347 | + streaming=False, |
| 348 | + ) |
| 349 | + assert config.streaming_supported |
| 350 | + |
| 351 | + |
| 352 | +def test_rails_config_flows_streaming_supported_true(): |
| 353 | + """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" |
| 354 | + |
| 355 | + rails = { |
| 356 | + "output": { |
| 357 | + "flows": ["content_safety_check_output"], |
| 358 | + "streaming": {"enabled": True}, |
| 359 | + } |
| 360 | + } |
| 361 | + prompts = [{"task": "content safety check output", "content": "..."}] |
| 362 | + rails_config = RailsConfig.model_validate( |
| 363 | + {"models": [], "rails": rails, "prompts": prompts} |
| 364 | + ) |
| 365 | + assert rails_config.streaming_supported |
| 366 | + |
| 367 | + |
| 368 | +def test_rails_config_flows_streaming_supported_false(): |
| 369 | + """Create RailsConfig and check the `streaming_supported Check LLM has the streaming attribute set if RailsConfig has it""" |
| 370 | + |
| 371 | + rails = { |
| 372 | + "output": { |
| 373 | + "flows": ["content_safety_check_output"], |
| 374 | + "streaming": {"enabled": False}, |
| 375 | + } |
| 376 | + } |
| 377 | + prompts = [{"task": "content safety check output", "content": "..."}] |
| 378 | + rails_config = RailsConfig.model_validate( |
| 379 | + {"models": [], "rails": rails, "prompts": prompts} |
| 380 | + ) |
| 381 | + assert not rails_config.streaming_supported |
0 commit comments