|
1 | 1 | import builtins |
| 2 | +import logging |
2 | 3 | from collections.abc import Iterable |
3 | 4 | from typing import Any, Generic, TypeVar |
4 | 5 |
|
|
8 | 9 | from django.contrib.messages.views import SuccessMessageMixin |
9 | 10 | from django.contrib.sitemaps import Sitemap |
10 | 11 | from django.contrib.syndication.views import Feed |
| 12 | +from django.core.exceptions import AppRegistryNotReady, ImproperlyConfigured |
11 | 13 | from django.core.files.utils import FileProxyMixin |
12 | 14 | from django.core.paginator import Paginator |
13 | 15 | from django.db.models.expressions import ExpressionWrapper |
14 | 16 | from django.db.models.fields import Field |
15 | 17 | from django.db.models.fields.related import ForeignKey |
16 | | -from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor |
| 18 | +from django.db.models.fields.related_descriptors import ( |
| 19 | + ForwardManyToOneDescriptor, |
| 20 | + ReverseManyToOneDescriptor, |
| 21 | + ReverseOneToOneDescriptor, |
| 22 | +) |
17 | 23 | from django.db.models.lookups import Lookup |
18 | 24 | from django.db.models.manager import BaseManager |
19 | | -from django.db.models.query import ModelIterable, QuerySet, RawQuerySet |
| 25 | +from django.db.models.options import Options |
| 26 | +from django.db.models.query import BaseIterable, ModelIterable, QuerySet, RawQuerySet |
20 | 27 | from django.forms.formsets import BaseFormSet |
21 | | -from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField |
22 | | -from django.utils.connection import BaseConnectionHandler |
| 28 | +from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField, ModelFormOptions |
| 29 | +from django.utils.connection import BaseConnectionHandler, ConnectionProxy |
| 30 | +from django.utils.functional import classproperty |
23 | 31 | from django.views.generic.detail import SingleObjectMixin |
24 | 32 | from django.views.generic.edit import DeletionMixin, FormMixin |
25 | 33 | from django.views.generic.list import MultipleObjectMixin |
26 | 34 |
|
27 | 35 | __all__ = ["monkeypatch"] |
28 | 36 |
|
| 37 | +logger = logging.getLogger(__name__) |
| 38 | + |
29 | 39 | _T = TypeVar("_T") |
30 | 40 | _VersionSpec = tuple[int, int] |
31 | 41 |
|
@@ -81,16 +91,41 @@ def __repr__(self) -> str: |
81 | 91 | # These types do have native `__class_getitem__` method since django 4.1: |
82 | 92 | MPGeneric(ForeignKey, (4, 1)), |
83 | 93 | MPGeneric(RawQuerySet), |
| 94 | + MPGeneric(classproperty), |
| 95 | + MPGeneric(ConnectionProxy), |
| 96 | + MPGeneric(ModelFormOptions), |
| 97 | + MPGeneric(Options), |
| 98 | + MPGeneric(BaseIterable), |
| 99 | + MPGeneric(ForwardManyToOneDescriptor), |
| 100 | + MPGeneric(ReverseOneToOneDescriptor), |
84 | 101 | ] |
85 | 102 |
|
86 | 103 |
|
| 104 | +def _get_need_generic() -> list[MPGeneric[Any]]: |
| 105 | + try: |
| 106 | + if VERSION >= (5, 1): |
| 107 | + from django.contrib.auth.forms import SetPasswordMixin, SetUnusablePasswordMixin |
| 108 | + |
| 109 | + return [MPGeneric(SetPasswordMixin), MPGeneric(SetUnusablePasswordMixin), *_need_generic] |
| 110 | + else: |
| 111 | + from django.contrib.auth.forms import AdminPasswordChangeForm, SetPasswordForm |
| 112 | + |
| 113 | + return [MPGeneric(SetPasswordForm), MPGeneric(AdminPasswordChangeForm), *_need_generic] |
| 114 | + |
| 115 | + except (ImproperlyConfigured, AppRegistryNotReady): |
| 116 | + # We cannot patch symbols in `django.contrib.auth.forms` if the `monkeypatch()` call |
| 117 | + # is in the settings file because django is not initialized yet. |
| 118 | + # To solve this, you'll have to call `monkeypatch()` again later, in an `AppConfig.ready` for ex. |
| 119 | + # See https://docs.djangoproject.com/en/5.2/ref/applications/#django.apps.AppConfig.ready |
| 120 | + return _need_generic |
| 121 | + |
| 122 | + |
87 | 123 | def monkeypatch(extra_classes: Iterable[type] | None = None, include_builtins: bool = True) -> None: |
88 | 124 | """Monkey patch django as necessary to work properly with mypy.""" |
89 | | - |
90 | 125 | # Add the __class_getitem__ dunder. |
91 | 126 | suited_for_this_version = filter( |
92 | 127 | lambda spec: spec.version is None or VERSION[:2] <= spec.version, |
93 | | - _need_generic, |
| 128 | + _get_need_generic(), |
94 | 129 | ) |
95 | 130 | for el in suited_for_this_version: |
96 | 131 | el.cls.__class_getitem__ = classmethod(lambda cls, *args, **kwargs: cls) |
|
0 commit comments