Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 17 additions & 32 deletions nemoguardrails/embeddings/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,41 +49,36 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):

def __init__(
self,
embedding_model: Optional[str] = None,
embedding_engine: Optional[str] = None,
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
embedding_engine: str = "SentenceTransformers",
embedding_params: Optional[Dict[str, Any]] = None,
index: Optional[AnnoyIndex] = None,
cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None,
search_threshold: Optional[float] = None,
search_threshold: float = float("inf"),
use_batching: bool = False,
max_batch_size: int = 10,
max_batch_hold: float = 0.01,
):
"""Initialize the BasicEmbeddingsIndex.

Args:
embedding_model (str, optional): The model for computing embeddings. Defaults to None.
embedding_engine (str, optional): The engine for computing embeddings. Defaults to None.
index (AnnoyIndex, optional): The pre-existing index. Defaults to None.
cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None.
embedding_model: The model for computing embeddings.
embedding_engine: The engine for computing embeddings.
index: The pre-existing index.
cache_config: The cache configuration.
search_threshold: The threshold for filtering search results.
use_batching: Whether to batch requests when computing the embeddings.
max_batch_size: The maximum size of a batch.
max_batch_hold: The maximum time a batch is held before being processed
"""
self._model: Optional[EmbeddingModel] = None
self._items: List[IndexItem] = []
self._embeddings: List[List[float]] = []
self.embedding_model: str = (
embedding_model
if embedding_model
else "sentence-transformers/all-MiniLM-L6-v2"
)
self.embedding_engine: str = (
embedding_engine if embedding_engine else "SentenceTransformers"
)
self.embedding_model = embedding_model
self.embedding_engine = embedding_engine
self.embedding_params = embedding_params or {}
self._embedding_size = 0
self.search_threshold = search_threshold or float("inf")
self.search_threshold = search_threshold
if isinstance(cache_config, Dict):
self._cache_config = EmbeddingsCacheConfig(**cache_config)
else:
Expand Down Expand Up @@ -139,11 +134,6 @@ def _init_model(self):
embedding_params=self.embedding_params,
)

if not self._model:
raise ValueError(
f"Couldn't create embedding model with model {model} and engine {engine}"
)

@cache_embeddings
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Compute embeddings for a list of texts.
Expand Down Expand Up @@ -205,10 +195,12 @@ async def _run_batch(self):
"""Runs the current batch of embeddings."""

# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
if not self._current_batch_full_event:
raise Exception("self._current_batch_full_event not initialized")
if (
self._current_batch_full_event is None
or self._current_batch_finished_event is None
):
raise RuntimeError("Batch events not initialized. This should not happen.")

assert self._current_batch_full_event is not None
done, pending = await asyncio.wait(
[
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
Expand All @@ -220,9 +212,6 @@ async def _run_batch(self):
task.cancel()

# Reset the batch event
if not self._current_batch_finished_event:
raise Exception("self._current_batch_finished_event not initialized")

batch_event: asyncio.Event = self._current_batch_finished_event
self._current_batch_finished_event = None

Expand Down Expand Up @@ -268,13 +257,9 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:

# We check if we reached the max batch size
if len(self._req_queue) >= self.max_batch_size:
if not self._current_batch_full_event:
raise Exception("self._current_batch_full_event not initialized")
self._current_batch_full_event.set()

# Wait for the batch to finish
if not self._current_batch_finished_event:
raise Exception("self._current_batch_finished_event not initialized")
# Wait for the batch to finish
await self._current_batch_finished_event.wait()

# Remove the result and return it
Expand Down
25 changes: 10 additions & 15 deletions nemoguardrails/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def generate_key(self, text: str) -> str:
@classmethod
def from_name(cls, name):
for subclass in cls.__subclasses__():
if hasattr(subclass, "name") and subclass.name == name:
if subclass.name == name:
return subclass
raise ValueError(
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
". Make sure to import the derived class before using it."
)

Expand Down Expand Up @@ -103,11 +103,11 @@ def clear(self):
@classmethod
def from_name(cls, name):
for subclass in cls.__subclasses__():
if hasattr(subclass, "name") and subclass.name == name:
if subclass.name == name:
return subclass
raise ValueError(
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
". Make sure to import the derived class before using it."
)

Expand Down Expand Up @@ -218,8 +218,8 @@ def clear(self):
class EmbeddingsCache:
def __init__(
self,
key_generator: Optional[KeyGenerator] = None,
cache_store: Optional[CacheStore] = None,
key_generator: KeyGenerator,
cache_store: CacheStore,
store_config: Optional[dict] = None,
):
self._key_generator = key_generator
Expand All @@ -244,9 +244,9 @@ def from_config(cls, config: EmbeddingsCacheConfig):

def get_config(self):
return EmbeddingsCacheConfig(
key_generator=self._key_generator.name if self._key_generator else None,
store=self._cache_store.name if self._cache_store else None,
store_config=self._store_config if self._store_config else None,
key_generator=self._key_generator.name,
store=self._cache_store.name,
store_config=self._store_config,
)

@singledispatchmethod
Expand All @@ -255,8 +255,6 @@ def get(self, texts):

@get.register(str)
def _(self, text: str):
if self._key_generator is None or self._cache_store is None:
return None
key = self._key_generator.generate_key(text)
log.info(f"Fetching key {key} for text '{text[:20]}...' from cache")

Expand Down Expand Up @@ -284,8 +282,6 @@ def set(self, texts):

@set.register(str)
def _(self, text: str, value: List[float]):
if self._key_generator is None or self._cache_store is None:
return
key = self._key_generator.generate_key(text)
log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.")
self._cache_store.set(key, value)
Expand All @@ -296,8 +292,7 @@ def _(self, texts: list, values: List[List[float]]):
self.set(text, value)

def clear(self):
if self._cache_store is not None:
self._cache_store.clear()
self._cache_store.clear()


def cache_embeddings(func):
Expand Down
7 changes: 5 additions & 2 deletions nemoguardrails/embeddings/providers/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def encode(self, documents: List[str]) -> List[List[float]]:
"""

# Make embedding request to Cohere API
return self.client.embed(
# Since we don't pass embedding_types parameter, the response should be
# EmbeddingsFloatsEmbedResponse with embeddings as List[List[float]]
response = self.client.embed(
texts=documents, model=self.model, input_type=self.input_type
).embeddings
)
return response.embeddings # type: ignore[return-value]
6 changes: 3 additions & 3 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,15 +394,15 @@ class EmbeddingsCacheConfig(BaseModel):
default=False,
description="Whether caching of the embeddings should be enabled or not.",
)
key_generator: Optional[str] = Field(
key_generator: str = Field(
default="sha256",
description="The method to use for generating the cache keys.",
)
store: Optional[str] = Field(
store: str = Field(
default="filesystem",
description="What type of store to use for the cached embeddings.",
)
store_config: Optional[Dict[str, Any]] = Field(
store_config: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional configuration options required for the store. "
"For example, path for `filesystem` or `host`/`port`/`db` for redis.",
Expand Down