Skip to content

Commit 3673cc1

Browse files
committed
Add coverage to rails/llm/options.py
1 parent bcb522d commit 3673cc1

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

tests/test_llm_options.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for LLM isolation with models that don't have model_kwargs field."""
17+
18+
from typing import Any, Dict, List, Optional
19+
from unittest.mock import Mock
20+
21+
import pytest
22+
from langchain_core.language_models import BaseChatModel
23+
from langchain_core.messages import BaseMessage
24+
from langchain_core.outputs import ChatGeneration, ChatResult
25+
from pydantic import BaseModel, Field
26+
27+
from nemoguardrails.rails.llm.config import RailsConfig
28+
from nemoguardrails.rails.llm.llmrails import LLMRails
29+
from nemoguardrails.rails.llm.options import GenerationLog, GenerationStats
30+
31+
32+
def test_generation_log_print_summary(capsys):
33+
"""Test printing rais stats with dummy data"""
34+
35+
stats = GenerationStats(
36+
input_rails_duration=1.0,
37+
dialog_rails_duration=2.0,
38+
generation_rails_duration=3.0,
39+
output_rails_duration=4.0,
40+
total_duration=10.0, # Sum of all previous rail durations
41+
llm_calls_duration=8.0, # Less than total duration
42+
llm_calls_count=4, # Input, dialog, generation and output calls
43+
llm_calls_total_prompt_tokens=1000,
44+
llm_calls_total_completion_tokens=2000,
45+
llm_calls_total_tokens=3000, # Sum of prompt and completion tokens
46+
)
47+
48+
generation_log = GenerationLog(activated_rails=[], stats=stats)
49+
50+
generation_log.print_summary()
51+
capture = capsys.readouterr()
52+
capture_lines = capture.out.splitlines()
53+
54+
# Check the correct times were printed
55+
assert capture_lines[1] == "# General stats"
56+
assert capture_lines[3] == "- Total time: 10.00s"
57+
assert capture_lines[4] == " - [1.00s][10.0%]: INPUT Rails"
58+
assert capture_lines[5] == " - [2.00s][20.0%]: DIALOG Rails"
59+
assert capture_lines[6] == " - [3.00s][30.0%]: GENERATION Rails"
60+
assert capture_lines[7] == " - [4.00s][40.0%]: OUTPUT Rails"
61+
assert (
62+
capture_lines[8]
63+
== "- 4 LLM calls, 8.00s total duration, 1000 total prompt tokens, 2000 total completion tokens, 3000 total tokens."
64+
)

0 commit comments

Comments
 (0)