|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +from unittest.mock import MagicMock |
| 17 | + |
16 | 18 | import pytest |
| 19 | +from langchain.llms.base import BaseLLM |
17 | 20 | from pydantic import ValidationError |
18 | 21 |
|
19 | 22 | from nemoguardrails.rails.llm.config import ( |
|
23 | 26 | RailsConfig, |
24 | 27 | TaskPrompt, |
25 | 28 | ) |
| 29 | +from nemoguardrails.rails.llm.llmrails import LLMRails |
26 | 30 |
|
27 | 31 |
|
28 | 32 | def test_task_prompt_valid_content(): |
@@ -307,3 +311,34 @@ def test_rails_config_none_config_path(): |
307 | 311 |
|
308 | 312 | result2 = config3 + config4 |
309 | 313 | assert result2.config_path == "" |
| 314 | + |
| 315 | + |
| 316 | +def test_llm_rails_configure_streaming_with_attr(): |
| 317 | + """Check LLM has the streaming attribute set if RailsConfig has it""" |
| 318 | + |
| 319 | + mock_llm = MagicMock(spec=BaseLLM) |
| 320 | + config = RailsConfig( |
| 321 | + models=[], |
| 322 | + streaming=True, |
| 323 | + ) |
| 324 | + |
| 325 | + rails = LLMRails(config, llm=mock_llm) |
| 326 | + setattr(mock_llm, "streaming", None) |
| 327 | + rails._configure_main_llm_streaming(llm=mock_llm) |
| 328 | + |
| 329 | + assert mock_llm.streaming |
| 330 | + |
| 331 | + |
| 332 | +def test_llm_rails_configure_streaming_without_attr(caplog): |
| 333 | + """Check LLM has the streaming attribute set if RailsConfig has it""" |
| 334 | + |
| 335 | + mock_llm = MagicMock(spec=BaseLLM) |
| 336 | + config = RailsConfig( |
| 337 | + models=[], |
| 338 | + streaming=True, |
| 339 | + ) |
| 340 | + |
| 341 | + rails = LLMRails(config, llm=mock_llm) |
| 342 | + rails._configure_main_llm_streaming(mock_llm) |
| 343 | + |
| 344 | + assert caplog.messages[-1] == "Provided main LLM does not support streaming." |
0 commit comments