Skip to content

Commit daa4062

Browse files
authored
Review 1383 (#1473)
* first round of review * redundant pyright checks * fix cohere type errors * remove redundant model validation
1 parent b450ef7 commit daa4062

File tree

4 files changed

+35
-52
lines changed

4 files changed

+35
-52
lines changed

nemoguardrails/embeddings/basic.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -49,41 +49,36 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
4949

5050
def __init__(
5151
self,
52-
embedding_model: Optional[str] = None,
53-
embedding_engine: Optional[str] = None,
52+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
53+
embedding_engine: str = "SentenceTransformers",
5454
embedding_params: Optional[Dict[str, Any]] = None,
5555
index: Optional[AnnoyIndex] = None,
5656
cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None,
57-
search_threshold: Optional[float] = None,
57+
search_threshold: float = float("inf"),
5858
use_batching: bool = False,
5959
max_batch_size: int = 10,
6060
max_batch_hold: float = 0.01,
6161
):
6262
"""Initialize the BasicEmbeddingsIndex.
6363
6464
Args:
65-
embedding_model (str, optional): The model for computing embeddings. Defaults to None.
66-
embedding_engine (str, optional): The engine for computing embeddings. Defaults to None.
67-
index (AnnoyIndex, optional): The pre-existing index. Defaults to None.
68-
cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None.
65+
embedding_model: The model for computing embeddings.
66+
embedding_engine: The engine for computing embeddings.
67+
index: The pre-existing index.
68+
cache_config: The cache configuration.
69+
search_threshold: The threshold for filtering search results.
6970
use_batching: Whether to batch requests when computing the embeddings.
7071
max_batch_size: The maximum size of a batch.
7172
max_batch_hold: The maximum time a batch is held before being processed
7273
"""
7374
self._model: Optional[EmbeddingModel] = None
7475
self._items: List[IndexItem] = []
7576
self._embeddings: List[List[float]] = []
76-
self.embedding_model: str = (
77-
embedding_model
78-
if embedding_model
79-
else "sentence-transformers/all-MiniLM-L6-v2"
80-
)
81-
self.embedding_engine: str = (
82-
embedding_engine if embedding_engine else "SentenceTransformers"
83-
)
77+
self.embedding_model = embedding_model
78+
self.embedding_engine = embedding_engine
8479
self.embedding_params = embedding_params or {}
8580
self._embedding_size = 0
86-
self.search_threshold = search_threshold or float("inf")
81+
self.search_threshold = search_threshold
8782
if isinstance(cache_config, Dict):
8883
self._cache_config = EmbeddingsCacheConfig(**cache_config)
8984
else:
@@ -139,11 +134,6 @@ def _init_model(self):
139134
embedding_params=self.embedding_params,
140135
)
141136

