Skip to content

Commit 0287885

Browse files
migrates mistral vectorizer to new client (#255)
1 parent 201d676 commit 0287885

File tree

5 files changed

+48
-52
lines changed

5 files changed

+48
-52
lines changed

docs/user_guide/vectorizers_04.ipynb

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
},
3232
{
3333
"cell_type": "code",
34-
"execution_count": 1,
34+
"execution_count": 2,
3535
"metadata": {},
3636
"outputs": [],
3737
"source": [
@@ -305,33 +305,25 @@
305305
},
306306
{
307307
"cell_type": "code",
308-
"execution_count": 6,
308+
"execution_count": 3,
309309
"metadata": {},
310310
"outputs": [
311-
{
312-
"name": "stderr",
313-
"output_type": "stream",
314-
"text": [
315-
"/Users/tyler.hutcherson/redis/redis-vl-python/.venv/lib/python3.9/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
316-
" return self.fget.__get__(instance, owner)()\n"
317-
]
318-
},
319311
{
320312
"data": {
321313
"text/plain": [
322-
"[0.00037810884532518685,\n",
323-
" -0.05080341175198555,\n",
324-
" -0.03514723479747772,\n",
325-
" -0.02325104922056198,\n",
326-
" -0.044158220291137695,\n",
327-
" 0.020487844944000244,\n",
328-
" 0.0014617963461205363,\n",
329-
" 0.031261757016181946,\n",
314+
"[0.0003780885017476976,\n",
315+
" -0.05080340430140495,\n",
316+
" -0.035147231072187424,\n",
317+
" -0.02325103059411049,\n",
318+
" -0.04415831342339516,\n",
319+
" 0.02048780582845211,\n",
320+
" 0.0014618589775636792,\n",
321+
" 0.03126184269785881,\n",
330322
" 0.05605152249336243,\n",
331-
" 0.018815357238054276]"
323+
" 0.018815429881215096]"
332324
]
333325
},
334-
"execution_count": 6,
326+
"execution_count": 3,
335327
"metadata": {},
336328
"output_type": "execute_result"
337329
}
@@ -532,14 +524,14 @@
532524
}
533525
],
534526
"source": [
535-
"# from redisvl.utils.vectorize import MistralAITextVectorizer\n",
527+
"from redisvl.utils.vectorize import MistralAITextVectorizer\n",
536528
"\n",
537-
"# mistral = MistralAITextVectorizer()\n",
529+
"mistral = MistralAITextVectorizer()\n",
538530
"\n",
539-
"# # embed a sentence using their asyncronous method\n",
540-
"# test = await mistral.aembed(\"This is a test sentence.\")\n",
541-
"# print(\"Vector dimensions: \", len(test))\n",
542-
"# print(test[:10])"
531+
"# embed a sentence using their asyncronous method\n",
532+
"test = await mistral.aembed(\"This is a test sentence.\")\n",
533+
"print(\"Vector dimensions: \", len(test))\n",
534+
"print(test[:10])"
543535
]
544536
},
545537
{
@@ -588,9 +580,17 @@
588580
},
589581
{
590582
"cell_type": "code",
591-
"execution_count": null,
583+
"execution_count": 3,
592584
"metadata": {},
593-
"outputs": [],
585+
"outputs": [
586+
{
587+
"name": "stdout",
588+
"output_type": "stream",
589+
"text": [
590+
"Vector dimensions: 1024\n"
591+
]
592+
}
593+
],
594594
"source": [
595595
"from redisvl.utils.vectorize import BedrockTextVectorizer\n",
596596
"\n",
@@ -836,7 +836,7 @@
836836
],
837837
"metadata": {
838838
"kernelspec": {
839-
"display_name": "Python 3.8.13 ('redisvl2')",
839+
"display_name": "redisvl-dev",
840840
"language": "python",
841841
"name": "python3"
842842
},
@@ -852,12 +852,7 @@
852852
"pygments_lexer": "ipython3",
853853
"version": "3.12.2"
854854
},
855-
"orig_nbformat": 4,
856-
"vscode": {
857-
"interpreter": {
858-
"hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316"
859-
}
860-
}
855+
"orig_nbformat": 4
861856
},
862857
"nbformat": 4,
863858
"nbformat_minor": 2

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ sentence-transformers = { version = ">=2.2.2", optional = true }
3232
google-cloud-aiplatform = { version = ">=1.26", optional = true }
3333
protobuf = { version = ">=5.29.1,<6.0.0.dev0", optional = true }
3434
cohere = { version = ">=4.44", optional = true }
35-
mistralai = { version = ">=0.2.0", optional = true }
35+
mistralai = { version = ">=1.0.0", optional = true }
3636
boto3 = { version = ">=1.34.0", optional = true }
3737

