From f3e0f639490850352f1b83cc3e5786832a07c851 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Sun, 14 Sep 2025 23:38:53 -0500 Subject: [PATCH 1/6] Initial checkin --- nemoguardrails/server/api.py | 105 ++++++++++++------ .../server/datastore/redis_store.py | 10 +- 2 files changed, 79 insertions(+), 36 deletions(-) diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index c3c43c3e2..d6e46bc8f 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -22,7 +22,7 @@ import time import warnings from contextlib import asynccontextmanager -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -42,14 +42,32 @@ logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) + +class GuardrailsApp(FastAPI): + """Custom FastAPI subclass with additional attributes for Guardrails server.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Initialize custom attributes + self.default_config_id: Optional[str] = None + self.rails_config_path: str = "" + self.disable_chat_ui: bool = False + self.auto_reload: bool = False + self.stop_signal: bool = False + self.single_config_mode: bool = False + self.single_config_id: Optional[str] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + self.task: Optional[asyncio.Future] = None + + # The list of registered loggers. Can be used to send logs to various # backends and storage engines. -registered_loggers = [] +registered_loggers: List[Callable] = [] api_description = """Guardrails Sever API.""" # The headers for each request -api_request_headers = contextvars.ContextVar("headers") +api_request_headers: contextvars.ContextVar = contextvars.ContextVar("headers") # The datastore that the Server should use. # This is currently used only for storing threads. @@ -59,7 +77,7 @@ @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: GuardrailsApp): # Startup logic here """Register any additional challenges, if available at startup.""" challenges_files = os.path.join(app.rails_config_path, "challenges.json") @@ -82,8 +100,11 @@ async def lifespan(app: FastAPI): if os.path.exists(filepath): filename = os.path.basename(filepath) spec = importlib.util.spec_from_file_location(filename, filepath) - config_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(config_module) + if spec is not None and spec.loader is not None: + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + else: + config_module = None # If there is an `init` function, we call it with the reference to the app. if config_module is not None and hasattr(config_module, "init"): @@ -110,6 +131,7 @@ async def root_handler(): if app.auto_reload: app.loop = asyncio.get_running_loop() + # Store the future directly as task app.task = app.loop.run_in_executor(None, start_auto_reload_monitoring) yield @@ -117,14 +139,14 @@ async def root_handler(): # Shutdown logic here if app.auto_reload: app.stop_signal = True - if hasattr(app, "task"): + if hasattr(app, "task") and app.task is not None: app.task.cancel() log.info("Shutting down file observer") else: pass -app = FastAPI( +app = GuardrailsApp( title="Guardrails Server API", description=api_description, version="0.1.0", @@ -186,7 +208,7 @@ class RequestBody(BaseModel): max_length=255, description="The id of an existing thread to which the messages should be added.", ) - messages: List[dict] = Field( + messages: Optional[List[dict]] = Field( default=None, description="The list of messages in the current conversation." ) context: Optional[dict] = Field( @@ -232,7 +254,7 @@ def ensure_config_ids(cls, v, values): class ResponseBody(BaseModel): - messages: List[dict] = Field( + messages: Optional[List[dict]] = Field( default=None, description="The new messages in the conversation" ) llm_output: Optional[dict] = Field( @@ -282,8 +304,8 @@ async def get_rails_configs(): # One instance of LLMRails per config id -llm_rails_instances = {} -llm_rails_events_history_cache = {} +llm_rails_instances: dict[str, LLMRails] = {} +llm_rails_events_history_cache: dict[str, dict] = {} def _generate_cache_key(config_ids: List[str]) -> str: @@ -310,7 +332,7 @@ def _get_rails(config_ids: List[str]) -> LLMRails: # get the same thing. config_ids = [""] - full_llm_rails_config = None + full_llm_rails_config: Optional[RailsConfig] = None for config_id in config_ids: base_path = os.path.abspath(app.rails_config_path) @@ -330,6 +352,9 @@ def _get_rails(config_ids: List[str]) -> LLMRails: else: full_llm_rails_config += rails_config + if full_llm_rails_config is None: + raise ValueError("No valid rails configuration found.") + llm_rails = LLMRails(config=full_llm_rails_config, verbose=True) llm_rails_instances[configs_cache_key] = llm_rails @@ -368,22 +393,27 @@ async def chat_completion(body: RequestBody, request: Request): "No 'config_id' provided and no default configuration is set for the server. " "You must set a 'config_id' in your request or set use --default-config-id when . " ) + + # Ensure config_ids is not None before passing to _get_rails + if config_ids is None: + raise GuardrailsConfigurationError("No valid configuration IDs available.") + try: llm_rails = _get_rails(config_ids) except ValueError as ex: log.exception(ex) - return { - "messages": [ + return ResponseBody( + messages=[ { "role": "assistant", "content": f"Could not load the {config_ids} guardrails configuration. " f"An internal error has occurred.", } ] - } + ) try: - messages = body.messages + messages = body.messages or [] if body.context: messages.insert(0, {"role": "context", "content": body.context}) @@ -396,14 +426,14 @@ async def chat_completion(body: RequestBody, request: Request): # We make sure the `thread_id` meets the minimum complexity requirement. if len(body.thread_id) < 16: - return { - "messages": [ + return ResponseBody( + messages=[ { "role": "assistant", "content": "The `thread_id` must have a minimum length of 16 characters.", } ] - } + ) # Fetch the existing thread messages. For easier management, we prepend # the string `thread-` to all thread keys. @@ -440,32 +470,37 @@ async def chat_completion(body: RequestBody, request: Request): ) if isinstance(res, GenerationResponse): - bot_message = res.response[0] + bot_message_content = res.response[0] + # Ensure bot_message is always a dict + if isinstance(bot_message_content, str): + bot_message = {"role": "assistant", "content": bot_message_content} + else: + bot_message = bot_message_content else: assert isinstance(res, dict) bot_message = res # If we're using threads, we also need to update the data before returning # the message. - if body.thread_id: + if body.thread_id and datastore is not None and datastore_key is not None: await datastore.set(datastore_key, json.dumps(messages + [bot_message])) - result = {"messages": [bot_message]} + result = ResponseBody(messages=[bot_message]) # If we have additional GenerationResponse fields, we return as well if isinstance(res, GenerationResponse): - result["llm_output"] = res.llm_output - result["output_data"] = res.output_data - result["log"] = res.log - result["state"] = res.state + result.llm_output = res.llm_output + result.output_data = res.output_data + result.log = res.log + result.state = res.state return result except Exception as ex: log.exception(ex) - return { - "messages": [{"role": "assistant", "content": "Internal server error."}] - } + return ResponseBody( + messages=[{"role": "assistant", "content": "Internal server error."}] + ) # By default, there are no challenges @@ -498,7 +533,7 @@ def register_datastore(datastore_instance: DataStore): datastore = datastore_instance -def register_logger(logger: callable): +def register_logger(logger: Callable): """Register an additional logger""" registered_loggers.append(logger) @@ -510,8 +545,7 @@ def start_auto_reload_monitoring(): from watchdog.observers import Observer class Handler(FileSystemEventHandler): - @staticmethod - def on_any_event(event): + def on_any_event(self, event): if event.is_directory: return None @@ -521,7 +555,8 @@ def on_any_event(event): ) # Compute the relative path - rel_path = os.path.relpath(event.src_path, app.rails_config_path) + src_path_str = str(event.src_path) + rel_path = os.path.relpath(src_path_str, app.rails_config_path) # The config_id is the first component parts = rel_path.split(os.path.sep) @@ -530,7 +565,7 @@ def on_any_event(event): if ( not parts[-1].startswith(".") and ".ipynb_checkpoints" not in parts - and os.path.isfile(event.src_path) + and os.path.isfile(src_path_str) ): # We just remove the config from the cache so that a new one is used next time if config_id in llm_rails_instances: diff --git a/nemoguardrails/server/datastore/redis_store.py b/nemoguardrails/server/datastore/redis_store.py index 4f6437f96..6e436dbff 100644 --- a/nemoguardrails/server/datastore/redis_store.py +++ b/nemoguardrails/server/datastore/redis_store.py @@ -16,7 +16,10 @@ import asyncio from typing import Optional -import aioredis +try: + import aioredis # type: ignore[import] +except ImportError: + aioredis = None # type: ignore[assignment] from nemoguardrails.server.datastore.datastore import DataStore @@ -35,6 +38,11 @@ def __init__( username: [Optional] The username to use for authentication. password: [Optional] The password to use for authentication """ + if aioredis is None: + raise ImportError( + "aioredis is required for RedisStore. Install it with: pip install aioredis" + ) + self.url = url self.username = username self.password = password From 452d4e1765445994826d9453fa35151f7e7f4e2c Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 14 Oct 2025 11:26:21 -0500 Subject: [PATCH 2/6] Add nemoguardrails/server to pyright type-checking --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6be833997..9b21954b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,6 +161,7 @@ include = [ "nemoguardrails/kb/**", "nemoguardrails/logging/**", "nemoguardrails/tracing/**", + "nemoguardrails/server/**", "tests/test_callbacks.py", ] From 54f6994f8ecd19d7ece762ec23602253cca88766 Mon Sep 17 00:00:00 2001 From: Tim Gasser <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 28 Oct 2025 05:56:56 -0500 Subject: [PATCH 3/6] chore(types): Type-clean embeddings/ (25 errors) (#1383) --- nemoguardrails/embeddings/basic.py | 87 ++++++++++--------- nemoguardrails/embeddings/cache.py | 38 +++++--- .../embeddings/providers/azureopenai.py | 7 +- nemoguardrails/embeddings/providers/cohere.py | 15 +++- .../embeddings/providers/fastembed.py | 2 +- nemoguardrails/embeddings/providers/google.py | 2 +- nemoguardrails/embeddings/providers/nim.py | 2 +- nemoguardrails/embeddings/providers/openai.py | 6 +- .../providers/sentence_transformers.py | 4 +- pyproject.toml | 1 + 10 files changed, 94 insertions(+), 70 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index cbd48ec62..a4e497762 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -15,9 +15,9 @@ import asyncio import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast -from annoy import AnnoyIndex +from annoy import AnnoyIndex # type: ignore from nemoguardrails.embeddings.cache import cache_embeddings from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem @@ -45,26 +45,14 @@ class BasicEmbeddingsIndex(EmbeddingsIndex): max_batch_hold: The maximum time a batch is held before being processed """ - embedding_model: str - embedding_engine: str - embedding_params: Dict[str, Any] - index: AnnoyIndex - embedding_size: int - cache_config: EmbeddingsCacheConfig - embeddings: List[List[float]] - search_threshold: float - use_batching: bool - max_batch_size: int - max_batch_hold: float - def __init__( self, - embedding_model=None, - embedding_engine=None, - embedding_params=None, - index=None, - cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None, - search_threshold: float = None, + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", + embedding_engine: str = "SentenceTransformers", + embedding_params: Optional[Dict[str, Any]] = None, + index: Optional[AnnoyIndex] = None, + cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None, + search_threshold: float = float("inf"), use_batching: bool = False, max_batch_size: int = 10, max_batch_hold: float = 0.01, @@ -72,22 +60,23 @@ def __init__( """Initialize the BasicEmbeddingsIndex. Args: - embedding_model (str, optional): The model for computing embeddings. Defaults to None. - embedding_engine (str, optional): The engine for computing embeddings. Defaults to None. - index (AnnoyIndex, optional): The pre-existing index. Defaults to None. - cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None. + embedding_model: The model for computing embeddings. + embedding_engine: The engine for computing embeddings. + index: The pre-existing index. + cache_config: The cache configuration. + search_threshold: The threshold for filtering search results. use_batching: Whether to batch requests when computing the embeddings. max_batch_size: The maximum size of a batch. max_batch_hold: The maximum time a batch is held before being processed """ self._model: Optional[EmbeddingModel] = None - self._items = [] - self._embeddings = [] + self._items: List[IndexItem] = [] + self._embeddings: List[List[float]] = [] self.embedding_model = embedding_model self.embedding_engine = embedding_engine self.embedding_params = embedding_params or {} self._embedding_size = 0 - self.search_threshold = search_threshold or float("inf") + self.search_threshold = search_threshold if isinstance(cache_config, Dict): self._cache_config = EmbeddingsCacheConfig(**cache_config) else: @@ -95,12 +84,12 @@ def __init__( self._index = index # Data structures for batching embedding requests - self._req_queue = {} - self._req_results = {} - self._req_idx = 0 - self._current_batch_finished_event = None - self._current_batch_full_event = None - self._current_batch_submitted = asyncio.Event() + self._req_queue: Dict[int, str] = {} + self._req_results: Dict[int, List[float]] = {} + self._req_idx: int = 0 + self._current_batch_finished_event: Optional[asyncio.Event] = None + self._current_batch_full_event: Optional[asyncio.Event] = None + self._current_batch_submitted: asyncio.Event = asyncio.Event() # Initialize the batching configuration self.use_batching = use_batching @@ -112,6 +101,11 @@ def embeddings_index(self): """Get the current embedding index""" return self._index + @embeddings_index.setter + def embeddings_index(self, index): + """Setter to allow replacing the index dynamically.""" + self._index = index + @property def cache_config(self): """Get the cache configuration.""" @@ -127,16 +121,14 @@ def embeddings(self): """Get the computed embeddings.""" return self._embeddings - @embeddings_index.setter - def embeddings_index(self, index): - """Setter to allow replacing the index dynamically.""" - self._index = index - def _init_model(self): """Initialize the model used for computing the embeddings.""" + model = self.embedding_model + engine = self.embedding_engine + self._model = init_embedding_model( - embedding_model=self.embedding_model, - embedding_engine=self.embedding_engine, + embedding_model=model, + embedding_engine=engine, embedding_params=self.embedding_params, ) @@ -153,7 +145,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: if self._model is None: self._init_model() - embeddings = await self._model.encode_async(texts) + # self._model can't be None here, or self._init_model() would throw a ValueError + model: EmbeddingModel = cast(EmbeddingModel, self._model) + embeddings = await model.encode_async(texts) return embeddings async def add_item(self, item: IndexItem): @@ -199,6 +193,12 @@ async def _run_batch(self): """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. + if ( + self._current_batch_full_event is None + or self._current_batch_finished_event is None + ): + raise RuntimeError("Batch events not initialized. This should not happen.") + done, pending = await asyncio.wait( [ asyncio.create_task(asyncio.sleep(self.max_batch_hold)), @@ -244,7 +244,10 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: self._req_idx += 1 self._req_queue[req_id] = text - if self._current_batch_finished_event is None: + if ( + self._current_batch_finished_event is None + or self._current_batch_full_event is None + ): self._current_batch_finished_event = asyncio.Event() self._current_batch_full_event = asyncio.Event() self._current_batch_submitted.clear() diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index 9abeb1de2..cdef48c27 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -20,7 +20,12 @@ from abc import ABC, abstractmethod from functools import singledispatchmethod from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional + +try: + import redis # type: ignore +except ImportError: + redis = None # type: ignore from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig @@ -30,6 +35,8 @@ class KeyGenerator(ABC): """Abstract class for key generators.""" + name: str # Class attribute that should be defined in subclasses + @abstractmethod def generate_key(self, text: str) -> str: pass @@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str: class CacheStore(ABC): """Abstract class for cache stores.""" + name: str + @abstractmethod def get(self, key): """Get a value from the cache.""" @@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore): name = "filesystem" - def __init__(self, cache_dir: str = None): + def __init__(self, cache_dir: Optional[str] = None): self._cache_dir = Path(cache_dir or ".cache/embeddings") self._cache_dir.mkdir(parents=True, exist_ok=True) @@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore): name = "redis" def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0): - import redis - + if redis is None: + raise ImportError( + "Could not import redis, please install it with `pip install redis`." + ) self._redis = redis.Redis(host=host, port=port, db=db) def get(self, key): @@ -207,9 +218,9 @@ def clear(self): class EmbeddingsCache: def __init__( self, - key_generator: KeyGenerator = None, - cache_store: CacheStore = None, - store_config: dict = None, + key_generator: KeyGenerator, + cache_store: CacheStore, + store_config: Optional[dict] = None, ): self._key_generator = key_generator self._cache_store = cache_store @@ -218,7 +229,10 @@ def __init__( @classmethod def from_dict(cls, d: Dict[str, str]): key_generator = KeyGenerator.from_name(d.get("key_generator"))() - store_config = d.get("store_config") + store_config_raw = d.get("store_config") + store_config: dict = ( + store_config_raw if isinstance(store_config_raw, dict) else {} + ) cache_store = CacheStore.from_name(d.get("store"))(**store_config) return cls(key_generator=key_generator, cache_store=cache_store) @@ -239,7 +253,7 @@ def get_config(self): def get(self, texts): raise NotImplementedError - @get.register + @get.register(str) def _(self, text: str): key = self._key_generator.generate_key(text) log.info(f"Fetching key {key} for text '{text[:20]}...' from cache") @@ -248,7 +262,7 @@ def _(self, text: str): return result - @get.register + @get.register(list) def _(self, texts: list): cached = {} @@ -266,13 +280,13 @@ def _(self, texts: list): def set(self, texts): raise NotImplementedError - @set.register + @set.register(str) def _(self, text: str, value: List[float]): key = self._key_generator.generate_key(text) log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.") self._cache_store.set(key, value) - @set.register + @set.register(list) def _(self, texts: list, values: List[List[float]]): for text, value in zip(texts, values): self.set(text, value) diff --git a/nemoguardrails/embeddings/providers/azureopenai.py b/nemoguardrails/embeddings/providers/azureopenai.py index 5c5906d5d..e77ab481a 100644 --- a/nemoguardrails/embeddings/providers/azureopenai.py +++ b/nemoguardrails/embeddings/providers/azureopenai.py @@ -46,17 +46,16 @@ class AzureEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str): try: - from openai import AzureOpenAI + from openai import AzureOpenAI # type: ignore except ImportError: raise ImportError( - "Could not import openai, please install it with " - "`pip install openai`." + "Could not import openai, please install it with `pip install openai`." ) # Set Azure OpenAI API credentials self.client = AzureOpenAI( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore ) self.embedding_model = embedding_model diff --git a/nemoguardrails/embeddings/providers/cohere.py b/nemoguardrails/embeddings/providers/cohere.py index 34cee4156..704e0bcd7 100644 --- a/nemoguardrails/embeddings/providers/cohere.py +++ b/nemoguardrails/embeddings/providers/cohere.py @@ -14,7 +14,7 @@ # limitations under the License. import asyncio from contextvars import ContextVar -from typing import List +from typing import TYPE_CHECKING, List from .base import EmbeddingModel @@ -23,6 +23,10 @@ # is changed, it will fail. async_client_var: ContextVar = ContextVar("async_client", default=None) +if TYPE_CHECKING: + import cohere + from cohere import AsyncClient, Client + class CohereEmbeddingModel(EmbeddingModel): """ @@ -64,7 +68,7 @@ def __init__( self.model = embedding_model self.input_type = input_type - self.client = cohere.Client(**kwargs) + self.client = cohere.Client(**kwargs) # type: ignore[reportCallIssue] self.embedding_size_dict = { "embed-v4.0": 1536, @@ -120,6 +124,9 @@ def encode(self, documents: List[str]) -> List[List[float]]: """ # Make embedding request to Cohere API - return self.client.embed( + # Since we don't pass embedding_types parameter, the response should be + # EmbeddingsFloatsEmbedResponse with embeddings as List[List[float]] + response = self.client.embed( texts=documents, model=self.model, input_type=self.input_type - ).embeddings + ) + return response.embeddings # type: ignore[return-value] diff --git a/nemoguardrails/embeddings/providers/fastembed.py b/nemoguardrails/embeddings/providers/fastembed.py index 1062e566f..1359f7ab5 100644 --- a/nemoguardrails/embeddings/providers/fastembed.py +++ b/nemoguardrails/embeddings/providers/fastembed.py @@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel): engine_name = "FastEmbed" def __init__(self, embedding_model: str, **kwargs): - from fastembed import TextEmbedding as Embedding + from fastembed import TextEmbedding as Embedding # type: ignore # Enabling a short form model name for all-MiniLM-L6-v2. if embedding_model == "all-MiniLM-L6-v2": diff --git a/nemoguardrails/embeddings/providers/google.py b/nemoguardrails/embeddings/providers/google.py index cf55399af..1f78974e6 100644 --- a/nemoguardrails/embeddings/providers/google.py +++ b/nemoguardrails/embeddings/providers/google.py @@ -46,7 +46,7 @@ class GoogleEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from google import genai + from google import genai # type: ignore[import] except ImportError: raise ImportError( diff --git a/nemoguardrails/embeddings/providers/nim.py b/nemoguardrails/embeddings/providers/nim.py index dd5690a4d..8ea9c1d0f 100644 --- a/nemoguardrails/embeddings/providers/nim.py +++ b/nemoguardrails/embeddings/providers/nim.py @@ -35,7 +35,7 @@ class NIMEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings + from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings # type: ignore self.model = embedding_model self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs) diff --git a/nemoguardrails/embeddings/providers/openai.py b/nemoguardrails/embeddings/providers/openai.py index 83f83f8c2..bd12f2333 100644 --- a/nemoguardrails/embeddings/providers/openai.py +++ b/nemoguardrails/embeddings/providers/openai.py @@ -46,14 +46,14 @@ def __init__( **kwargs, ): try: - import openai - from openai import AsyncOpenAI, OpenAI + import openai # type: ignore + from openai import AsyncOpenAI, OpenAI # type: ignore except ImportError: raise ImportError( "Could not import openai, please install it with " "`pip install openai`." ) - if openai.__version__ < "1.0.0": + if openai.__version__ < "1.0.0": # type: ignore raise RuntimeError( "`openai<1.0.0` is no longer supported. " "Please upgrade using `pip install openai>=1.0.0`." diff --git a/nemoguardrails/embeddings/providers/sentence_transformers.py b/nemoguardrails/embeddings/providers/sentence_transformers.py index 7ffcec712..cc7ce7be8 100644 --- a/nemoguardrails/embeddings/providers/sentence_transformers.py +++ b/nemoguardrails/embeddings/providers/sentence_transformers.py @@ -43,7 +43,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from sentence_transformers import SentenceTransformer + from sentence_transformers import SentenceTransformer # type: ignore except ImportError: raise ImportError( "Could not import sentence-transformers, please install it with " @@ -51,7 +51,7 @@ def __init__(self, embedding_model: str, **kwargs): ) try: - from torch import cuda + from torch import cuda # type: ignore except ImportError: raise ImportError( "Could not import torch, please install it with `pip install torch`." diff --git a/pyproject.toml b/pyproject.toml index 9b21954b3..2e79e544d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,7 @@ pyright = "^1.1.405" include = [ "nemoguardrails/rails/**", "nemoguardrails/actions/**", + "nemoguardrails/embeddings/**", "nemoguardrails/cli/**", "nemoguardrails/kb/**", "nemoguardrails/logging/**", From 82903976bde5275874d6b1b105cb57acbb1a40c4 Mon Sep 17 00:00:00 2001 From: Pouyan <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 28 Oct 2025 13:59:36 +0100 Subject: [PATCH 4/6] test: restore test that was skipped due to Colang 2.0 serialization issue (#1449) --- tests/v2_x/test_passthroug_mode.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/v2_x/test_passthroug_mode.py b/tests/v2_x/test_passthroug_mode.py index c75112f56..9421a1dd1 100644 --- a/tests/v2_x/test_passthroug_mode.py +++ b/tests/v2_x/test_passthroug_mode.py @@ -81,9 +81,6 @@ def test_passthrough_llm_action_not_invoked_via_logs(self): self.assertIn("content", response) self.assertIsInstance(response["content"], str) - @unittest.skip( - reason="Github issue https://github.com/NVIDIA/NeMo-Guardrails/issues/1378" - ) def test_passthrough_llm_action_invoked_via_logs(self): chat = TestChat( config, From 63f3464c0d8b4ff64213a4cfd8c419c32cb7d263 Mon Sep 17 00:00:00 2001 From: Pouyan <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 28 Oct 2025 14:31:48 +0100 Subject: [PATCH 5/6] fix(llm): add fallback extraction for reasoning traces from tags (#1474) Adds a compatibility layer for LLM providers that don't properly populate reasoning_content in additional_kwargs. When reasoning_content is missing, the system now falls back to extracting reasoning traces from ... tags in the response content and removes the tags from the final output. This fixes compatibility with certain NVIDIA models (e.g., nvidia/llama-3.3-nemotron-super-49b-v1.5) in langchain-nvidia-ai-endpoints that include reasoning traces in tags but fail to populate the reasoning_content field. All reasoning models using ChatNVIDIA should expose reasoning content consistently through the same interface --- nemoguardrails/actions/llm/utils.py | 72 ++++++++- tests/conftest.py | 10 ++ tests/test_actions_llm_utils.py | 183 ++++++++++++++++++++++- tests/test_reasoning_trace_extraction.py | 91 +++++++++++ 4 files changed, 352 insertions(+), 4 deletions(-) diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index a89b0f8af..c6f8439c5 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import re from typing import Any, Dict, List, Optional, Sequence, Union +logger = logging.getLogger(__name__) + from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager from langchain_core.runnables import RunnableConfig @@ -238,15 +241,78 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List: def _store_reasoning_traces(response) -> None: + """Store reasoning traces from response in context variable. + + Extracts reasoning content from response.additional_kwargs["reasoning_content"] + if available. Otherwise, falls back to extracting from tags in the + response content (and removes the tags from content). + + Args: + response: The LLM response object + """ + + reasoning_content = _extract_reasoning_content(response) + + if not reasoning_content: + # Some LLM providers (e.g., certain NVIDIA models) embed reasoning in tags + # instead of properly populating reasoning_content in additional_kwargs, so we need + # both extraction methods to support different provider implementations. + reasoning_content = _extract_and_remove_think_tags(response) + + if reasoning_content: + reasoning_trace_var.set(reasoning_content) + + +def _extract_reasoning_content(response): if hasattr(response, "additional_kwargs"): additional_kwargs = response.additional_kwargs if ( isinstance(additional_kwargs, dict) and "reasoning_content" in additional_kwargs ): - reasoning_content = additional_kwargs["reasoning_content"] - if reasoning_content: - reasoning_trace_var.set(reasoning_content) + return additional_kwargs["reasoning_content"] + return None + + +def _extract_and_remove_think_tags(response) -> Optional[str]: + """Extract reasoning from tags and remove them from `response.content`. + + This function looks for ... tags in the response content, + and if found, extracts the reasoning content inside the tags. It has a side-effect: + it removes the full reasoning trace and tags from response.content. + + Args: + response: The LLM response object + + Returns: + The extracted reasoning content, or None if no tags found + """ + if not hasattr(response, "content"): + return None + + content = response.content + has_opening_tag = "" in content + has_closing_tag = "" in content + + if not has_opening_tag and not has_closing_tag: + return None + + if has_opening_tag != has_closing_tag: + logger.warning( + "Malformed tags detected: missing %s tag. " + "Skipping reasoning extraction to prevent corrupted content.", + "closing" if has_opening_tag else "opening", + ) + return None + + match = re.search(r"(.*?)", content, re.DOTALL) + if match: + reasoning_content = match.group(1).strip() + response.content = re.sub( + r".*?", "", content, flags=re.DOTALL + ).strip() + return reasoning_content + return None def _store_tool_calls(response) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 1dc00134b..2e3f0c1d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,5 +22,15 @@ ) +@pytest.fixture(autouse=True) +def reset_reasoning_trace_var(): + """Reset reasoning_trace_var before each test to prevent state leakage.""" + from nemoguardrails.context import reasoning_trace_var + + reasoning_trace_var.set(None) + yield + reasoning_trace_var.set(None) + + def pytest_configure(config): patch("prompt_toolkit.PromptSession", autospec=True).start() diff --git a/tests/test_actions_llm_utils.py b/tests/test_actions_llm_utils.py index 9c238dda2..8f0accbd2 100644 --- a/tests/test_actions_llm_utils.py +++ b/tests/test_actions_llm_utils.py @@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemoguardrails.actions.llm.utils import _infer_provider_from_module +from nemoguardrails.actions.llm.utils import ( + _extract_and_remove_think_tags, + _infer_provider_from_module, + _store_reasoning_traces, +) +from nemoguardrails.context import reasoning_trace_var class MockOpenAILLM: @@ -123,3 +128,179 @@ class Wrapper3(Wrapper2): llm = Wrapper3() provider = _infer_provider_from_module(llm) assert provider == "anthropic" + + +class MockResponse: + def __init__(self, content="", additional_kwargs=None): + self.content = content + self.additional_kwargs = additional_kwargs or {} + + +def test_store_reasoning_traces_from_additional_kwargs(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="The answer is 42", + additional_kwargs={"reasoning_content": "Let me think about this..."}, + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() == "Let me think about this..." + + +def test_store_reasoning_traces_from_think_tags(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="Let me think about this...The answer is 42" + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() == "Let me think about this..." + assert response.content == "The answer is 42" + + +def test_store_reasoning_traces_multiline_think_tags(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solutionThe answer is 42" + ) + + _store_reasoning_traces(response) + + assert ( + reasoning_trace_var.get() + == "Step 1: Analyze the problem\nStep 2: Consider options\nStep 3: Choose solution" + ) + assert response.content == "The answer is 42" + + +def test_store_reasoning_traces_prefers_additional_kwargs(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="This should not be usedThe answer is 42", + additional_kwargs={"reasoning_content": "This should be used"}, + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() == "This should be used" + + +def test_store_reasoning_traces_no_reasoning_content(): + reasoning_trace_var.set(None) + + response = MockResponse(content="The answer is 42") + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() is None + + +def test_store_reasoning_traces_empty_reasoning_content(): + reasoning_trace_var.set(None) + + response = MockResponse( + content="The answer is 42", additional_kwargs={"reasoning_content": ""} + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() is None + + +def test_store_reasoning_traces_incomplete_think_tags(): + reasoning_trace_var.set(None) + + response = MockResponse(content="This is incomplete") + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() is None + + +def test_store_reasoning_traces_no_content_attribute(): + reasoning_trace_var.set(None) + + class ResponseWithoutContent: + def __init__(self): + self.additional_kwargs = {} + + response = ResponseWithoutContent() + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() is None + + +def test_store_reasoning_traces_removes_think_tags_with_whitespace(): + reasoning_trace_var.set(None) + + response = MockResponse( + content=" reasoning here \n\n Final answer " + ) + + _store_reasoning_traces(response) + + assert reasoning_trace_var.get() == "reasoning here" + assert response.content == "Final answer" + + +def test_extract_and_remove_think_tags_basic(): + response = MockResponse(content="reasoninganswer") + + result = _extract_and_remove_think_tags(response) + + assert result == "reasoning" + assert response.content == "answer" + + +def test_extract_and_remove_think_tags_multiline(): + response = MockResponse(content="line1\nline2\nline3final answer") + + result = _extract_and_remove_think_tags(response) + + assert result == "line1\nline2\nline3" + assert response.content == "final answer" + + +def test_extract_and_remove_think_tags_no_tags(): + response = MockResponse(content="just a normal response") + + result = _extract_and_remove_think_tags(response) + + assert result is None + assert response.content == "just a normal response" + + +def test_extract_and_remove_think_tags_incomplete(): + response = MockResponse(content="incomplete") + + result = _extract_and_remove_think_tags(response) + + assert result is None + assert response.content == "incomplete" + + +def test_extract_and_remove_think_tags_no_content_attribute(): + class ResponseWithoutContent: + pass + + response = ResponseWithoutContent() + + result = _extract_and_remove_think_tags(response) + + assert result is None + + +def test_extract_and_remove_think_tags_wrong_order(): + response = MockResponse(content=" text here ") + + result = _extract_and_remove_think_tags(response) + + assert result is None + assert response.content == " text here " diff --git a/tests/test_reasoning_trace_extraction.py b/tests/test_reasoning_trace_extraction.py index b794de0e4..a74892679 100644 --- a/tests/test_reasoning_trace_extraction.py +++ b/tests/test_reasoning_trace_extraction.py @@ -304,3 +304,94 @@ async def test_reasoning_content_with_other_additional_kwargs(self): assert stored_trace == test_reasoning reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_extracts_reasoning_from_think_tags(self): + test_reasoning = "Let me analyze this step by step" + + mock_llm = AsyncMock() + mock_response = AIMessage( + content=f"{test_reasoning}The answer is 42", + additional_kwargs={}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "What is the answer?") + + assert result == "The answer is 42" + assert "" not in result + stored_trace = reasoning_trace_var.get() + assert stored_trace == test_reasoning + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_prefers_additional_kwargs_over_think_tags(self): + reasoning_from_kwargs = "This should be used" + reasoning_from_tags = "This should be ignored" + + mock_llm = AsyncMock() + mock_response = AIMessage( + content=f"{reasoning_from_tags}Response", + additional_kwargs={"reasoning_content": reasoning_from_kwargs}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "Query") + + assert result == f"{reasoning_from_tags}Response" + stored_trace = reasoning_trace_var.get() + assert stored_trace == reasoning_from_kwargs + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_extracts_multiline_reasoning_from_think_tags(self): + multiline_reasoning = """Step 1: Understand the question +Step 2: Break down the problem +Step 3: Formulate the answer""" + + mock_llm = AsyncMock() + mock_response = AIMessage( + content=f"{multiline_reasoning}Final answer", + additional_kwargs={}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "Question") + + assert result == "Final answer" + assert "" not in result + stored_trace = reasoning_trace_var.get() + assert stored_trace == multiline_reasoning + + reasoning_trace_var.set(None) + + @pytest.mark.asyncio + async def test_llm_call_handles_incomplete_think_tags(self): + mock_llm = AsyncMock() + mock_response = AIMessage( + content="This is incomplete", + additional_kwargs={}, + ) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + from nemoguardrails.actions.llm.utils import llm_call + + reasoning_trace_var.set(None) + result = await llm_call(mock_llm, "Query") + + assert result == "This is incomplete" + stored_trace = reasoning_trace_var.get() + assert stored_trace is None + + reasoning_trace_var.set(None) From 05f2a409cb274516aa53a2dd8101891b25bd25ae Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 28 Oct 2025 14:20:22 -0500 Subject: [PATCH 6/6] Clean up the config_id logic based on Traian and Greptile feedback --- nemoguardrails/server/api.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index d6e46bc8f..6769dec1e 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -385,18 +385,16 @@ async def chat_completion(body: RequestBody, request: Request): # Save the request headers in a context variable. api_request_headers.set(request.headers) + # Use Request config_ids if set, otherwise use the FastAPI default config. + # If neither is available we can't generate any completions as we have no config_id config_ids = body.config_ids - if not config_ids and app.default_config_id: - config_ids = [app.default_config_id] - elif not config_ids and not app.default_config_id: - raise GuardrailsConfigurationError( - "No 'config_id' provided and no default configuration is set for the server. " - "You must set a 'config_id' in your request or set use --default-config-id when . " - ) - - # Ensure config_ids is not None before passing to _get_rails - if config_ids is None: - raise GuardrailsConfigurationError("No valid configuration IDs available.") + if not config_ids: + if app.default_config_id: + config_ids = [app.default_config_id] + else: + raise GuardrailsConfigurationError( + "No request config_ids provided and server has no default configuration" + ) try: llm_rails = _get_rails(config_ids)