diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index c6f8439c5..c5b46e2b7 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -36,6 +36,8 @@ from nemoguardrails.logging.callbacks import logging_callbacks from nemoguardrails.logging.explain import LLMCallInfo +log = logging.getLogger(__name__) + class LLMCallException(Exception): """A wrapper around the LLM call invocation exception. @@ -113,7 +115,7 @@ def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]: return _infer_provider_from_module(llm) -def _infer_model_name(llm: BaseLanguageModel): +def _infer_model_name(llm: Union[BaseLanguageModel, Runnable]) -> str: """Helper to infer the model name based from an LLM instance. Because not all models implement correctly _identifying_params from LangChain, we have to diff --git a/nemoguardrails/logging/callbacks.py b/nemoguardrails/logging/callbacks.py index e40bd974e..fa4bdaf79 100644 --- a/nemoguardrails/logging/callbacks.py +++ b/nemoguardrails/logging/callbacks.py @@ -32,6 +32,7 @@ from nemoguardrails.logging.explain import LLMCallInfo from nemoguardrails.logging.processing_log import processing_log_var from nemoguardrails.logging.stats import LLMStats +from nemoguardrails.logging.utils import extract_model_name_and_base_url from nemoguardrails.utils import new_uuid log = logging.getLogger(__name__) @@ -64,6 +65,15 @@ async def on_llm_start( if explain_info: explain_info.llm_calls.append(llm_call_info) + # Log model name and base URL + model_name, base_url = extract_model_name_and_base_url(serialized) + if base_url: + log.info(f"Invoking LLM: model={model_name}, url={base_url}") + elif model_name: + log.info(f"Invoking LLM: model={model_name}") + else: + log.info("Invoking LLM") + log.info("Invocation Params :: %s", kwargs.get("invocation_params", {})) log.info( "Prompt :: %s", @@ -105,6 +115,15 @@ async def on_chat_model_start( if explain_info: explain_info.llm_calls.append(llm_call_info) + # Log model name and base URL + model_name, base_url = extract_model_name_and_base_url(serialized) + if base_url: + log.info(f"Invoking LLM: model={model_name}, url={base_url}") + elif model_name: + log.info(f"Invoking LLM: model={model_name}") + else: + log.info("Invoking LLM") + type_map = { "human": "User", "ai": "Bot", diff --git a/nemoguardrails/logging/utils.py b/nemoguardrails/logging/utils.py new file mode 100644 index 000000000..344cb357b --- /dev/null +++ b/nemoguardrails/logging/utils.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import Any, Dict, Optional + +log = logging.getLogger(__name__) + + +def extract_model_name_and_base_url( + serialized: Dict[str, Any] +) -> tuple[Optional[str], Optional[str]]: + """Extract model name and base URL from serialized LLM parameters. + + Args: + serialized: The serialized LLM configuration + + Returns: + A tuple of (model_name, base_url). Either value can be None if not found + """ + model_name = None + base_url = None + + # Case 1: Try to extract from kwargs (we expect kwargs to be populated for the `ChatOpenAI` class). + if "kwargs" in serialized: + kwargs = serialized["kwargs"] + + # Check for model_name in kwargs (ChatOpenAI attribute) + if "model_name" in kwargs and kwargs["model_name"]: + model_name = str(kwargs["model_name"]) + + # Check for openai_api_base in kwargs (ChatOpenAI attribute) + if "openai_api_base" in kwargs and kwargs["openai_api_base"]: + base_url = str(kwargs["openai_api_base"]) + + # Case 2: For other providers, parse `repr`, a string representation of the provider class. Since we don't + # have a reference to the actual class, we need to parse the string representation. + if "repr" in serialized and isinstance(serialized["repr"], str): + repr_str = serialized["repr"] + + # Extract model name. We expect the property to be formatted like model='...' or model_name='...', + # and check for single and double quotes. + if not model_name: + match = re.search(r"model(?:_name)?=['\"]([^'\"]+)['\"]", repr_str) + if match: + model_name = match.group(1) + + # Extract base URL. The property name may vary between providers, so try common names. + # We expect the property to be formatted like property_name='...', and check for single and double quotes. + if not base_url: + url_attrs = [ + "api_base", + "api_host", + "azure_endpoint", + "base_url", + "endpoint", + "endpoint_url", + "openai_api_base", + ] + for attr in url_attrs: + match = re.search(rf"{attr}=['\"]([^'\"]+)['\"]", repr_str) + if match: + base_url = match.group(1) + break + + return model_name, base_url diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 4c47afbfb..cea02b68a 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -31,6 +31,7 @@ from nemoguardrails.logging.callbacks import LoggingCallbackHandler from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo from nemoguardrails.logging.stats import LLMStats +from nemoguardrails.logging.utils import extract_model_name_and_base_url @pytest.mark.asyncio @@ -261,3 +262,122 @@ def __init__(self, content, msg_type): assert logged_prompt is not None assert "[cyan]Custom[/]" in logged_prompt assert "[cyan]Function[/]" in logged_prompt + + +def test_extract_model_and_url_from_kwargs(): + """Test extracting model_name and openai_api_base from kwargs (ChatOpenAI case).""" + serialized = { + "kwargs": { + "model_name": "gpt-4", + "openai_api_base": "https://api.openai.com/v1", + "temperature": 0.7, + } + } + + model_name, base_url = extract_model_name_and_base_url(serialized) + + assert model_name == "gpt-4" + assert base_url == "https://api.openai.com/v1" + + +def test_extract_model_and_url_from_repr(): + """Test extracting from repr string (ChatNIM case).""" + # Property values in single-quotes + serialized = { + "kwargs": {"temperature": 0.1}, + "repr": "ChatNIM(model='meta/llama-3.3-70b-instruct', client=, endpoint_url='https://nim.int.aire.nvidia.com/v1')", + } + + model_name, base_url = extract_model_name_and_base_url(serialized) + + assert model_name == "meta/llama-3.3-70b-instruct" + assert base_url == "https://nim.int.aire.nvidia.com/v1" + + # Property values in double-quotes + serialized = { + "repr": 'ChatOpenAI(model="gpt-3.5-turbo", base_url="https://custom.api.com/v1")' + } + + model_name, base_url = extract_model_name_and_base_url(serialized) + + assert model_name == "gpt-3.5-turbo" + assert base_url == "https://custom.api.com/v1" + + # Model is stored in the `model_name` property + serialized = { + "repr": "SomeProvider(model_name='custom-model-v2', api_base='https://example.com')" + } + + model_name, base_url = extract_model_name_and_base_url(serialized) + + assert model_name == "custom-model-v2" + assert base_url == "https://example.com" + + +def test_extract_model_and_url_from_various_url_properties(): + """Test extracting various URL property names.""" + test_cases = [ + ("api_base='https://api1.com'", "https://api1.com"), + ("api_host='https://api2.com'", "https://api2.com"), + ("azure_endpoint='https://azure.com'", "https://azure.com"), + ("endpoint='https://endpoint.com'", "https://endpoint.com"), + ("openai_api_base='https://openai.com'", "https://openai.com"), + ] + + for url_pattern, expected_url in test_cases: + serialized = {"repr": f"Provider(model='test-model', {url_pattern})"} + model_name, base_url = extract_model_name_and_base_url(serialized) + assert base_url == expected_url, f"Failed for pattern: {url_pattern}" + + +def test_extract_model_and_url_kwargs_priority_over_repr(): + """Test that kwargs values, if present, take priority over repr values.""" + serialized = { + "kwargs": { + "model_name": "gpt-4-from-kwargs", + "openai_api_base": "https://kwargs.api.com", + }, + "repr": "ChatOpenAI(model='gpt-3.5-from-repr', base_url='https://repr.api.com')", + } + + model_name, base_url = extract_model_name_and_base_url(serialized) + + assert model_name == "gpt-4-from-kwargs" + assert base_url == "https://kwargs.api.com" + + +def test_extract_model_and_url_with_missing_values(): + """Test extraction when values are missing.""" + # No model or URL + serialized = {"kwargs": {"temperature": 0.7}} + model_name, base_url = extract_model_name_and_base_url(serialized) + assert model_name is None + assert base_url is None + + # Only model, no URL + serialized = {"kwargs": {"model_name": "gpt-4"}} + model_name, base_url = extract_model_name_and_base_url(serialized) + assert model_name == "gpt-4" + assert base_url is None + + # Only URL, no model + serialized = {"repr": "Provider(endpoint_url='https://example.com')"} + model_name, base_url = extract_model_name_and_base_url(serialized) + assert model_name is None + assert base_url == "https://example.com" + + +def test_extract_model_and_url_with_empty_values(): + """Test extraction when values are empty strings.""" + serialized = {"kwargs": {"model_name": "", "openai_api_base": ""}} + model_name, base_url = extract_model_name_and_base_url(serialized) + assert model_name is None + assert base_url is None + + +def test_extract_model_and_url_with_empty_serialized_data(): + """Test extraction with empty or minimal serialized dict.""" + serialized = {} + model_name, base_url = extract_model_name_and_base_url(serialized) + assert model_name is None + assert base_url is None