3838
[tool.poetry.extras]

redisvl/utils/vectorize/text/mistral.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class MistralAITextVectorizer(BaseVectorizer):
4444
"""
4545

4646
_client: Any = PrivateAttr()
47-
_aclient: Any = PrivateAttr()
4847

4948
def __init__(self, model: str = "mistral-embed", api_config: Optional[Dict] = None):
5049
"""Initialize the MistralAI vectorizer.
@@ -69,8 +68,7 @@ def _initialize_clients(self, api_config: Optional[Dict]):
6968
"""
7069
# Dynamic import of the mistralai module
7170
try:
72-
from mistralai.async_client import MistralAsyncClient
73-
from mistralai.client import MistralClient
71+
from mistralai import Mistral
7472
except ImportError:
7573
raise ImportError(
7674
"MistralAI vectorizer requires the mistralai library. \
@@ -88,13 +86,12 @@ def _initialize_clients(self, api_config: Optional[Dict]):
8886
environment variable."
8987
)
9088

91-
self._client = MistralClient(api_key=api_key)
92-
self._aclient = MistralAsyncClient(api_key=api_key)
89+
self._client = Mistral(api_key=api_key)
9390

9491
def _set_model_dims(self, model) -> int:
9592
try:
9693
embedding = (
97-
self._client.embeddings(model=model, input=["dimension test"])
94+
self._client.embeddings.create(model=model, inputs=["dimension test"])
9895
.data[0]
9996
.embedding
10097
)
@@ -144,7 +141,7 @@ def embed_many(
144141

145142
embeddings: List = []
146143
for batch in self.batchify(texts, batch_size, preprocess):
147-
response = self._client.embeddings(model=self.model, input=batch)
144+
response = self._client.embeddings.create(model=self.model, inputs=batch)
148145
embeddings += [
149146
self._process_embedding(r.embedding, as_buffer, dtype)
150147
for r in response.data
@@ -186,7 +183,7 @@ def embed(
186183

187184
dtype = kwargs.pop("dtype", None)
188185

189-
result = self._client.embeddings(model=self.model, input=[text])
186+
result = self._client.embeddings.create(model=self.model, inputs=[text])
190187
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
191188

192189
@retry(
@@ -228,7 +225,9 @@ async def aembed_many(
228225

229226
embeddings: List = []
230227
for batch in self.batchify(texts, batch_size, preprocess):
231-
response = await self._aclient.embeddings(model=self.model, input=batch)
228+
response = await self._client.embeddings.create_async(
229+
model=self.model, inputs=batch
230+
)
232231
embeddings += [
233232
self._process_embedding(r.embedding, as_buffer, dtype)
234233
for r in response.data
@@ -270,7 +269,9 @@ async def aembed(
270269

271270
dtype = kwargs.pop("dtype", None)
272271

273-
result = await self._aclient.embeddings(model=self.model, input=[text])
272+
result = await self._client.embeddings.create_async(
273+
model=self.model, inputs=[text]
274+
)
274275
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
275276

276277
@property

tests/integration/test_vectorizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def skip_vectorizer() -> bool:
2828
CohereTextVectorizer,
2929
AzureOpenAITextVectorizer,
3030
BedrockTextVectorizer,
31-
# MistralAITextVectorizer,
31+
MistralAITextVectorizer,
3232
CustomTextVectorizer,
3333
]
3434
)
@@ -242,7 +242,7 @@ def bad_return_type(text: str) -> str:
242242
params=[
243243
OpenAITextVectorizer,
244244
BedrockTextVectorizer,
245-
# MistralAITextVectorizer,
245+
MistralAITextVectorizer,
246246
CustomTextVectorizer,
247247
]
248248
)

0 commit comments

Comments
 (0)