|
1 | | -from typing import List, Optional, Union |
| 1 | +from typing import List, Optional, Tuple, Union |
2 | 2 |
|
3 | 3 | from redis.commands.search.dialect import DEFAULT_DIALECT |
4 | 4 |
|
@@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None: |
31 | 31 | self._with_scores: bool = False |
32 | 32 | self._scorer: Optional[str] = None |
33 | 33 | self._filters: List = list() |
34 | | - self._ids: Optional[List[str]] = None |
| 34 | + self._ids: Optional[Tuple[str]] = None |
35 | 35 | self._slop: int = -1 |
36 | 36 | self._timeout: Optional[float] = None |
37 | 37 | self._in_order: bool = False |
@@ -81,7 +81,7 @@ def return_field( |
81 | 81 | self._return_fields += ("AS", as_field) |
82 | 82 | return self |
83 | 83 |
|
84 | | - def _mk_field_list(self, fields: List[str]) -> List: |
| 84 | + def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List: |
85 | 85 | if not fields: |
86 | 86 | return [] |
87 | 87 | return [fields] if isinstance(fields, str) else list(fields) |
@@ -126,7 +126,7 @@ def summarize( |
126 | 126 |
|
127 | 127 | def highlight( |
128 | 128 | self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None |
129 | | - ) -> None: |
| 129 | + ) -> "Query": |
130 | 130 | """ |
131 | 131 | Apply specified markup to matched term(s) within the returned field(s). |
132 | 132 |
|
@@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query": |
187 | 187 | self._scorer = scorer |
188 | 188 | return self |
189 | 189 |
|
190 | | - def get_args(self) -> List[str]: |
| 190 | + def get_args(self) -> List[Union[str, int, float]]: |
191 | 191 | """Format the redis arguments for this query and return them.""" |
192 | | - args = [self._query_string] |
| 192 | + args: List[Union[str, int, float]] = [self._query_string] |
193 | 193 | args += self._get_args_tags() |
194 | 194 | args += self._summarize_fields + self._highlight_fields |
195 | 195 | args += ["LIMIT", self._offset, self._num] |
196 | 196 | return args |
197 | 197 |
|
198 | | - def _get_args_tags(self) -> List[str]: |
199 | | - args = [] |
| 198 | + def _get_args_tags(self) -> List[Union[str, int, float]]: |
| 199 | + args: List[Union[str, int, float]] = [] |
200 | 200 | if self._no_content: |
201 | 201 | args.append("NOCONTENT") |
202 | 202 | if self._fields: |
@@ -288,14 +288,14 @@ def with_scores(self) -> "Query": |
288 | 288 | self._with_scores = True |
289 | 289 | return self |
290 | 290 |
|
291 | | - def limit_fields(self, *fields: List[str]) -> "Query": |
| 291 | + def limit_fields(self, *fields: str) -> "Query": |
292 | 292 | """ |
293 | 293 | Limit the search to specific TEXT fields only. |
294 | 294 |
|
295 | | - - **fields**: A list of strings, case sensitive field names |
| 295 | + - **fields**: Each element should be a string, case sensitive field name |
296 | 296 | from the defined schema. |
297 | 297 | """ |
298 | | - self._fields = fields |
| 298 | + self._fields = list(fields) |
299 | 299 | return self |
300 | 300 |
|
301 | 301 | def add_filter(self, flt: "Filter") -> "Query": |
@@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query": |
340 | 340 |
|
341 | 341 |
|
342 | 342 | class Filter: |
343 | | - def __init__(self, keyword: str, field: str, *args: List[str]) -> None: |
| 343 | + def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None: |
344 | 344 | self.args = [keyword, field] + list(args) |
345 | 345 |
|
346 | 346 |
|
|
0 commit comments