Skip to content

Commit 790882b

Browse files
Add rate limit handling to embedders (#425)
* Improve error handling for embedders including rate limit * Update unit tests * Update changelog and docs * Ruff * Fix mypy * Fix more mypy issues * Improve error handling for Mistral and Sentence Transformers * Improve unit tests * Move rate limit handler decorator to base class * Move rate limit module to utils and generate deprecation warnings * Refactor modules using rate limit handling * Update docs and examples * Fix refactoring * Ruff
1 parent 4541de2 commit 790882b

31 files changed

+880
-309
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Next
44

5+
### Added
6+
7+
- Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely.
8+
59
## 1.10.0
610

711
### Added

docs/source/api.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,19 +359,19 @@ Rate Limiting
359359
RateLimitHandler
360360
----------------
361361

362-
.. autoclass:: neo4j_graphrag.llm.rate_limit.RateLimitHandler
362+
.. autoclass:: neo4j_graphrag.utils.rate_limit.RateLimitHandler
363363
:members:
364364

365365
RetryRateLimitHandler
366366
---------------------
367367

368-
.. autoclass:: neo4j_graphrag.llm.rate_limit.RetryRateLimitHandler
368+
.. autoclass:: neo4j_graphrag.utils.rate_limit.RetryRateLimitHandler
369369
:members:
370370

371371
NoOpRateLimitHandler
372372
--------------------
373373

374-
.. autoclass:: neo4j_graphrag.llm.rate_limit.NoOpRateLimitHandler
374+
.. autoclass:: neo4j_graphrag.utils.rate_limit.NoOpRateLimitHandler
375375
:members:
376376

377377

docs/source/user_guide_rag.rst

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ Rate limiting is enabled by default for all LLM instances with the following con
327327
.. code:: python
328328
329329
from neo4j_graphrag.llm import OpenAILLM
330-
from neo4j_graphrag.llm.rate_limit import RetryRateLimitHandler
330+
from neo4j_graphrag.utils.rate_limit import RetryRateLimitHandler
331331
332332
# Customize rate limiting parameters
333333
llm = OpenAILLM(
@@ -348,7 +348,7 @@ You can customize the rate limiting behavior by creating your own rate limit han
348348
.. code:: python
349349
350350
from neo4j_graphrag.llm import AnthropicLLM
351-
from neo4j_graphrag.llm.rate_limit import RateLimitHandler
351+
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
352352
353353
class CustomRateLimitHandler(RateLimitHandler):
354354
"""Implement your custom rate limiting strategy."""
@@ -528,6 +528,37 @@ The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `Sente
528528
529529
If another embedder is desired, a custom embedder can be created, using the `Embedder` interface.
530530

531+
Embedder Rate Limiting
532+
----------------------
533+
534+
All embedder implementations include automatic rate limiting that uses retry logic with exponential backoff by default, similar to LLM implementations. This feature helps handle API rate limits from embedding providers gracefully.
535+
536+
.. code:: python
537+
538+
from neo4j_graphrag.embeddings import OpenAIEmbeddings
539+
from neo4j_graphrag.utils.rate_limit import RetryRateLimitHandler, NoOpRateLimitHandler
540+
541+
# Default rate limiting (automatically enabled)
542+
embedder = OpenAIEmbeddings(model="text-embedding-3-large")
543+
544+
# Custom rate limiting configuration
545+
embedder = OpenAIEmbeddings(
546+
model="text-embedding-3-large",
547+
rate_limit_handler=RetryRateLimitHandler(
548+
max_attempts=5,
549+
min_wait=2.0,
550+
max_wait=120.0
551+
)
552+
)
553+
554+
# Disable rate limiting
555+
embedder = OpenAIEmbeddings(
556+
model="text-embedding-3-large",
557+
rate_limit_handler=NoOpRateLimitHandler()
558+
)
559+
560+
The rate limiting configuration works the same way as for LLMs. See the :ref:`Rate Limit Handling <Rate Limit Handling>` section above for more details on customization options.
561+
531562

532563
Other Vector Retriever Configuration
533564
----------------------------------------

examples/customize/embeddings/custom_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class CustomEmbeddings(Embedder):
88
def __init__(self, dimension: int = 10, **kwargs: Any):
99
self.dimension = dimension
1010

11-
def embed_query(self, input: str) -> list[float]:
11+
def _embed_query(self, input: str) -> list[float]:
1212
return [random.random() for _ in range(self.dimension)]
1313

1414

examples/customize/llms/custom_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union
44

55
from neo4j_graphrag.llm import LLMInterface, LLMResponse
6-
from neo4j_graphrag.llm.rate_limit import (
6+
from neo4j_graphrag.utils.rate_limit import (
77
RateLimitHandler,
88
# rate_limit_handler,
99
# async_rate_limit_handler,

examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
# Create Embedder object
2222
class CustomEmbedder(Embedder):
23-
def embed_query(self, text: str) -> list[float]:
23+
def _embed_query(self, text: str) -> list[float]:
2424
return [random() for _ in range(DIMENSION)]
2525

2626

examples/customize/retrievers/hybrid_retrievers/hybrid_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
# Create Embedder object
2222
class CustomEmbedder(Embedder):
23-
def embed_query(self, text: str) -> list[float]:
23+
def _embed_query(self, text: str) -> list[float]:
2424
return [random() for _ in range(DIMENSION)]
2525

2626

src/neo4j_graphrag/embeddings/base.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,31 @@
1515
from __future__ import annotations
1616

1717
from abc import ABC, abstractmethod
18+
from typing import Optional
19+
20+
from neo4j_graphrag.utils.rate_limit import (
21+
DEFAULT_RATE_LIMIT_HANDLER,
22+
RateLimitHandler,
23+
rate_limit_handler,
24+
)
1825

1926

2027
class Embedder(ABC):
2128
"""
2229
Interface for embedding models.
2330
An embedder passed into a retriever must implement this interface.
31+
32+
Args:
33+
rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff.
2434
"""
2535

26-
@abstractmethod
36+
def __init__(self, rate_limit_handler: Optional[RateLimitHandler] = None):
37+
if rate_limit_handler is not None:
38+
self._rate_limit_handler = rate_limit_handler
39+
else:
40+
self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER
41+
42+
@rate_limit_handler
2743
def embed_query(self, text: str) -> list[float]:
2844
"""Embed query text.
2945
@@ -33,3 +49,15 @@ def embed_query(self, text: str) -> list[float]:
3349
Returns:
3450
list[float]: A vector embedding.
3551
"""
52+
return self._embed_query(text)
53+
54+
@abstractmethod
55+
def _embed_query(self, text: str) -> list[float]:
56+
"""Embed query text.
57+
58+
Args:
59+
text (str): Text to convert to vector embedding
60+
61+
Returns:
62+
list[float]: A vector embedding.
63+
"""

src/neo4j_graphrag/embeddings/cohere.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Any
17+
from typing import Any, Optional
1818

1919
from neo4j_graphrag.embeddings.base import Embedder
20+
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
21+
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
2022

2123
try:
2224
import cohere
@@ -25,19 +27,30 @@
2527

2628

2729
class CohereEmbeddings(Embedder):
28-
def __init__(self, model: str = "", **kwargs: Any) -> None:
30+
def __init__(
31+
self,
32+
model: str = "",
33+
rate_limit_handler: Optional[RateLimitHandler] = None,
34+
**kwargs: Any,
35+
) -> None:
2936
if cohere is None:
3037
raise ImportError(
3138
"""Could not import cohere python client.
3239
Please install it with `pip install "neo4j-graphrag[cohere]"`."""
3340
)
41+
super().__init__(rate_limit_handler)
3442
self.model = model
3543
self.client = cohere.Client(**kwargs)
3644

37-
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
38-
response = self.client.embed(
39-
texts=[text],
40-
model=self.model,
41-
**kwargs,
42-
)
43-
return response.embeddings[0] # type: ignore
45+
def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
46+
try:
47+
response = self.client.embed(
48+
texts=[text],
49+
model=self.model,
50+
**kwargs,
51+
)
52+
return response.embeddings[0] # type: ignore
53+
except Exception as e:
54+
raise EmbeddingsGenerationError(
55+
f"Failed to generate embedding with Cohere: {e}"
56+
) from e

src/neo4j_graphrag/embeddings/mistral.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from __future__ import annotations
1717

1818
import os
19-
from typing import Any
19+
from typing import Any, Optional
2020

2121
from neo4j_graphrag.embeddings.base import Embedder
2222
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
23+
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
2324

2425
try:
2526
from mistralai import Mistral
@@ -36,29 +37,41 @@ class MistralAIEmbeddings(Embedder):
3637
model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed".
3738
"""
3839

39-
def __init__(self, model: str = "mistral-embed", **kwargs: Any) -> None:
40+
def __init__(
41+
self,
42+
model: str = "mistral-embed",
43+
rate_limit_handler: Optional[RateLimitHandler] = None,
44+
**kwargs: Any,
45+
) -> None:
4046
if Mistral is None:
4147
raise ImportError(
4248
"""Could not import mistralai.
4349
Please install it with `pip install "neo4j-graphrag[mistralai]"`."""
4450
)
51+
super().__init__(rate_limit_handler)
4552
api_key = kwargs.pop("api_key", None)
4653
if api_key is None:
4754
api_key = os.getenv("MISTRAL_API_KEY", "")
4855
self.model = model
4956
self.mistral_client = Mistral(api_key=api_key, **kwargs)
5057

51-
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
58+
def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
5259
"""
5360
Generate embeddings for a given query using a Mistral AI text embedding model.
5461
5562
Args:
5663
text (str): The text to generate an embedding for.
5764
**kwargs (Any): Additional keyword arguments to pass to the Mistral AI client.
5865
"""
59-
embeddings_batch_response = self.mistral_client.embeddings.create(
60-
model=self.model, inputs=[text], **kwargs
61-
)
66+
try:
67+
embeddings_batch_response = self.mistral_client.embeddings.create(
68+
model=self.model, inputs=[text], **kwargs
69+
)
70+
except Exception as e:
71+
raise EmbeddingsGenerationError(
72+
f"Failed to generate embedding with MistralAI: {e}"
73+
) from e
74+
6275
if embeddings_batch_response is None or not embeddings_batch_response.data:
6376
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")
6477

0 commit comments

Comments
 (0)