From d8aa0643064dc133a373143be28dac6fa4357b0f Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:39:57 -0500 Subject: [PATCH 1/9] Cleaned embeddings/ --- nemoguardrails/embeddings/basic.py | 85 +++++++++++-------- nemoguardrails/embeddings/cache.py | 57 ++++++++----- .../embeddings/providers/fastembed.py | 2 +- nemoguardrails/embeddings/providers/openai.py | 6 +- .../providers/sentence_transformers.py | 4 +- 5 files changed, 93 insertions(+), 61 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index cbd48ec62..af2635fc7 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -17,7 +17,7 @@ import logging from typing import Any, Dict, List, Optional, Union -from annoy import AnnoyIndex +from annoy import AnnoyIndex # type: ignore from nemoguardrails.embeddings.cache import cache_embeddings from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem @@ -45,26 +45,16 @@ class BasicEmbeddingsIndex(EmbeddingsIndex): max_batch_hold: The maximum time a batch is held before being processed """ - embedding_model: str - embedding_engine: str - embedding_params: Dict[str, Any] - index: AnnoyIndex - embedding_size: int - cache_config: EmbeddingsCacheConfig - embeddings: List[List[float]] - search_threshold: float - use_batching: bool - max_batch_size: int - max_batch_hold: float + # Instance attributes are defined in __init__ and accessed via properties def __init__( self, - embedding_model=None, - embedding_engine=None, - embedding_params=None, - index=None, - cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None, - search_threshold: float = None, + embedding_model: Optional[str] = None, + embedding_engine: Optional[str] = None, + embedding_params: Optional[Dict[str, Any]] = None, + index: Optional[AnnoyIndex] = None, + cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None, + search_threshold: Optional[float] = None, use_batching: bool = False, max_batch_size: int = 10, max_batch_hold: float = 0.01, @@ -81,10 +71,10 @@ def __init__( max_batch_hold: The maximum time a batch is held before being processed """ self._model: Optional[EmbeddingModel] = None - self._items = [] - self._embeddings = [] - self.embedding_model = embedding_model - self.embedding_engine = embedding_engine + self._items: List[IndexItem] = [] + self._embeddings: List[List[float]] = [] + self.embedding_model: Optional[str] = embedding_model + self.embedding_engine: Optional[str] = embedding_engine self.embedding_params = embedding_params or {} self._embedding_size = 0 self.search_threshold = search_threshold or float("inf") @@ -95,12 +85,12 @@ def __init__( self._index = index # Data structures for batching embedding requests - self._req_queue = {} - self._req_results = {} - self._req_idx = 0 - self._current_batch_finished_event = None - self._current_batch_full_event = None - self._current_batch_submitted = asyncio.Event() + self._req_queue: Dict[int, str] = {} + self._req_results: Dict[int, List[float]] = {} + self._req_idx: int = 0 + self._current_batch_finished_event: Optional[asyncio.Event] = None + self._current_batch_full_event: Optional[asyncio.Event] = None + self._current_batch_submitted: asyncio.Event = asyncio.Event() # Initialize the batching configuration self.use_batching = use_batching @@ -112,6 +102,11 @@ def embeddings_index(self): """Get the current embedding index""" return self._index + @embeddings_index.setter + def embeddings_index(self, index): + """Setter to allow replacing the index dynamically.""" + self._index = index + @property def cache_config(self): """Get the cache configuration.""" @@ -127,19 +122,23 @@ def embeddings(self): """Get the computed embeddings.""" return self._embeddings - @embeddings_index.setter - def embeddings_index(self, index): - """Setter to allow replacing the index dynamically.""" - self._index = index - def _init_model(self): """Initialize the model used for computing the embeddings.""" + # Provide defaults if not specified + model = self.embedding_model or "sentence-transformers/all-MiniLM-L6-v2" + engine = self.embedding_engine or "SentenceTransformers" + self._model = init_embedding_model( - embedding_model=self.embedding_model, - embedding_engine=self.embedding_engine, + embedding_model=model, + embedding_engine=engine, embedding_params=self.embedding_params, ) + if not self._model: + raise ValueError( + f"Couldn't create embedding model with model {model} and engine {engine}" + ) + @cache_embeddings async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: """Compute embeddings for a list of texts. @@ -153,6 +152,8 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: if self._model is None: self._init_model() + if not self._model: + raise Exception("Couldn't initialize embedding model") embeddings = await self._model.encode_async(texts) return embeddings @@ -199,6 +200,10 @@ async def _run_batch(self): """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. + if not self._current_batch_full_event: + raise Exception("self._current_batch_full_event not initialized") + + assert self._current_batch_full_event is not None done, pending = await asyncio.wait( [ asyncio.create_task(asyncio.sleep(self.max_batch_hold)), @@ -210,6 +215,10 @@ async def _run_batch(self): task.cancel() # Reset the batch event + if not self._current_batch_finished_event: + raise Exception("self._current_batch_finished_event not initialized") + + assert self._current_batch_finished_event is not None batch_event: asyncio.Event = self._current_batch_finished_event self._current_batch_finished_event = None @@ -252,9 +261,13 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: # We check if we reached the max batch size if len(self._req_queue) >= self.max_batch_size: + if not self._current_batch_full_event: + raise Exception("self._current_batch_full_event not initialized") self._current_batch_full_event.set() - # Wait for the batch to finish + # Wait for the batch to finish + if not self._current_batch_finished_event: + raise Exception("self._current_batch_finished_event not initialized") await self._current_batch_finished_event.wait() # Remove the result and return it diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index 9abeb1de2..1f49c37da 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -20,7 +20,12 @@ from abc import ABC, abstractmethod from functools import singledispatchmethod from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional + +try: + import redis # type: ignore +except ImportError: + redis = None # type: ignore from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig @@ -30,6 +35,8 @@ class KeyGenerator(ABC): """Abstract class for key generators.""" + name: str # Class attribute that should be defined in subclasses + @abstractmethod def generate_key(self, text: str) -> str: pass @@ -37,11 +44,11 @@ def generate_key(self, text: str) -> str: @classmethod def from_name(cls, name): for subclass in cls.__subclasses__(): - if subclass.name == name: + if hasattr(subclass, "name") and subclass.name == name: return subclass raise ValueError( f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " - f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" + f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" ". Make sure to import the derived class before using it." ) @@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str: class CacheStore(ABC): """Abstract class for cache stores.""" + name: str # Class attribute that should be defined in subclasses + @abstractmethod def get(self, key): """Get a value from the cache.""" @@ -94,11 +103,11 @@ def clear(self): @classmethod def from_name(cls, name): for subclass in cls.__subclasses__(): - if subclass.name == name: + if hasattr(subclass, "name") and subclass.name == name: return subclass raise ValueError( f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " - f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" + f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" ". Make sure to import the derived class before using it." ) @@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore): name = "filesystem" - def __init__(self, cache_dir: str = None): + def __init__(self, cache_dir: Optional[str] = None): self._cache_dir = Path(cache_dir or ".cache/embeddings") self._cache_dir.mkdir(parents=True, exist_ok=True) @@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore): name = "redis" def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0): - import redis - + if redis is None: + raise ImportError( + "Could not import redis, please install it with `pip install redis`." + ) self._redis = redis.Redis(host=host, port=port, db=db) def get(self, key): @@ -207,9 +218,9 @@ def clear(self): class EmbeddingsCache: def __init__( self, - key_generator: KeyGenerator = None, - cache_store: CacheStore = None, - store_config: dict = None, + key_generator: Optional[KeyGenerator] = None, + cache_store: Optional[CacheStore] = None, + store_config: Optional[dict] = None, ): self._key_generator = key_generator self._cache_store = cache_store @@ -218,7 +229,10 @@ def __init__( @classmethod def from_dict(cls, d: Dict[str, str]): key_generator = KeyGenerator.from_name(d.get("key_generator"))() - store_config = d.get("store_config") + store_config_raw = d.get("store_config") + store_config: dict = ( + store_config_raw if isinstance(store_config_raw, dict) else {} + ) cache_store = CacheStore.from_name(d.get("store"))(**store_config) return cls(key_generator=key_generator, cache_store=cache_store) @@ -230,8 +244,8 @@ def from_config(cls, config: EmbeddingsCacheConfig): def get_config(self): return EmbeddingsCacheConfig( - key_generator=self._key_generator.name, - store=self._cache_store.name, + key_generator=self._key_generator.name if self._key_generator else "sha256", + store=self._cache_store.name if self._cache_store else "filesystem", store_config=self._store_config, ) @@ -239,8 +253,10 @@ def get_config(self): def get(self, texts): raise NotImplementedError - @get.register + @get.register(str) def _(self, text: str): + if self._key_generator is None or self._cache_store is None: + return None key = self._key_generator.generate_key(text) log.info(f"Fetching key {key} for text '{text[:20]}...' from cache") @@ -248,7 +264,7 @@ def _(self, text: str): return result - @get.register + @get.register(list) def _(self, texts: list): cached = {} @@ -266,19 +282,22 @@ def _(self, texts: list): def set(self, texts): raise NotImplementedError - @set.register + @set.register(str) def _(self, text: str, value: List[float]): + if self._key_generator is None or self._cache_store is None: + return key = self._key_generator.generate_key(text) log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.") self._cache_store.set(key, value) - @set.register + @set.register(list) def _(self, texts: list, values: List[List[float]]): for text, value in zip(texts, values): self.set(text, value) def clear(self): - self._cache_store.clear() + if self._cache_store is not None: + self._cache_store.clear() def cache_embeddings(func): diff --git a/nemoguardrails/embeddings/providers/fastembed.py b/nemoguardrails/embeddings/providers/fastembed.py index 1062e566f..1359f7ab5 100644 --- a/nemoguardrails/embeddings/providers/fastembed.py +++ b/nemoguardrails/embeddings/providers/fastembed.py @@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel): engine_name = "FastEmbed" def __init__(self, embedding_model: str, **kwargs): - from fastembed import TextEmbedding as Embedding + from fastembed import TextEmbedding as Embedding # type: ignore # Enabling a short form model name for all-MiniLM-L6-v2. if embedding_model == "all-MiniLM-L6-v2": diff --git a/nemoguardrails/embeddings/providers/openai.py b/nemoguardrails/embeddings/providers/openai.py index 83f83f8c2..bd12f2333 100644 --- a/nemoguardrails/embeddings/providers/openai.py +++ b/nemoguardrails/embeddings/providers/openai.py @@ -46,14 +46,14 @@ def __init__( **kwargs, ): try: - import openai - from openai import AsyncOpenAI, OpenAI + import openai # type: ignore + from openai import AsyncOpenAI, OpenAI # type: ignore except ImportError: raise ImportError( "Could not import openai, please install it with " "`pip install openai`." ) - if openai.__version__ < "1.0.0": + if openai.__version__ < "1.0.0": # type: ignore raise RuntimeError( "`openai<1.0.0` is no longer supported. " "Please upgrade using `pip install openai>=1.0.0`." diff --git a/nemoguardrails/embeddings/providers/sentence_transformers.py b/nemoguardrails/embeddings/providers/sentence_transformers.py index 7ffcec712..cc7ce7be8 100644 --- a/nemoguardrails/embeddings/providers/sentence_transformers.py +++ b/nemoguardrails/embeddings/providers/sentence_transformers.py @@ -43,7 +43,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from sentence_transformers import SentenceTransformer + from sentence_transformers import SentenceTransformer # type: ignore except ImportError: raise ImportError( "Could not import sentence-transformers, please install it with " @@ -51,7 +51,7 @@ def __init__(self, embedding_model: str, **kwargs): ) try: - from torch import cuda + from torch import cuda # type: ignore except ImportError: raise ImportError( "Could not import torch, please install it with `pip install torch`." From d3af88868c549b7de4e8505576fe9e01b0003048 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 14 Oct 2025 14:28:48 -0500 Subject: [PATCH 2/9] Add nemoguardrails/embeddings to pre-commit checking with pyright --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6be833997..b31308b9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,7 @@ pyright = "^1.1.405" include = [ "nemoguardrails/rails/**", "nemoguardrails/actions/**", + "nemoguardrails/embeddings/**", "nemoguardrails/cli/**", "nemoguardrails/kb/**", "nemoguardrails/logging/**", From f0ae7be64b5fdf45c3877b5a1f823223fff28d8f Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 14 Oct 2025 14:56:25 -0500 Subject: [PATCH 3/9] Address Traian and Pouyan's feedback on redundant None-checks and defaults in EmbeddingsCacheConfig --- nemoguardrails/embeddings/basic.py | 24 ++++++++++++++---------- nemoguardrails/embeddings/cache.py | 6 +++--- nemoguardrails/rails/llm/config.py | 6 +++--- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index af2635fc7..25ada8da6 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -15,7 +15,7 @@ import asyncio import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast from annoy import AnnoyIndex # type: ignore @@ -73,8 +73,14 @@ def __init__( self._model: Optional[EmbeddingModel] = None self._items: List[IndexItem] = [] self._embeddings: List[List[float]] = [] - self.embedding_model: Optional[str] = embedding_model - self.embedding_engine: Optional[str] = embedding_engine + self.embedding_model: str = ( + embedding_model + if embedding_model + else "sentence-transformers/all-MiniLM-L6-v2" + ) + self.embedding_engine: str = ( + embedding_engine if embedding_engine else "SentenceTransformers" + ) self.embedding_params = embedding_params or {} self._embedding_size = 0 self.search_threshold = search_threshold or float("inf") @@ -124,9 +130,8 @@ def embeddings(self): def _init_model(self): """Initialize the model used for computing the embeddings.""" - # Provide defaults if not specified - model = self.embedding_model or "sentence-transformers/all-MiniLM-L6-v2" - engine = self.embedding_engine or "SentenceTransformers" + model = self.embedding_model + engine = self.embedding_engine self._model = init_embedding_model( embedding_model=model, @@ -152,9 +157,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: if self._model is None: self._init_model() - if not self._model: - raise Exception("Couldn't initialize embedding model") - embeddings = await self._model.encode_async(texts) + # self._model can't be None here, or self._init_model() would throw a ValueError + model: EmbeddingModel = cast(EmbeddingModel, self._model) + embeddings = await model.encode_async(texts) return embeddings async def add_item(self, item: IndexItem): @@ -218,7 +223,6 @@ async def _run_batch(self): if not self._current_batch_finished_event: raise Exception("self._current_batch_finished_event not initialized") - assert self._current_batch_finished_event is not None batch_event: asyncio.Event = self._current_batch_finished_event self._current_batch_finished_event = None diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index 1f49c37da..f3824150f 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -244,9 +244,9 @@ def from_config(cls, config: EmbeddingsCacheConfig): def get_config(self): return EmbeddingsCacheConfig( - key_generator=self._key_generator.name if self._key_generator else "sha256", - store=self._cache_store.name if self._cache_store else "filesystem", - store_config=self._store_config, + key_generator=self._key_generator.name if self._key_generator else None, + store=self._cache_store.name if self._cache_store else None, + store_config=self._store_config if self._store_config else None, ) @singledispatchmethod diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 90d24bdc7..bedf3cef8 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -394,15 +394,15 @@ class EmbeddingsCacheConfig(BaseModel): default=False, description="Whether caching of the embeddings should be enabled or not.", ) - key_generator: str = Field( + key_generator: Optional[str] = Field( default="sha256", description="The method to use for generating the cache keys.", ) - store: str = Field( + store: Optional[str] = Field( default="filesystem", description="What type of store to use for the cached embeddings.", ) - store_config: Dict[str, Any] = Field( + store_config: Optional[Dict[str, Any]] = Field( default_factory=dict, description="Any additional configuration options required for the store. " "For example, path for `filesystem` or `host`/`port`/`db` for redis.", From 994cc141af4f6edcf009110f6bff818d80522728 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Tue, 14 Oct 2025 15:03:04 -0500 Subject: [PATCH 4/9] Add type ignore to langchain_nvidia_ai_endpoints import --- nemoguardrails/embeddings/providers/nim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/embeddings/providers/nim.py b/nemoguardrails/embeddings/providers/nim.py index dd5690a4d..8ea9c1d0f 100644 --- a/nemoguardrails/embeddings/providers/nim.py +++ b/nemoguardrails/embeddings/providers/nim.py @@ -35,7 +35,7 @@ class NIMEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings + from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings # type: ignore self.model = embedding_model self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs) From c03350c7a60a45ad3d22f029b37b38c5a614cdf0 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:37:11 -0500 Subject: [PATCH 5/9] Type-clean new community-contributed embedding APIs --- nemoguardrails/embeddings/providers/azureopenai.py | 2 +- nemoguardrails/embeddings/providers/cohere.py | 8 ++++++-- nemoguardrails/embeddings/providers/google.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/nemoguardrails/embeddings/providers/azureopenai.py b/nemoguardrails/embeddings/providers/azureopenai.py index 5c5906d5d..842bff24d 100644 --- a/nemoguardrails/embeddings/providers/azureopenai.py +++ b/nemoguardrails/embeddings/providers/azureopenai.py @@ -56,7 +56,7 @@ def __init__(self, embedding_model: str): self.client = AzureOpenAI( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore[arg-type] (comes from `$AZURE_OPENAI_ENDPOINT`) ) self.embedding_model = embedding_model diff --git a/nemoguardrails/embeddings/providers/cohere.py b/nemoguardrails/embeddings/providers/cohere.py index 34cee4156..b6bb60e52 100644 --- a/nemoguardrails/embeddings/providers/cohere.py +++ b/nemoguardrails/embeddings/providers/cohere.py @@ -14,7 +14,7 @@ # limitations under the License. import asyncio from contextvars import ContextVar -from typing import List +from typing import TYPE_CHECKING, List from .base import EmbeddingModel @@ -23,6 +23,10 @@ # is changed, it will fail. async_client_var: ContextVar = ContextVar("async_client", default=None) +if TYPE_CHECKING: + import cohere + from cohere import AsyncClient, Client + class CohereEmbeddingModel(EmbeddingModel): """ @@ -64,7 +68,7 @@ def __init__( self.model = embedding_model self.input_type = input_type - self.client = cohere.Client(**kwargs) + self.client = cohere.Client(**kwargs) # type: ignore[reportCallIssue] self.embedding_size_dict = { "embed-v4.0": 1536, diff --git a/nemoguardrails/embeddings/providers/google.py b/nemoguardrails/embeddings/providers/google.py index cf55399af..1f78974e6 100644 --- a/nemoguardrails/embeddings/providers/google.py +++ b/nemoguardrails/embeddings/providers/google.py @@ -46,7 +46,7 @@ class GoogleEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from google import genai + from google import genai # type: ignore[import] except ImportError: raise ImportError( From b450ef7cf4343ac574c4f0da16e0d037d02d304f Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Fri, 24 Oct 2025 17:02:16 -0500 Subject: [PATCH 6/9] Remove optional dependencies and clean the type errors --- nemoguardrails/embeddings/basic.py | 5 ++++- nemoguardrails/embeddings/providers/azureopenai.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 25ada8da6..706e4bca5 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -257,7 +257,10 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: self._req_idx += 1 self._req_queue[req_id] = text - if self._current_batch_finished_event is None: + if ( + self._current_batch_finished_event is None + or self._current_batch_full_event is None + ): self._current_batch_finished_event = asyncio.Event() self._current_batch_full_event = asyncio.Event() self._current_batch_submitted.clear() diff --git a/nemoguardrails/embeddings/providers/azureopenai.py b/nemoguardrails/embeddings/providers/azureopenai.py index 842bff24d..920cbd9d0 100644 --- a/nemoguardrails/embeddings/providers/azureopenai.py +++ b/nemoguardrails/embeddings/providers/azureopenai.py @@ -46,7 +46,9 @@ class AzureEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str): try: - from openai import AzureOpenAI + from openai import ( + AzureOpenAI, # type: ignore[attr-defined] (Assume this is installed) + ) except ImportError: raise ImportError( "Could not import openai, please install it with " From daa4062bc17c6e132e2f9f86fdc41236d48a46e4 Mon Sep 17 00:00:00 2001 From: Pouyan <13303554+Pouyanpi@users.noreply.github.com> Date: Mon, 27 Oct 2025 22:16:02 +0100 Subject: [PATCH 7/9] Review 1383 (#1473) * first round of review * redundant pyright checks * fix cohere type errors * remove redundant model validation --- nemoguardrails/embeddings/basic.py | 49 +++++++------------ nemoguardrails/embeddings/cache.py | 25 ++++------ nemoguardrails/embeddings/providers/cohere.py | 7 ++- nemoguardrails/rails/llm/config.py | 6 +-- 4 files changed, 35 insertions(+), 52 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 706e4bca5..878a7677e 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -49,12 +49,12 @@ class BasicEmbeddingsIndex(EmbeddingsIndex): def __init__( self, - embedding_model: Optional[str] = None, - embedding_engine: Optional[str] = None, + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", + embedding_engine: str = "SentenceTransformers", embedding_params: Optional[Dict[str, Any]] = None, index: Optional[AnnoyIndex] = None, cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None, - search_threshold: Optional[float] = None, + search_threshold: float = float("inf"), use_batching: bool = False, max_batch_size: int = 10, max_batch_hold: float = 0.01, @@ -62,10 +62,11 @@ def __init__( """Initialize the BasicEmbeddingsIndex. Args: - embedding_model (str, optional): The model for computing embeddings. Defaults to None. - embedding_engine (str, optional): The engine for computing embeddings. Defaults to None. - index (AnnoyIndex, optional): The pre-existing index. Defaults to None. - cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None. + embedding_model: The model for computing embeddings. + embedding_engine: The engine for computing embeddings. + index: The pre-existing index. + cache_config: The cache configuration. + search_threshold: The threshold for filtering search results. use_batching: Whether to batch requests when computing the embeddings. max_batch_size: The maximum size of a batch. max_batch_hold: The maximum time a batch is held before being processed @@ -73,17 +74,11 @@ def __init__( self._model: Optional[EmbeddingModel] = None self._items: List[IndexItem] = [] self._embeddings: List[List[float]] = [] - self.embedding_model: str = ( - embedding_model - if embedding_model - else "sentence-transformers/all-MiniLM-L6-v2" - ) - self.embedding_engine: str = ( - embedding_engine if embedding_engine else "SentenceTransformers" - ) + self.embedding_model = embedding_model + self.embedding_engine = embedding_engine self.embedding_params = embedding_params or {} self._embedding_size = 0 - self.search_threshold = search_threshold or float("inf") + self.search_threshold = search_threshold if isinstance(cache_config, Dict): self._cache_config = EmbeddingsCacheConfig(**cache_config) else: @@ -139,11 +134,6 @@ def _init_model(self): embedding_params=self.embedding_params, ) - if not self._model: - raise ValueError( - f"Couldn't create embedding model with model {model} and engine {engine}" - ) - @cache_embeddings async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: """Compute embeddings for a list of texts. @@ -205,10 +195,12 @@ async def _run_batch(self): """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. - if not self._current_batch_full_event: - raise Exception("self._current_batch_full_event not initialized") + if ( + self._current_batch_full_event is None + or self._current_batch_finished_event is None + ): + raise RuntimeError("Batch events not initialized. This should not happen.") - assert self._current_batch_full_event is not None done, pending = await asyncio.wait( [ asyncio.create_task(asyncio.sleep(self.max_batch_hold)), @@ -220,9 +212,6 @@ async def _run_batch(self): task.cancel() # Reset the batch event - if not self._current_batch_finished_event: - raise Exception("self._current_batch_finished_event not initialized") - batch_event: asyncio.Event = self._current_batch_finished_event self._current_batch_finished_event = None @@ -268,13 +257,9 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: # We check if we reached the max batch size if len(self._req_queue) >= self.max_batch_size: - if not self._current_batch_full_event: - raise Exception("self._current_batch_full_event not initialized") self._current_batch_full_event.set() - # Wait for the batch to finish - if not self._current_batch_finished_event: - raise Exception("self._current_batch_finished_event not initialized") + # Wait for the batch to finish await self._current_batch_finished_event.wait() # Remove the result and return it diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index f3824150f..8551f3fa9 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -44,11 +44,11 @@ def generate_key(self, text: str) -> str: @classmethod def from_name(cls, name): for subclass in cls.__subclasses__(): - if hasattr(subclass, "name") and subclass.name == name: + if subclass.name == name: return subclass raise ValueError( f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " - f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" + f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" ". Make sure to import the derived class before using it." ) @@ -103,11 +103,11 @@ def clear(self): @classmethod def from_name(cls, name): for subclass in cls.__subclasses__(): - if hasattr(subclass, "name") and subclass.name == name: + if subclass.name == name: return subclass raise ValueError( f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " - f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" + f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" ". Make sure to import the derived class before using it." ) @@ -218,8 +218,8 @@ def clear(self): class EmbeddingsCache: def __init__( self, - key_generator: Optional[KeyGenerator] = None, - cache_store: Optional[CacheStore] = None, + key_generator: KeyGenerator, + cache_store: CacheStore, store_config: Optional[dict] = None, ): self._key_generator = key_generator @@ -244,9 +244,9 @@ def from_config(cls, config: EmbeddingsCacheConfig): def get_config(self): return EmbeddingsCacheConfig( - key_generator=self._key_generator.name if self._key_generator else None, - store=self._cache_store.name if self._cache_store else None, - store_config=self._store_config if self._store_config else None, + key_generator=self._key_generator.name, + store=self._cache_store.name, + store_config=self._store_config, ) @singledispatchmethod @@ -255,8 +255,6 @@ def get(self, texts): @get.register(str) def _(self, text: str): - if self._key_generator is None or self._cache_store is None: - return None key = self._key_generator.generate_key(text) log.info(f"Fetching key {key} for text '{text[:20]}...' from cache") @@ -284,8 +282,6 @@ def set(self, texts): @set.register(str) def _(self, text: str, value: List[float]): - if self._key_generator is None or self._cache_store is None: - return key = self._key_generator.generate_key(text) log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.") self._cache_store.set(key, value) @@ -296,8 +292,7 @@ def _(self, texts: list, values: List[List[float]]): self.set(text, value) def clear(self): - if self._cache_store is not None: - self._cache_store.clear() + self._cache_store.clear() def cache_embeddings(func): diff --git a/nemoguardrails/embeddings/providers/cohere.py b/nemoguardrails/embeddings/providers/cohere.py index b6bb60e52..704e0bcd7 100644 --- a/nemoguardrails/embeddings/providers/cohere.py +++ b/nemoguardrails/embeddings/providers/cohere.py @@ -124,6 +124,9 @@ def encode(self, documents: List[str]) -> List[List[float]]: """ # Make embedding request to Cohere API - return self.client.embed( + # Since we don't pass embedding_types parameter, the response should be + # EmbeddingsFloatsEmbedResponse with embeddings as List[List[float]] + response = self.client.embed( texts=documents, model=self.model, input_type=self.input_type - ).embeddings + ) + return response.embeddings # type: ignore[return-value] diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index bedf3cef8..90d24bdc7 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -394,15 +394,15 @@ class EmbeddingsCacheConfig(BaseModel): default=False, description="Whether caching of the embeddings should be enabled or not.", ) - key_generator: Optional[str] = Field( + key_generator: str = Field( default="sha256", description="The method to use for generating the cache keys.", ) - store: Optional[str] = Field( + store: str = Field( default="filesystem", description="What type of store to use for the cached embeddings.", ) - store_config: Optional[Dict[str, Any]] = Field( + store_config: Dict[str, Any] = Field( default_factory=dict, description="Any additional configuration options required for the store. " "For example, path for `filesystem` or `host`/`port`/`db` for redis.", From 1313ec08bfef30859b3aec0a7970aed77b381876 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 27 Oct 2025 16:55:06 -0500 Subject: [PATCH 8/9] Address last couple of comments --- nemoguardrails/embeddings/cache.py | 2 +- nemoguardrails/embeddings/providers/azureopenai.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index 8551f3fa9..cdef48c27 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -83,7 +83,7 @@ def generate_key(self, text: str) -> str: class CacheStore(ABC): """Abstract class for cache stores.""" - name: str # Class attribute that should be defined in subclasses + name: str @abstractmethod def get(self, key): diff --git a/nemoguardrails/embeddings/providers/azureopenai.py b/nemoguardrails/embeddings/providers/azureopenai.py index 920cbd9d0..4eb7accfc 100644 --- a/nemoguardrails/embeddings/providers/azureopenai.py +++ b/nemoguardrails/embeddings/providers/azureopenai.py @@ -46,9 +46,7 @@ class AzureEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str): try: - from openai import ( - AzureOpenAI, # type: ignore[attr-defined] (Assume this is installed) - ) + from openai import AzureOpenAI # type: ignore except ImportError: raise ImportError( "Could not import openai, please install it with " @@ -58,7 +56,7 @@ def __init__(self, embedding_model: str): self.client = AzureOpenAI( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore[arg-type] (comes from `$AZURE_OPENAI_ENDPOINT`) + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), ) self.embedding_model = embedding_model From 61ac1c9e7400c8543764186d0c596e990cfe7353 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 28 Oct 2025 11:49:15 +0100 Subject: [PATCH 9/9] minor cleanup and type ignore for Azure --- nemoguardrails/embeddings/basic.py | 2 -- nemoguardrails/embeddings/providers/azureopenai.py | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 878a7677e..a4e497762 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -45,8 +45,6 @@ class BasicEmbeddingsIndex(EmbeddingsIndex): max_batch_hold: The maximum time a batch is held before being processed """ - # Instance attributes are defined in __init__ and accessed via properties - def __init__( self, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", diff --git a/nemoguardrails/embeddings/providers/azureopenai.py b/nemoguardrails/embeddings/providers/azureopenai.py index 4eb7accfc..e77ab481a 100644 --- a/nemoguardrails/embeddings/providers/azureopenai.py +++ b/nemoguardrails/embeddings/providers/azureopenai.py @@ -49,14 +49,13 @@ def __init__(self, embedding_model: str): from openai import AzureOpenAI # type: ignore except ImportError: raise ImportError( - "Could not import openai, please install it with " - "`pip install openai`." + "Could not import openai, please install it with `pip install openai`." ) # Set Azure OpenAI API credentials self.client = AzureOpenAI( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore ) self.embedding_model = embedding_model