@@ -49,41 +49,36 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
4949
5050 def __init__ (
5151 self ,
52- embedding_model : Optional [ str ] = None ,
53- embedding_engine : Optional [ str ] = None ,
52+ embedding_model : str = "sentence-transformers/all-MiniLM-L6-v2" ,
53+ embedding_engine : str = "SentenceTransformers" ,
5454 embedding_params : Optional [Dict [str , Any ]] = None ,
5555 index : Optional [AnnoyIndex ] = None ,
5656 cache_config : Optional [Union [EmbeddingsCacheConfig , Dict [str , Any ]]] = None ,
57- search_threshold : Optional [ float ] = None ,
57+ search_threshold : float = float ( "inf" ) ,
5858 use_batching : bool = False ,
5959 max_batch_size : int = 10 ,
6060 max_batch_hold : float = 0.01 ,
6161 ):
6262 """Initialize the BasicEmbeddingsIndex.
6363
6464 Args:
65- embedding_model (str, optional): The model for computing embeddings. Defaults to None.
66- embedding_engine (str, optional): The engine for computing embeddings. Defaults to None.
67- index (AnnoyIndex, optional): The pre-existing index. Defaults to None.
68- cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None.
65+ embedding_model: The model for computing embeddings.
66+ embedding_engine: The engine for computing embeddings.
67+ index: The pre-existing index.
68+ cache_config: The cache configuration.
69+ search_threshold: The threshold for filtering search results.
6970 use_batching: Whether to batch requests when computing the embeddings.
7071 max_batch_size: The maximum size of a batch.
7172 max_batch_hold: The maximum time a batch is held before being processed
7273 """
7374 self ._model : Optional [EmbeddingModel ] = None
7475 self ._items : List [IndexItem ] = []
7576 self ._embeddings : List [List [float ]] = []
76- self .embedding_model : str = (
77- embedding_model
78- if embedding_model
79- else "sentence-transformers/all-MiniLM-L6-v2"
80- )
81- self .embedding_engine : str = (
82- embedding_engine if embedding_engine else "SentenceTransformers"
83- )
77+ self .embedding_model = embedding_model
78+ self .embedding_engine = embedding_engine
8479 self .embedding_params = embedding_params or {}
8580 self ._embedding_size = 0
86- self .search_threshold = search_threshold or float ( "inf" )
81+ self .search_threshold = search_threshold
8782 if isinstance (cache_config , Dict ):
8883 self ._cache_config = EmbeddingsCacheConfig (** cache_config )
8984 else :
@@ -139,11 +134,6 @@ def _init_model(self):
139134 embedding_params = self .embedding_params ,
140135 )
141136
142- if not self ._model :
143- raise ValueError (
144- f"Couldn't create embedding model with model { model } and engine { engine } "
145- )
146-
147137 @cache_embeddings
148138 async def _get_embeddings (self , texts : List [str ]) -> List [List [float ]]:
149139 """Compute embeddings for a list of texts.
@@ -205,10 +195,12 @@ async def _run_batch(self):
205195 """Runs the current batch of embeddings."""
206196
207197 # Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
208- if not self ._current_batch_full_event :
209- raise Exception ("self._current_batch_full_event not initialized" )
198+ if (
199+ self ._current_batch_full_event is None
200+ or self ._current_batch_finished_event is None
201+ ):
202+ raise RuntimeError ("Batch events not initialized. This should not happen." )
210203
211- assert self ._current_batch_full_event is not None
212204 done , pending = await asyncio .wait (
213205 [
214206 asyncio .create_task (asyncio .sleep (self .max_batch_hold )),
@@ -220,9 +212,6 @@ async def _run_batch(self):
220212 task .cancel ()
221213
222214 # Reset the batch event
223- if not self ._current_batch_finished_event :
224- raise Exception ("self._current_batch_finished_event not initialized" )
225-
226215 batch_event : asyncio .Event = self ._current_batch_finished_event
227216 self ._current_batch_finished_event = None
228217
@@ -268,13 +257,9 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:
268257
269258 # We check if we reached the max batch size
270259 if len (self ._req_queue ) >= self .max_batch_size :
271- if not self ._current_batch_full_event :
272- raise Exception ("self._current_batch_full_event not initialized" )
273260 self ._current_batch_full_event .set ()
274261
275- # Wait for the batch to finish
276- if not self ._current_batch_finished_event :
277- raise Exception ("self._current_batch_finished_event not initialized" )
262+ # Wait for the batch to finish
278263 await self ._current_batch_finished_event .wait ()
279264
280265 # Remove the result and return it
0 commit comments