Skip to content

Commit 5c26576

Browse files
committed
Adding more mypy fixes in search.
1 parent dcbf473 commit 5c26576

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

redis/commands/search/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import redis
1+
from redis.client import Pipeline as RedisPipeline
22

33
from ...asyncio.client import Pipeline as AsyncioPipeline
44
from .commands import (
@@ -181,9 +181,17 @@ def pipeline(self, transaction=True, shard_hint=None):
181181
return p
182182

183183

184-
class Pipeline(SearchCommands, redis.client.Pipeline):
184+
class Pipeline(SearchCommands, RedisPipeline):
185185
"""Pipeline for the module."""
186186

187+
def __init__(self, connection_pool, response_callbacks, transaction, shard_hint):
188+
super().__init__(connection_pool, response_callbacks, transaction, shard_hint)
189+
self.index_name: str = ""
187190

188-
class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline):
191+
192+
class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline):
189193
"""AsyncPipeline for the module."""
194+
195+
def __init__(self, connection_pool, response_callbacks, transaction, shard_hint):
196+
super().__init__(connection_pool, response_callbacks, transaction, shard_hint)
197+
self.index_name: str = ""

redis/commands/search/commands.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,44 @@
6464
class SearchCommands:
6565
"""Search commands."""
6666

67+
@property
68+
def index_name(self) -> str:
69+
"""The name of the search index. Must be implemented by inheriting classes."""
70+
if not hasattr(self, "_index_name"):
71+
raise AttributeError("index_name must be set by the inheriting class")
72+
return self._index_name
73+
74+
@index_name.setter
75+
def index_name(self, value: str) -> None:
76+
"""Set the name of the search index."""
77+
self._index_name = value
78+
79+
@property
80+
def client(self):
81+
"""The Redis client. Must be provided by inheriting classes."""
82+
if not hasattr(self, "_client"):
83+
raise AttributeError("client must be set by the inheriting class")
84+
return self._client
85+
86+
@client.setter
87+
def client(self, value) -> None:
88+
"""Set the Redis client."""
89+
self._client = value
90+
91+
@property
92+
def _RESP2_MODULE_CALLBACKS(self):
93+
"""Response callbacks for RESP2. Must be provided by inheriting classes."""
94+
if not hasattr(self, "_resp2_module_callbacks"):
95+
raise AttributeError(
96+
"_RESP2_MODULE_CALLBACKS must be set by the inheriting class"
97+
)
98+
return self._resp2_module_callbacks
99+
100+
@_RESP2_MODULE_CALLBACKS.setter
101+
def _RESP2_MODULE_CALLBACKS(self, value) -> None:
102+
"""Set the RESP2 module callbacks."""
103+
self._resp2_module_callbacks = value
104+
67105
def _parse_results(self, cmd, res, **kwargs):
68106
if get_protocol_version(self.client) in ["3", 3]:
69107
return ProfileInformation(res) if cmd == "FT.PROFILE" else res
@@ -221,7 +259,7 @@ def create_index(
221259

222260
return self.execute_command(*args)
223261

224-
def alter_schema_add(self, fields: List[str]):
262+
def alter_schema_add(self, fields: Union[Field, List[Field]]):
225263
"""
226264
Alter the existing search index by adding new fields. The index
227265
must already exist.
@@ -336,11 +374,11 @@ def add_document(
336374
doc_id: str,
337375
nosave: bool = False,
338376
score: float = 1.0,
339-
payload: bool = None,
377+
payload: Optional[bool] = None,
340378
replace: bool = False,
341379
partial: bool = False,
342380
language: Optional[str] = None,
343-
no_create: str = False,
381+
no_create: bool = False,
344382
**fields: List[str],
345383
):
346384
"""
@@ -478,7 +516,7 @@ def get_params_args(
478516
return args
479517

480518
def _mk_query_args(
481-
self, query, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
519+
self, query, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
482520
):
483521
args = [self.index_name]
484522

@@ -528,7 +566,7 @@ def search(
528566
def explain(
529567
self,
530568
query: Union[str, Query],
531-
query_params: Dict[str, Union[str, int, float]] = None,
569+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
532570
):
533571
"""Returns the execution plan for a complex query.
534572
@@ -543,7 +581,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa
543581
def aggregate(
544582
self,
545583
query: Union[AggregateRequest, Cursor],
546-
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
584+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
547585
):
548586
"""
549587
Issue an aggregation query.
@@ -598,7 +636,7 @@ def profile(
598636
self,
599637
query: Union[Query, AggregateRequest],
600638
limited: bool = False,
601-
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
639+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
602640
):
603641
"""
604642
Performs a search or aggregate command and collects performance
@@ -936,7 +974,7 @@ async def info(self):
936974
async def search(
937975
self,
938976
query: Union[str, Query],
939-
query_params: Dict[str, Union[str, int, float]] = None,
977+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
940978
):
941979
"""
942980
Search the index for a given query, and return a result of documents
@@ -968,7 +1006,7 @@ async def search(
9681006
async def aggregate(
9691007
self,
9701008
query: Union[AggregateResult, Cursor],
971-
query_params: Dict[str, Union[str, int, float]] = None,
1009+
query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None,
9721010
):
9731011
"""
9741012
Issue an aggregation query.

0 commit comments

Comments
 (0)