Skip to content

Commit 2f48cab

Browse files
committed
Add mypy type checks
1 parent 9b7c35c commit 2f48cab

File tree

5 files changed

+137
-82
lines changed

5 files changed

+137
-82
lines changed

.github/workflows/lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
run: uv run ruff format --check
4646
- name: Run ruff check
4747
run: uv run ruff check
48-
# - name: Run mypy
49-
# run: uv run mypy .
48+
- name: Run mypy
49+
run: uv run mypy .
5050
- name: Minimize uv cache
5151
run: uv cache prune --ci

langchain/langchain_vectorize/retrievers.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,18 @@
44

55
from typing import TYPE_CHECKING, Any, Literal, Optional
66

7-
import vectorize_client
87
from langchain_core.documents import Document
98
from langchain_core.retrievers import BaseRetriever
109
from typing_extensions import override
11-
from vectorize_client import (
12-
ApiClient,
13-
Configuration,
14-
PipelinesApi,
15-
RetrieveDocumentsRequest,
16-
)
10+
from vectorize_client.api.pipelines_api import PipelinesApi
11+
from vectorize_client.api_client import ApiClient
12+
from vectorize_client.configuration import Configuration
13+
from vectorize_client.models.retrieve_documents_request import RetrieveDocumentsRequest
1714

1815
if TYPE_CHECKING:
1916
from langchain_core.callbacks import CallbackManagerForRetrieverRun
2017
from langchain_core.runnables import RunnableConfig
18+
from vectorize_client.models.document import Document as VectorizeDocument
2119

