|
4 | 4 | datetime, |
5 | 5 | timedelta, |
6 | 6 | ) |
7 | | -from typing import TYPE_CHECKING |
| 7 | +from typing import ( |
| 8 | + TYPE_CHECKING, |
| 9 | + Literal, |
| 10 | + overload, |
| 11 | +) |
8 | 12 | import warnings |
9 | 13 |
|
10 | 14 | from dateutil.relativedelta import ( |
@@ -281,6 +285,17 @@ def __repr__(self) -> str: |
281 | 285 | repr = f"Holiday: {self.name} ({info})" |
282 | 286 | return repr |
283 | 287 |
|
| 288 | + @overload |
| 289 | + def dates(self, start_date, end_date, return_name: Literal[True]) -> Series: ... |
| 290 | + |
| 291 | + @overload |
| 292 | + def dates( |
| 293 | + self, start_date, end_date, return_name: Literal[False] |
| 294 | + ) -> DatetimeIndex: ... |
| 295 | + |
| 296 | + @overload |
| 297 | + def dates(self, start_date, end_date) -> DatetimeIndex: ... |
| 298 | + |
284 | 299 | def dates( |
285 | 300 | self, start_date, end_date, return_name: bool = False |
286 | 301 | ) -> Series | DatetimeIndex: |
@@ -411,7 +426,7 @@ def _apply_rule(self, dates: DatetimeIndex) -> DatetimeIndex: |
411 | 426 | return dates |
412 | 427 |
|
413 | 428 |
|
414 | | -holiday_calendars = {} |
| 429 | +holiday_calendars: dict[str, type[AbstractHolidayCalendar]] = {} |
415 | 430 |
|
416 | 431 |
|
417 | 432 | def register(cls) -> None: |
@@ -449,7 +464,7 @@ class AbstractHolidayCalendar(metaclass=HolidayCalendarMetaClass): |
449 | 464 | rules: list[Holiday] = [] |
450 | 465 | start_date = Timestamp(datetime(1970, 1, 1)) |
451 | 466 | end_date = Timestamp(datetime(2200, 12, 31)) |
452 | | - _cache = None |
| 467 | + _cache: tuple[Timestamp, Timestamp, Series] | None = None |
453 | 468 |
|
454 | 469 | def __init__(self, name: str = "", rules=None) -> None: |
455 | 470 | """ |
@@ -478,7 +493,9 @@ def rule_from_name(self, name: str) -> Holiday | None: |
478 | 493 |
|
479 | 494 | return None |
480 | 495 |
|
481 | | - def holidays(self, start=None, end=None, return_name: bool = False): |
| 496 | + def holidays( |
| 497 | + self, start=None, end=None, return_name: bool = False |
| 498 | + ) -> DatetimeIndex | Series: |
482 | 499 | """ |
483 | 500 | Returns a curve with holidays between start_date and end_date |
484 | 501 |
|
@@ -515,14 +532,9 @@ def holidays(self, start=None, end=None, return_name: bool = False): |
515 | 532 | rule.dates(start, end, return_name=True) for rule in self.rules |
516 | 533 | ] |
517 | 534 | if pre_holidays: |
518 | | - # error: Argument 1 to "concat" has incompatible type |
519 | | - # "List[Union[Series, DatetimeIndex]]"; expected |
520 | | - # "Union[Iterable[DataFrame], Mapping[<nothing>, DataFrame]]" |
521 | | - holidays = concat(pre_holidays) # type: ignore[arg-type] |
| 535 | + holidays = concat(pre_holidays) |
522 | 536 | else: |
523 | | - # error: Incompatible types in assignment (expression has type |
524 | | - # "Series", variable has type "DataFrame") |
525 | | - holidays = Series(index=DatetimeIndex([]), dtype=object) # type: ignore[assignment] |
| 537 | + holidays = Series(index=DatetimeIndex([]), dtype=object) |
526 | 538 |
|
527 | 539 | self._cache = (start, end, holidays.sort_index()) |
528 | 540 |
|
|
0 commit comments