Skip to content

Commit 5c4b5d8

Browse files
saeedseyfimdrxy
andauthored
feat(vertexai): add default dimensions parameter to VertexAIEmbeddings (#1242)
## Summary Add optional `dimensions` constructor argument to `VertexAIEmbeddings` class. This allows users to set a default output dimensionality that will be used for all embedding requests unless explicitly overridden. ## Changes - Add `dimensions` field to `VertexAIEmbeddings` with default value `None` - Update `embed()` method to use constructor default when dimensions not specified in method call - Explicit dimensions in method calls override the constructor default - Full backward compatibility maintained - existing code continues to work unchanged ## Usage Example ```python # Set default dimensions in constructor embeddings = VertexAIEmbeddings(model="text-embedding-004", dimensions=128) # Uses 128 dimensions (from constructor) embeddings.embed(["hello", "world"]) # Overrides with 256 dimensions embeddings.embed(["hello", "world"], dimensions=256) # Backward compatible - no dimensions specified embeddings = VertexAIEmbeddings(model="text-embedding-004") embeddings.embed(["hello", "world"]) # Uses model's default ``` ## Test Plan - [x] Added 5 new unit tests covering all scenarios - [x] Test that constructor dimensions are used when not specified in embed() - [x] Test that explicit dimensions override constructor default - [x] Test backward compatibility when no default dimensions specified - [x] Test default dimensions work with embed_documents() and embed_query() - [x] All 10 unit tests pass - [x] Code quality checks pass (lint, format, mypy) --------- Co-authored-by: Mason Daugherty <mason@langchain.dev>
1 parent cce3e34 commit 5c4b5d8

File tree

2 files changed

+109
-4
lines changed

2 files changed

+109
-4
lines changed

libs/vertexai/langchain_google_vertexai/embeddings.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class VertexAIEmbeddings(BaseModel, Embeddings):
5454

5555
max_retries: int = 6
5656
"""The maximum number of retries to make when generating."""
57+
dimensions: int | None = None
58+
"""Default output dimensionality for embeddings. If not specified, uses the
59+
model's default. Can be overridden per request in embed() method."""
5760

5861
@model_validator(mode="before")
5962
@classmethod
@@ -130,10 +133,11 @@ def embed(
130133
131134
The following are only supported on preview models:
132135
`QUESTION_ANSWERING`, `FACT_VERIFICATION`.
133-
dimensions: Optional output embeddings dimensions.
136+
dimensions: Output embeddings dimensions.
134137
135-
Only supported on preview models.
136-
title: Optional title for the text.
138+
Only supported on preview models. If not provided, uses the
139+
default dimensions specified in the constructor.
140+
title: Title for the text.
137141
138142
Only applicable when `TaskType` is `RETRIEVAL_DOCUMENT`.
139143
@@ -142,10 +146,11 @@ def embed(
142146
"""
143147
if len(texts) == 0:
144148
return []
149+
effective_dimensions = dimensions if dimensions is not None else self.dimensions
145150
embeddings = self._get_embeddings_with_retry(
146151
texts=texts,
147152
embeddings_type=embeddings_task_type,
148-
dimensions=dimensions,
153+
dimensions=effective_dimensions,
149154
title=title,
150155
)
151156
return embeddings

libs/vertexai/tests/unit_tests/test_embeddings.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,103 @@ def test_embed_parameters(mock_get_embeddings, mock_client):
9494
dimensions=128,
9595
title="test-title",
9696
)
97+
98+
99+
@patch("langchain_google_vertexai.embeddings.genai.Client")
100+
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
101+
def test_default_dimensions_used_when_not_specified(mock_get_embeddings, mock_client):
102+
"""Test that constructor dimensions are used when not specified in embed()."""
103+
mock_client.return_value = MagicMock()
104+
embeddings = VertexAIEmbeddings(model="text-embedding-004", dimensions=256)
105+
texts = ["hello", "world"]
106+
107+
mock_get_embeddings.return_value = [[0.001] * 256 for _ in texts]
108+
109+
embeddings.embed(texts)
110+
111+
mock_get_embeddings.assert_called_once_with(
112+
texts=texts,
113+
embeddings_type=None,
114+
dimensions=256,
115+
title=None,
116+
)
117+
118+
119+
@patch("langchain_google_vertexai.embeddings.genai.Client")
120+
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
121+
def test_explicit_dimensions_override_default(mock_get_embeddings, mock_client):
122+
"""Test that explicit dimensions in embed() override constructor default."""
123+
mock_client.return_value = MagicMock()
124+
embeddings = VertexAIEmbeddings(model="text-embedding-004", dimensions=256)
125+
texts = ["hello", "world"]
126+
127+
mock_get_embeddings.return_value = [[0.001] * 512 for _ in texts]
128+
129+
embeddings.embed(texts, dimensions=512)
130+
131+
mock_get_embeddings.assert_called_once_with(
132+
texts=texts,
133+
embeddings_type=None,
134+
dimensions=512,
135+
title=None,
136+
)
137+
138+
139+
@patch("langchain_google_vertexai.embeddings.genai.Client")
140+
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
141+
def test_no_default_dimensions_works_as_before(mock_get_embeddings, mock_client):
142+
"""Test backward compatibility when no default dimensions specified."""
143+
mock_client.return_value = MagicMock()
144+
embeddings = VertexAIEmbeddings(model="text-embedding-004")
145+
texts = ["hello", "world"]
146+
147+
mock_get_embeddings.return_value = [[0.001] * 768 for _ in texts]
148+
149+
embeddings.embed(texts)
150+
151+
mock_get_embeddings.assert_called_once_with(
152+
texts=texts,
153+
embeddings_type=None,
154+
dimensions=None,
155+
title=None,
156+
)
157+
158+
159+
@patch("langchain_google_vertexai.embeddings.genai.Client")
160+
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
161+
def test_default_dimensions_used_in_embed_documents(mock_get_embeddings, mock_client):
162+
"""Test that constructor dimensions are used in embed_documents()."""
163+
mock_client.return_value = MagicMock()
164+
embeddings = VertexAIEmbeddings(model="text-embedding-004", dimensions=128)
165+
texts = ["hello", "world"]
166+
167+
mock_get_embeddings.return_value = [[0.001] * 128 for _ in texts]
168+
169+
embeddings.embed_documents(texts)
170+
171+
mock_get_embeddings.assert_called_once_with(
172+
texts=texts,
173+
embeddings_type="RETRIEVAL_DOCUMENT",
174+
dimensions=128,
175+
title=None,
176+
)
177+
178+
179+
@patch("langchain_google_vertexai.embeddings.genai.Client")
180+
@patch.object(VertexAIEmbeddings, "_get_embeddings_with_retry")
181+
def test_default_dimensions_used_in_embed_query(mock_get_embeddings, mock_client):
182+
"""Test that constructor dimensions are used in embed_query()."""
183+
mock_client.return_value = MagicMock()
184+
embeddings = VertexAIEmbeddings(model="text-embedding-004", dimensions=128)
185+
text = "hello"
186+
187+
mock_get_embeddings.return_value = [[0.001] * 128]
188+
189+
embeddings.embed_query(text)
190+
191+
mock_get_embeddings.assert_called_once_with(
192+
texts=[text],
193+
embeddings_type="RETRIEVAL_QUERY",
194+
dimensions=128,
195+
title=None,
196+
)

0 commit comments

Comments
 (0)