2220
_METADATA_FIELDS = {
2321
"relevancy",
@@ -122,7 +120,7 @@ def format_docs(docs):
122120
metadata_filters: list[dict[str, Any]] = []
123121
"""The metadata filters to apply when retrieving the documents."""
124122

125-
_pipelines: PipelinesApi | None = None
123+
_pipelines: PipelinesApi = _NOT_SET # type: ignore[assignment]
126124

127125
@override
128126
def model_post_init(self, /, context: Any) -> None:
@@ -146,7 +144,7 @@ def model_post_init(self, /, context: Any) -> None:
146144
self._pipelines = PipelinesApi(api)
147145

148146
@staticmethod
149-
def _convert_document(document: vectorize_client.models.Document) -> Document:
147+
def _convert_document(document: VectorizeDocument) -> Document:
150148
metadata = {field: getattr(document, field) for field in _METADATA_FIELDS}
151149
return Document(id=document.id, page_content=document.text, metadata=metadata)
152150

@@ -162,14 +160,29 @@ def _get_relevant_documents(
162160
rerank: bool | None = None,
163161
metadata_filters: list[dict[str, Any]] | None = None,
164162
) -> list[Document]:
165-
request = RetrieveDocumentsRequest(
163+
request = RetrieveDocumentsRequest( # type: ignore[call-arg]
166164
question=query,
167165
num_results=num_results or self.num_results,
168166
rerank=rerank or self.rerank,
169167
metadata_filters=metadata_filters or self.metadata_filters,
170168
)
169+
organization_ = organization or self.organization
170+
if not organization_:
171+
msg = (
172+
"Organization must be set either at initialization "
173+
"or in the invoke method."
174+
)
175+
raise ValueError(msg)
176+
pipeline_id_ = pipeline_id or self.pipeline_id
177+
if not pipeline_id_:
178+
msg = (
179+
"Pipeline ID must be set either at initialization "
180+
"or in the invoke method."
181+
)
182+
raise ValueError(msg)
183+
171184
response = self._pipelines.retrieve_documents(
172-
organization or self.organization, pipeline_id or self.pipeline_id, request
185+
organization_, pipeline_id_, request
173186
)
174187
return [self._convert_document(doc) for doc in response.documents]
175188

@@ -181,9 +194,10 @@ def invoke(
181194
*,
182195
organization: str = "",
183196
pipeline_id: str = "",
184-
num_results: int = _NOT_SET,
185-
rerank: bool = _NOT_SET,
186-
metadata_filters: list[dict[str, Any]] = _NOT_SET,
197+
num_results: int = _NOT_SET, # type: ignore[assignment]
198+
rerank: bool = _NOT_SET, # type: ignore[assignment]
199+
metadata_filters: list[dict[str, Any]] = _NOT_SET, # type: ignore[assignment]
200+
**_kwargs: Any,
187201
) -> list[Document]:
188202
"""Invoke the retriever to get relevant documents.
189203
@@ -218,16 +232,15 @@ def invoke(
218232
query = "what year was breath of the wild released?"
219233
docs = retriever.invoke(query, num_results=2)
220234
"""
221-
kwargs = {}
222235
if organization:
223-
kwargs["organization"] = organization
236+
_kwargs["organization"] = organization
224237
if pipeline_id:
225-
kwargs["pipeline_id"] = pipeline_id
238+
_kwargs["pipeline_id"] = pipeline_id
226239
if num_results is not _NOT_SET:
227-
kwargs["num_results"] = num_results
240+
_kwargs["num_results"] = num_results
228241
if rerank is not _NOT_SET:
229-
kwargs["rerank"] = rerank
242+
_kwargs["rerank"] = rerank
230243
if metadata_filters is not _NOT_SET:
231-
kwargs["metadata_filters"] = metadata_filters
244+
_kwargs["metadata_filters"] = metadata_filters
232245

233-
return super().invoke(input, config, **kwargs)
246+
return super().invoke(input, config, **_kwargs)

langchain/pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Issues = "https://github.com/vectorize-io/integrations-python/issues"
3131

3232
[dependency-groups]
3333
dev = [
34-
"mypy>=1.13.0",
34+
"mypy>=1.17.1,<1.18",
3535
"pytest>=8.3.3",
3636
"ruff>=0.9.0,<0.10",
3737
]
@@ -59,6 +59,8 @@ flake8-annotations.mypy-init-return = true
5959

6060
[tool.mypy]
6161
strict = true
62+
strict_bytes = true
63+
enable_error_code = "deprecated"
6264
warn_unreachable = true
6365
pretty = true
6466
show_error_codes = true

langchain/tests/test_retrievers.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,34 @@
88

99
import pytest
1010
import urllib3
11-
import vectorize_client as v
12-
from vectorize_client import ApiClient
11+
from vectorize_client.api.ai_platform_connectors_api import AIPlatformConnectorsApi
12+
from vectorize_client.api.destination_connectors_api import DestinationConnectorsApi
13+
from vectorize_client.api.pipelines_api import PipelinesApi
14+
from vectorize_client.api.source_connectors_api import SourceConnectorsApi
15+
from vectorize_client.api.uploads_api import UploadsApi
16+
from vectorize_client.api_client import ApiClient
17+
from vectorize_client.configuration import Configuration
18+
from vectorize_client.models.create_source_connector_request import (
19+
CreateSourceConnectorRequest,
20+
)
21+
from vectorize_client.models.file_upload import FileUpload
22+
from vectorize_client.models.pipeline_ai_platform_connector_schema import (
23+
PipelineAIPlatformConnectorSchema,
24+
)
25+
from vectorize_client.models.pipeline_configuration_schema import (
26+
PipelineConfigurationSchema,
27+
)
28+
from vectorize_client.models.pipeline_destination_connector_schema import (
29+
PipelineDestinationConnectorSchema,
30+
)
31+
from vectorize_client.models.pipeline_source_connector_schema import (
32+
PipelineSourceConnectorSchema,
33+
)
34+
from vectorize_client.models.schedule_schema import ScheduleSchema
35+
from vectorize_client.models.source_connector_type import SourceConnectorType
36+
from vectorize_client.models.start_file_upload_to_connector_request import (
37+
StartFileUploadToConnectorRequest,
38+
)
1339

1440
from langchain_vectorize.retrievers import VectorizeRetriever
1541

@@ -38,7 +64,7 @@ def environment() -> Literal["prod", "dev", "local", "staging"]:
3864
if env not in ["prod", "dev", "local", "staging"]:
3965
msg = "Invalid VECTORIZE_ENV environment variable."
4066
raise ValueError(msg)
41-
return env
67+
return env # type: ignore[return-value]
4268

4369

4470
@pytest.fixture(scope="session")
@@ -56,33 +82,31 @@ def api_client(api_token: str, environment: str) -> Iterator[ApiClient]:
5682
else:
5783
host = "https://api-staging.vectorize.io/v1"
5884

59-
with v.ApiClient(
60-
v.Configuration(host=host, access_token=api_token, debug=True),
85+
with ApiClient(
86+
Configuration(host=host, access_token=api_token, debug=True),
6187
header_name,
6288
header_value,
6389
) as api:
6490
yield api
6591

6692

6793
@pytest.fixture(scope="session")
68-
def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
69-
pipelines = v.PipelinesApi(api_client)
94+
def pipeline_id(api_client: ApiClient, org_id: str) -> Iterator[str]:
95+
pipelines = PipelinesApi(api_client)
7096

71-
connectors_api = v.SourceConnectorsApi(api_client)
97+
connectors_api = SourceConnectorsApi(api_client)
7298
response = connectors_api.create_source_connector(
7399
org_id,
74-
v.CreateSourceConnectorRequest(
75-
v.FileUpload(name="from api", type="FILE_UPLOAD")
76-
),
100+
CreateSourceConnectorRequest(FileUpload(name="from api", type="FILE_UPLOAD")),
77101
)
78102
source_connector_id = response.connector.id
79103
logging.info("Created source connector %s", source_connector_id)
80104

81-
uploads_api = v.UploadsApi(api_client)
105+
uploads_api = UploadsApi(api_client)
82106
upload_response = uploads_api.start_file_upload_to_connector(
83107
org_id,
84108
source_connector_id,
85-
v.StartFileUploadToConnectorRequest(
109+
StartFileUploadToConnectorRequest( # type: ignore[call-arg]
86110
name="research.pdf",
87111
content_type="application/pdf",
88112
metadata=json.dumps({"created-from-api": True}),
@@ -109,44 +133,44 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
109133
else:
110134
logging.info("Upload successful")
111135

112-
ai_platforms = v.AIPlatformConnectorsApi(api_client).get_ai_platform_connectors(
136+
ai_platforms = AIPlatformConnectorsApi(api_client).get_ai_platform_connectors(
113137
org_id
114138
)
115139
builtin_ai_platform = next(
116140
c.id for c in ai_platforms.ai_platform_connectors if c.type == "VECTORIZE"
117141
)
118142
logging.info("Using AI platform %s", builtin_ai_platform)
119143

120-
vector_databases = v.DestinationConnectorsApi(
121-
api_client
122-
).get_destination_connectors(org_id)
144+
vector_databases = DestinationConnectorsApi(api_client).get_destination_connectors(
145+
org_id
146+
)
123147
builtin_vector_db = next(
124148
c.id for c in vector_databases.destination_connectors if c.type == "VECTORIZE"
125149
)
126150
logging.info("Using destination connector %s", builtin_vector_db)
127151

128152
pipeline_response = pipelines.create_pipeline(
129153
org_id,
130-
v.PipelineConfigurationSchema(
154+
PipelineConfigurationSchema( # type: ignore[call-arg]
131155
source_connectors=[
132-
v.PipelineSourceConnectorSchema(
156+
PipelineSourceConnectorSchema(
133157
id=source_connector_id,
134-
type=v.SourceConnectorType.FILE_UPLOAD,
158+
type=SourceConnectorType.FILE_UPLOAD,
135159
config={},
136160
)
137161
],
138-
destination_connector=v.PipelineDestinationConnectorSchema(
162+
destination_connector=PipelineDestinationConnectorSchema(
139163
id=builtin_vector_db,
140164
type="VECTORIZE",
141165
config={},
142166
),
143-
ai_platform_connector=v.PipelineAIPlatformConnectorSchema(
167+
ai_platform_connector=PipelineAIPlatformConnectorSchema(
144168
id=builtin_ai_platform,
145169
type="VECTORIZE",
146170
config={},
147171
),
148172
pipeline_name="Test pipeline",
149-
schedule=v.ScheduleSchema(type="manual"),
173+
schedule=ScheduleSchema(type="manual"),
150174
),
151175
)
152176
pipeline_id = pipeline_response.data.id

0 commit comments

Comments
 (0)