|
| 1 | +import os |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +from redisvl.extensions.cache.embeddings import EmbeddingsCache |
1 | 5 | from redisvl.utils.vectorize.base import BaseVectorizer, Vectorizers |
2 | 6 | from redisvl.utils.vectorize.text.azureopenai import AzureOpenAITextVectorizer |
3 | 7 | from redisvl.utils.vectorize.text.bedrock import BedrockTextVectorizer |
|
23 | 27 | ] |
24 | 28 |
|
25 | 29 |
|
26 | | -def vectorizer_from_dict(vectorizer: dict) -> BaseVectorizer: |
| 30 | +def vectorizer_from_dict( |
| 31 | + vectorizer: dict, |
| 32 | + cache: dict = {}, |
| 33 | + cache_folder=os.getenv("SENTENCE_TRANSFORMERS_HOME"), |
| 34 | +) -> BaseVectorizer: |
27 | 35 | vectorizer_type = Vectorizers(vectorizer["type"]) |
28 | 36 | model = vectorizer["model"] |
| 37 | + |
| 38 | + args = {"model": model} |
| 39 | + if cache: |
| 40 | + emb_cache = EmbeddingsCache(**cache) |
| 41 | + args["cache"] = emb_cache |
| 42 | + |
29 | 43 | if vectorizer_type == Vectorizers.cohere: |
30 | | - return CohereTextVectorizer(model=model) |
| 44 | + return CohereTextVectorizer(**args) |
31 | 45 | elif vectorizer_type == Vectorizers.openai: |
32 | | - return OpenAITextVectorizer(model=model) |
| 46 | + return OpenAITextVectorizer(**args) |
33 | 47 | elif vectorizer_type == Vectorizers.azure_openai: |
34 | | - return AzureOpenAITextVectorizer(model=model) |
| 48 | + return AzureOpenAITextVectorizer(**args) |
35 | 49 | elif vectorizer_type == Vectorizers.hf: |
36 | | - return HFTextVectorizer(model=model) |
| 50 | + return HFTextVectorizer(**args) |
37 | 51 | elif vectorizer_type == Vectorizers.mistral: |
38 | | - return MistralAITextVectorizer(model=model) |
| 52 | + return MistralAITextVectorizer(**args) |
39 | 53 | elif vectorizer_type == Vectorizers.vertexai: |
40 | | - return VertexAITextVectorizer(model=model) |
| 54 | + return VertexAITextVectorizer(**args) |
41 | 55 | elif vectorizer_type == Vectorizers.voyageai: |
42 | | - return VoyageAITextVectorizer(model=model) |
| 56 | + return VoyageAITextVectorizer(**args) |
43 | 57 | else: |
44 | 58 | raise ValueError(f"Unsupported vectorizer type: {vectorizer_type}") |
0 commit comments