Skip to content

Commit e41ffed

Browse files
authored
Support inheriting ManyToManyField from an abstract model (#2260)
1 parent a16b9ba commit e41ffed

File tree

5 files changed

+157
-65
lines changed

5 files changed

+157
-65
lines changed

mypy_django_plugin/lib/helpers.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
MemberExpr,
1919
MypyFile,
2020
NameExpr,
21+
RefExpr,
2122
StrExpr,
2223
SymbolNode,
2324
SymbolTable,
@@ -497,3 +498,38 @@ def resolve_lazy_reference(
497498

498499
def is_model_type(info: TypeInfo) -> bool:
499500
return info.metaclass_type is not None and info.metaclass_type.type.has_base(fullnames.MODEL_METACLASS_FULLNAME)
501+
502+
503+
def get_model_from_expression(
504+
expr: Expression,
505+
*,
506+
self_model: TypeInfo,
507+
api: Union[TypeChecker, SemanticAnalyzer],
508+
django_context: "DjangoContext",
509+
) -> Optional[Instance]:
510+
"""
511+
Attempts to resolve an expression to a 'TypeInfo' instance. Any lazy reference
512+
argument(e.g. "<app_label>.<object_name>") to a Django model is also attempted.
513+
"""
514+
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo):
515+
if is_model_type(expr.node):
516+
return Instance(expr.node, [])
517+
518+
if isinstance(expr, StrExpr) and expr.value == "self":
519+
return Instance(self_model, [])
520+
521+
lazy_reference = None
522+
if isinstance(expr, StrExpr):
523+
lazy_reference = expr.value
524+
elif (
525+
isinstance(expr, MemberExpr)
526+
and isinstance(expr.expr, NameExpr)
527+
and f"{expr.expr.fullname}.{expr.name}" == fullnames.AUTH_USER_MODEL_FULLNAME
528+
):
529+
lazy_reference = django_context.settings.AUTH_USER_MODEL
530+
531+
if lazy_reference is not None:
532+
model_info = resolve_lazy_reference(lazy_reference, api=api, django_context=django_context, ctx=expr)
533+
if model_info is not None:
534+
return Instance(model_info, [])
535+
return None

mypy_django_plugin/transformers/manytomany.py

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import NamedTuple, Optional, Tuple, Union
1+
from typing import NamedTuple, Optional, Tuple
22

3-
from mypy.checker import TypeChecker
4-
from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, Node, RefExpr, StrExpr, TypeInfo
3+
from mypy.nodes import AssignmentStmt, NameExpr, Node, TypeInfo
54
from mypy.plugin import FunctionContext, MethodContext
6-
from mypy.semanal import SemanticAnalyzer
75
from mypy.types import Instance, ProperType, UninhabitedType
86
from mypy.types import Type as MypyType
97

@@ -72,24 +70,21 @@ def get_m2m_arguments(
7270
) -> Optional[M2MArguments]:
7371
checker = helpers.get_typechecker_api(ctx)
7472
to_arg = ctx.args[0][0]
75-
to_model: Optional[ProperType]
76-
if isinstance(to_arg, StrExpr) and to_arg.value == "self":
77-
to_model = Instance(model_info, [])
78-
to_self = True
79-
else:
80-
to_model = get_model_from_expression(to_arg, api=checker, django_context=django_context)
81-
to_self = False
82-
73+
to_model = helpers.get_model_from_expression(
74+
to_arg, self_model=model_info, api=checker, django_context=django_context
75+
)
8376
if to_model is None:
8477
# 'ManyToManyField()' requires the 'to' argument
8578
return None
86-
to = M2MTo(arg=to_arg, model=to_model, self=to_self)
79+
to = M2MTo(arg=to_arg, model=to_model, self=to_model.type == model_info)
8780

8881
through = None
8982
if len(ctx.args) > 5 and ctx.args[5]:
9083
# 'ManyToManyField(..., through=)' was called
9184
through_arg = ctx.args[5][0]
92-
through_model = get_model_from_expression(through_arg, api=checker, django_context=django_context)
85+
through_model = helpers.get_model_from_expression(
86+
through_arg, self_model=model_info, api=checker, django_context=django_context
87+
)
9388
if through_model is not None:
9489
through = M2MThrough(arg=through_arg, model=through_model)
9590
elif not helpers.is_abstract_model(model_info):
@@ -119,37 +114,6 @@ def get_m2m_arguments(
119114
return M2MArguments(to=to, through=through)
120115

121116

122-
def get_model_from_expression(
123-
expr: Expression,
124-
*,
125-
api: Union[TypeChecker, SemanticAnalyzer],
126-
django_context: DjangoContext,
127-
) -> Optional[ProperType]:
128-
"""
129-
Attempts to resolve an expression to a 'TypeInfo' instance. Any lazy reference
130-
argument(e.g. "<app_label>.<object_name>") to a Django model is also attempted.
131-
"""
132-
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo):
133-
if helpers.is_model_type(expr.node):
134-
return Instance(expr.node, [])
135-
136-
lazy_reference = None
137-
if isinstance(expr, StrExpr):
138-
lazy_reference = expr.value
139-
elif (
140-
isinstance(expr, MemberExpr)
141-
and isinstance(expr.expr, NameExpr)
142-
and f"{expr.expr.fullname}.{expr.name}" == fullnames.AUTH_USER_MODEL_FULLNAME
143-
):
144-
lazy_reference = django_context.settings.AUTH_USER_MODEL
145-
146-
if lazy_reference is not None:
147-
model_info = helpers.resolve_lazy_reference(lazy_reference, api=api, django_context=django_context, ctx=expr)
148-
if model_info is not None:
149-
return Instance(model_info, [])
150-
return None
151-
152-
153117
def get_related_manager_and_model(ctx: MethodContext) -> Optional[Tuple[Instance, Instance, Instance]]:
154118
"""
155119
Returns a 3-tuple consisting of:

mypy_django_plugin/transformers/models.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
NameExpr,
1919
RefExpr,
2020
Statement,
21-
StrExpr,
2221
SymbolTableNode,
2322
TypeInfo,
2423
Var,
@@ -41,7 +40,7 @@
4140
MANAGER_METHODS_RETURNING_QUERYSET,
4241
create_manager_info_from_from_queryset_call,
4342
)
44-
from mypy_django_plugin.transformers.manytomany import M2MArguments, M2MThrough, M2MTo, get_model_from_expression
43+
from mypy_django_plugin.transformers.manytomany import M2MArguments, M2MThrough, M2MTo
4544

4645

4746
class ModelClassInitializer:
@@ -677,12 +676,27 @@ def run(self) -> None:
677676
continue
678677
# Get the names of the implicit through model that will be generated
679678
through_model_name = f"{self.model_classdef.name}_{m2m_field_name}"
680-
self.create_through_table_class(
679+
through_model = self.create_through_table_class(
681680
field_name=m2m_field_name,
682681
model_name=through_model_name,
683682
model_fullname=f"{self.model_classdef.info.module_name}.{through_model_name}",
684683
m2m_args=args,
685684
)
685+
container = self.model_classdef.info.get_containing_type_info(m2m_field_name)
686+
if (
687+
through_model is not None
688+
and container is not None
689+
and container.fullname != self.model_classdef.info.fullname
690+
and helpers.is_abstract_model(container)
691+
):
692+
# ManyToManyField is inherited from an abstract parent class, so in
693+
# order to get the to and the through model argument right we
694+
# override the ManyToManyField attribute on the current class
695+
helpers.add_new_sym_for_info(
696+
self.model_classdef.info,
697+
name=m2m_field_name,
698+
sym_type=Instance(self.m2m_field, [args.to.model, Instance(through_model, [])]),
699+
)
686700
# Create a 'ManyRelatedManager' class for the processed model
687701
self.create_many_related_manager(Instance(self.model_classdef.info, []))
688702
if isinstance(args.to.model, Instance):
@@ -717,6 +731,13 @@ def fk_field(self) -> TypeInfo:
717731
raise helpers.IncompleteDefnException()
718732
return info
719733

734+
@cached_property
735+
def m2m_field(self) -> TypeInfo:
736+
info = self.lookup_typeinfo(fullnames.MANYTOMANY_FIELD_FULLNAME)
737+
if info is None:
738+
raise helpers.IncompleteDefnException()
739+
return info
740+
720741
@cached_property
721742
def manager_info(self) -> TypeInfo:
722743
info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
@@ -746,18 +767,17 @@ def get_pk_instance(self, model: TypeInfo, /) -> Instance:
746767

747768
def create_through_table_class(
748769
self, field_name: str, model_name: str, model_fullname: str, m2m_args: M2MArguments
749-
) -> None:
750-
if (
751-
not isinstance(m2m_args.to.model, Instance)
770+
) -> Optional[TypeInfo]:
771+
if not isinstance(m2m_args.to.model, Instance):
772+
return None
773+
elif m2m_args.through is not None:
752774
# Call has explicit 'through=', no need to create any implicit through table
753-
or m2m_args.through is not None
754-
):
755-
return
775+
return m2m_args.through.model.type if isinstance(m2m_args.through.model, Instance) else None
756776

757777
# If through model is already declared there's nothing more we should do
758778
through_model = self.lookup_typeinfo(model_fullname)
759779
if through_model is not None:
760-
return
780+
return through_model
761781
# Declare a new, empty, implicitly generated through model class named: '<Model>_<field_name>'
762782
through_model = self.add_new_class_for_current_module(model_name, bases=[Instance(self.model_base, [])])
763783
# We attempt to be a bit clever here and store the generated through model's fullname in
@@ -823,6 +843,7 @@ def create_through_table_class(
823843
sym_type=Instance(self.manager_info, [Instance(through_model, [])]),
824844
is_classvar=True,
825845
)
846+
return through_model
826847

827848
def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> Optional[M2MArguments]:
828849
"""
@@ -848,22 +869,24 @@ def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) ->
848869
return None
849870

