diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index 706e4bca5..878a7677e 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -49,12 +49,12 @@ class BasicEmbeddingsIndex(EmbeddingsIndex): def __init__( self, - embedding_model: Optional[str] = None, - embedding_engine: Optional[str] = None, + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", + embedding_engine: str = "SentenceTransformers", embedding_params: Optional[Dict[str, Any]] = None, index: Optional[AnnoyIndex] = None, cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None, - search_threshold: Optional[float] = None, + search_threshold: float = float("inf"), use_batching: bool = False, max_batch_size: int = 10, max_batch_hold: float = 0.01, @@ -62,10 +62,11 @@ def __init__( """Initialize the BasicEmbeddingsIndex. Args: - embedding_model (str, optional): The model for computing embeddings. Defaults to None. - embedding_engine (str, optional): The engine for computing embeddings. Defaults to None. - index (AnnoyIndex, optional): The pre-existing index. Defaults to None. - cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None. + embedding_model: The model for computing embeddings. + embedding_engine: The engine for computing embeddings. + index: The pre-existing index. + cache_config: The cache configuration. + search_threshold: The threshold for filtering search results. use_batching: Whether to batch requests when computing the embeddings. max_batch_size: The maximum size of a batch. max_batch_hold: The maximum time a batch is held before being processed @@ -73,17 +74,11 @@ def __init__( self._model: Optional[EmbeddingModel] = None self._items: List[IndexItem] = [] self._embeddings: List[List[float]] = [] - self.embedding_model: str = ( - embedding_model - if embedding_model - else "sentence-transformers/all-MiniLM-L6-v2" - ) - self.embedding_engine: str = ( - embedding_engine if embedding_engine else "SentenceTransformers" - ) + self.embedding_model = embedding_model + self.embedding_engine = embedding_engine self.embedding_params = embedding_params or {} self._embedding_size = 0 - self.search_threshold = search_threshold or float("inf") + self.search_threshold = search_threshold if isinstance(cache_config, Dict): self._cache_config = EmbeddingsCacheConfig(**cache_config) else: @@ -139,11 +134,6 @@ def _init_model(self): embedding_params=self.embedding_params, ) - if not self._model: - raise ValueError( - f"Couldn't create embedding model with model {model} and engine {engine}" - ) - @cache_embeddings async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: """Compute embeddings for a list of texts. @@ -205,10 +195,12 @@ async def _run_batch(self): """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. - if not self._current_batch_full_event: - raise Exception("self._current_batch_full_event not initialized") + if ( + self._current_batch_full_event is None + or self._current_batch_finished_event is None + ): + raise RuntimeError("Batch events not initialized. This should not happen.") - assert self._current_batch_full_event is not None done, pending = await asyncio.wait( [ asyncio.create_task(asyncio.sleep(self.max_batch_hold)), @@ -220,9 +212,6 @@ async def _run_batch(self): task.cancel() # Reset the batch event - if not self._current_batch_finished_event: - raise Exception("self._current_batch_finished_event not initialized") - batch_event: asyncio.Event = self._current_batch_finished_event self._current_batch_finished_event = None @@ -268,13 +257,9 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: # We check if we reached the max batch size if len(self._req_queue) >= self.max_batch_size: - if not self._current_batch_full_event: - raise Exception("self._current_batch_full_event not initialized") self._current_batch_full_event.set() - # Wait for the batch to finish - if not self._current_batch_finished_event: - raise Exception("self._current_batch_finished_event not initialized") + # Wait for the batch to finish await self._current_batch_finished_event.wait() # Remove the result and return it diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index f3824150f..8551f3fa9 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -44,11 +44,11 @@ def generate_key(self, text: str) -> str: @classmethod def from_name(cls, name): for subclass in cls.__subclasses__(): - if hasattr(subclass, "name") and subclass.name == name: + if subclass.name == name: return subclass raise ValueError( f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " - f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" + f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" ". Make sure to import the derived class before using it." ) @@ -103,11 +103,11 @@ def clear(self): @classmethod def from_name(cls, name): for subclass in cls.__subclasses__(): - if hasattr(subclass, "name") and subclass.name == name: + if subclass.name == name: return subclass raise ValueError( f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: " - f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}" + f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}" ". Make sure to import the derived class before using it." ) @@ -218,8 +218,8 @@ def clear(self): class EmbeddingsCache: def __init__( self, - key_generator: Optional[KeyGenerator] = None, - cache_store: Optional[CacheStore] = None, + key_generator: KeyGenerator, + cache_store: CacheStore, store_config: Optional[dict] = None, ): self._key_generator = key_generator @@ -244,9 +244,9 @@ def from_config(cls, config: EmbeddingsCacheConfig): def get_config(self): return EmbeddingsCacheConfig( - key_generator=self._key_generator.name if self._key_generator else None, - store=self._cache_store.name if self._cache_store else None, - store_config=self._store_config if self._store_config else None, + key_generator=self._key_generator.name, + store=self._cache_store.name, + store_config=self._store_config, ) @singledispatchmethod @@ -255,8 +255,6 @@ def get(self, texts): @get.register(str) def _(self, text: str): - if self._key_generator is None or self._cache_store is None: - return None key = self._key_generator.generate_key(text) log.info(f"Fetching key {key} for text '{text[:20]}...' from cache") @@ -284,8 +282,6 @@ def set(self, texts): @set.register(str) def _(self, text: str, value: List[float]): - if self._key_generator is None or self._cache_store is None: - return key = self._key_generator.generate_key(text) log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.") self._cache_store.set(key, value) @@ -296,8 +292,7 @@ def _(self, texts: list, values: List[List[float]]): self.set(text, value) def clear(self): - if self._cache_store is not None: - self._cache_store.clear() + self._cache_store.clear() def cache_embeddings(func): diff --git a/nemoguardrails/embeddings/providers/cohere.py b/nemoguardrails/embeddings/providers/cohere.py index b6bb60e52..704e0bcd7 100644 --- a/nemoguardrails/embeddings/providers/cohere.py +++ b/nemoguardrails/embeddings/providers/cohere.py @@ -124,6 +124,9 @@ def encode(self, documents: List[str]) -> List[List[float]]: """ # Make embedding request to Cohere API - return self.client.embed( + # Since we don't pass embedding_types parameter, the response should be + # EmbeddingsFloatsEmbedResponse with embeddings as List[List[float]] + response = self.client.embed( texts=documents, model=self.model, input_type=self.input_type - ).embeddings + ) + return response.embeddings # type: ignore[return-value] diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index bedf3cef8..90d24bdc7 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -394,15 +394,15 @@ class EmbeddingsCacheConfig(BaseModel): default=False, description="Whether caching of the embeddings should be enabled or not.", ) - key_generator: Optional[str] = Field( + key_generator: str = Field( default="sha256", description="The method to use for generating the cache keys.", ) - store: Optional[str] = Field( + store: str = Field( default="filesystem", description="What type of store to use for the cached embeddings.", ) - store_config: Optional[Dict[str, Any]] = Field( + store_config: Dict[str, Any] = Field( default_factory=dict, description="Any additional configuration options required for the store. " "For example, path for `filesystem` or `host`/`port`/`db` for redis.",