142-
if not self._model:
143-
raise ValueError(
144-
f"Couldn't create embedding model with model {model} and engine {engine}"
145-
)
146-
147137
@cache_embeddings
148138
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
149139
"""Compute embeddings for a list of texts.
@@ -205,10 +195,12 @@ async def _run_batch(self):
205195
"""Runs the current batch of embeddings."""
206196

207197
# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
208-
if not self._current_batch_full_event:
209-
raise Exception("self._current_batch_full_event not initialized")
198+
if (
199+
self._current_batch_full_event is None
200+
or self._current_batch_finished_event is None
201+
):
202+
raise RuntimeError("Batch events not initialized. This should not happen.")
210203

211-
assert self._current_batch_full_event is not None
212204
done, pending = await asyncio.wait(
213205
[
214206
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
@@ -220,9 +212,6 @@ async def _run_batch(self):
220212
task.cancel()
221213

222214
# Reset the batch event
223-
if not self._current_batch_finished_event:
224-
raise Exception("self._current_batch_finished_event not initialized")
225-
226215
batch_event: asyncio.Event = self._current_batch_finished_event
227216
self._current_batch_finished_event = None
228217

@@ -268,13 +257,9 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:
268257

269258
# We check if we reached the max batch size
270259
if len(self._req_queue) >= self.max_batch_size:
271-
if not self._current_batch_full_event:
272-
raise Exception("self._current_batch_full_event not initialized")
273260
self._current_batch_full_event.set()
274261

275-
# Wait for the batch to finish
276-
if not self._current_batch_finished_event:
277-
raise Exception("self._current_batch_finished_event not initialized")
262+
# Wait for the batch to finish
278263
await self._current_batch_finished_event.wait()
279264

280265
# Remove the result and return it

nemoguardrails/embeddings/cache.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def generate_key(self, text: str) -> str:
4444
@classmethod
4545
def from_name(cls, name):
4646
for subclass in cls.__subclasses__():
47-
if hasattr(subclass, "name") and subclass.name == name:
47+
if subclass.name == name:
4848
return subclass
4949
raise ValueError(
5050
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
51-
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
51+
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
5252
". Make sure to import the derived class before using it."
5353
)
5454

@@ -103,11 +103,11 @@ def clear(self):
103103
@classmethod
104104
def from_name(cls, name):
105105
for subclass in cls.__subclasses__():
106-
if hasattr(subclass, "name") and subclass.name == name:
106+
if subclass.name == name:
107107
return subclass
108108
raise ValueError(
109109
f"Unknown {cls.__name__}: {name}. Available {cls.__name__}s are: "
110-
f"{', '.join([subclass.name for subclass in cls.__subclasses__() if hasattr(subclass, 'name')])}"
110+
f"{', '.join([subclass.name for subclass in cls.__subclasses__()])}"
111111
". Make sure to import the derived class before using it."
112112
)
113113

@@ -218,8 +218,8 @@ def clear(self):
218218
class EmbeddingsCache:
219219
def __init__(
220220
self,
221-
key_generator: Optional[KeyGenerator] = None,
222-
cache_store: Optional[CacheStore] = None,
221+
key_generator: KeyGenerator,
222+
cache_store: CacheStore,
223223
store_config: Optional[dict] = None,
224224
):
225225
self._key_generator = key_generator
@@ -244,9 +244,9 @@ def from_config(cls, config: EmbeddingsCacheConfig):
244244

245245
def get_config(self):
246246
return EmbeddingsCacheConfig(
247-
key_generator=self._key_generator.name if self._key_generator else None,
248-
store=self._cache_store.name if self._cache_store else None,
249-
store_config=self._store_config if self._store_config else None,
247+
key_generator=self._key_generator.name,
248+
store=self._cache_store.name,
249+
store_config=self._store_config,
250250
)
251251

252252
@singledispatchmethod
@@ -255,8 +255,6 @@ def get(self, texts):
255255

256256
@get.register(str)
257257
def _(self, text: str):
258-
if self._key_generator is None or self._cache_store is None:
259-
return None
260258
key = self._key_generator.generate_key(text)
261259
log.info(f"Fetching key {key} for text '{text[:20]}...' from cache")
262260

@@ -284,8 +282,6 @@ def set(self, texts):
284282

285283
@set.register(str)
286284
def _(self, text: str, value: List[float]):
287-
if self._key_generator is None or self._cache_store is None:
288-
return
289285
key = self._key_generator.generate_key(text)
290286
log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.")
291287
self._cache_store.set(key, value)
@@ -296,8 +292,7 @@ def _(self, texts: list, values: List[List[float]]):
296292
self.set(text, value)
297293

298294
def clear(self):
299-
if self._cache_store is not None:
300-
self._cache_store.clear()
295+
self._cache_store.clear()
301296

302297

303298
def cache_embeddings(func):

nemoguardrails/embeddings/providers/cohere.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def encode(self, documents: List[str]) -> List[List[float]]:
124124
"""
125125

126126
# Make embedding request to Cohere API
127-
return self.client.embed(
127+
# Since we don't pass embedding_types parameter, the response should be
128+
# EmbeddingsFloatsEmbedResponse with embeddings as List[List[float]]
129+
response = self.client.embed(
128130
texts=documents, model=self.model, input_type=self.input_type
129-
).embeddings
131+
)
132+
return response.embeddings # type: ignore[return-value]

nemoguardrails/rails/llm/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,15 +394,15 @@ class EmbeddingsCacheConfig(BaseModel):
394394
default=False,
395395
description="Whether caching of the embeddings should be enabled or not.",
396396
)
397-
key_generator: Optional[str] = Field(
397+
key_generator: str = Field(
398398
default="sha256",
399399
description="The method to use for generating the cache keys.",
400400
)
401-
store: Optional[str] = Field(
401+
store: str = Field(
402402
default="filesystem",
403403
description="What type of store to use for the cached embeddings.",
404404
)
405-
store_config: Optional[Dict[str, Any]] = Field(
405+
store_config: Dict[str, Any] = Field(
406406
default_factory=dict,
407407
description="Any additional configuration options required for the store. "
408408
"For example, path for `filesystem` or `host`/`port`/`db` for redis.",

0 commit comments

Comments
 (0)