850871
# Resolve the type of the 'to' argument expression
851-
to_model: Optional[ProperType]
852-
if isinstance(to_arg, StrExpr) and to_arg.value == "self":
853-
to_model = Instance(self.model_classdef.info, [])
854-
to_self = True
855-
else:
856-
to_model = get_model_from_expression(to_arg, api=self.api, django_context=self.django_context)
857-
to_self = False
872+
to_model = helpers.get_model_from_expression(
873+
to_arg, self_model=self.model_classdef.info, api=self.api, django_context=self.django_context
874+
)
858875
if to_model is None:
859876
return None
860-
to = M2MTo(arg=to_arg, model=to_model, self=to_self)
877+
to = M2MTo(
878+
arg=to_arg,
879+
model=to_model,
880+
self=to_model.type == self.model_classdef.info,
881+
)
861882

862883
# Resolve the type of the 'through' argument expression
863884
through_arg = look_for["through"]
864885
through = None
865886
if through_arg is not None:
866-
through_model = get_model_from_expression(through_arg, api=self.api, django_context=self.django_context)
887+
through_model = helpers.get_model_from_expression(
888+
through_arg, self_model=self.model_classdef.info, api=self.api, django_context=self.django_context
889+
)
867890
if through_model is not None:
868891
through = M2MThrough(arg=through_arg, model=through_model)
869892

tests/typecheck/fields/test_related.yml

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,3 +1403,72 @@
14031403
class MyModel(models.Model):
14041404
m2m_1 = models.ManyToManyField(other_models.Other, related_name="auto_through")
14051405
m2m_2 = models.ManyToManyField(other_models.Other, related_name="custom_through", through=Through)
1406+
1407+
- case: test_m2m_from_abstract_model
1408+
main: |
1409+
from myapp.models import First, Second
1410+
reveal_type(First().others) # N: Revealed type is "myapp.models.Other_ManyRelatedManager[myapp.models.First_others]"
1411+
reveal_type(First().others.get()) # N: Revealed type is "myapp.models.Other"
1412+
reveal_type(First.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Other, myapp.models.First_others]"
1413+
reveal_type(First.others.through) # N: Revealed type is "Type[myapp.models.First_others]"
1414+
reveal_type(First.others.through.objects.get()) # N: Revealed type is "myapp.models.First_others"
1415+
1416+
reveal_type(Second().others) # N: Revealed type is "myapp.models.Other_ManyRelatedManager[myapp.models.Second_others]"
1417+
reveal_type(Second().others.get()) # N: Revealed type is "myapp.models.Other"
1418+
reveal_type(Second.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Other, myapp.models.Second_others]"
1419+
reveal_type(Second.others.through) # N: Revealed type is "Type[myapp.models.Second_others]"
1420+
reveal_type(Second.others.through.objects.get()) # N: Revealed type is "myapp.models.Second_others"
1421+
installed_apps:
1422+
- myapp
1423+
files:
1424+
- path: myapp/__init__.py
1425+
- path: myapp/models.py
1426+
content: |
1427+
from django.db import models
1428+
class Other(models.Model):
1429+
...
1430+
1431+
class Parent(models.Model):
1432+
others = models.ManyToManyField(Other)
1433+
1434+
class Meta:
1435+
abstract = True
1436+
1437+
class First(Parent):
1438+
...
1439+
1440+
class Second(Parent):
1441+
...
1442+
1443+
- case: test_m2m_self_on_abstract_model
1444+
main: |
1445+
from myapp.models import First, Second
1446+
reveal_type(First().others) # N: Revealed type is "myapp.models.First_ManyRelatedManager[myapp.models.First_others]"
1447+
reveal_type(First().others.get()) # N: Revealed type is "myapp.models.First"
1448+
reveal_type(First.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.First, myapp.models.First_others]"
1449+
reveal_type(First.others.through) # N: Revealed type is "Type[myapp.models.First_others]"
1450+
reveal_type(First.others.through.objects.get()) # N: Revealed type is "myapp.models.First_others"
1451+
1452+
reveal_type(Second().others) # N: Revealed type is "myapp.models.Second_ManyRelatedManager[myapp.models.Second_others]"
1453+
reveal_type(Second().others.get()) # N: Revealed type is "myapp.models.Second"
1454+
reveal_type(Second.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Second, myapp.models.Second_others]"
1455+
reveal_type(Second.others.through) # N: Revealed type is "Type[myapp.models.Second_others]"
1456+
reveal_type(Second.others.through.objects.get()) # N: Revealed type is "myapp.models.Second_others"
1457+
installed_apps:
1458+
- myapp
1459+
files:
1460+
- path: myapp/__init__.py
1461+
- path: myapp/models.py
1462+
content: |
1463+
from django.db import models
1464+
class Parent(models.Model):
1465+
others = models.ManyToManyField("self")
1466+
1467+
class Meta:
1468+
abstract = True
1469+
1470+
class First(Parent):
1471+
...
1472+
1473+
class Second(Parent):
1474+
...

tests/typecheck/models/test_contrib_models.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
reveal_type(User().is_anonymous) # N: Revealed type is "Literal[False]"
1616
reveal_type(User().groups.get()) # N: Revealed type is "django.contrib.auth.models.Group"
1717
reveal_type(User().user_permissions.get()) # N: Revealed type is "django.contrib.auth.models.Permission"
18-
reveal_type(User.groups) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Group, django.db.models.base.Model]"
18+
reveal_type(User.groups) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Group, django.contrib.auth.models.User_groups]"
1919
reveal_type(User.user_permissions) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Permission, django.db.models.base.Model]"
2020
2121
from django.contrib.auth.models import AnonymousUser

0 commit comments

Comments
 (0)