|
| 1 | +from mypy.mro import calculate_mro, MroError |
1 | 2 | from mypy.plugin import Plugin, FunctionContext, ClassDefContext |
2 | 3 | from mypy.plugins.common import add_method |
3 | | -from mypy.nodes import( |
| 4 | +from mypy.nodes import ( |
4 | 5 | NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, |
5 | | - Argument, Var, ARG_STAR2, MDEF |
| 6 | + Argument, Var, ARG_STAR2, MDEF, TupleExpr, RefExpr |
6 | 7 | ) |
7 | 8 | from mypy.types import ( |
8 | 9 | UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType |
9 | 10 | ) |
10 | 11 | from mypy.typevars import fill_typevars_with_any |
11 | 12 |
|
12 | | -from typing import Optional, Callable, Dict, TYPE_CHECKING |
| 13 | +from typing import Optional, Callable, Dict, TYPE_CHECKING, List |
13 | 14 | if TYPE_CHECKING: |
14 | 15 | from typing_extensions import Final |
15 | 16 |
|
@@ -141,14 +142,37 @@ def decl_info_hook(ctx): |
141 | 142 |
|
142 | 143 | Base = declarative_base() |
143 | 144 | """ |
| 145 | + cls_bases = [] # type: List[Instance] |
| 146 | + |
| 147 | + # Passing base classes as positional arguments is currently not handled. |
| 148 | + if 'cls' in ctx.call.arg_names: |
| 149 | + declarative_base_cls_arg = ctx.call.args[ctx.call.arg_names.index("cls")] |
| 150 | + if isinstance(declarative_base_cls_arg, TupleExpr): |
| 151 | + items = [item for item in declarative_base_cls_arg.items] |
| 152 | + else: |
| 153 | + items = [declarative_base_cls_arg] |
| 154 | + |
| 155 | + for item in items: |
| 156 | + if isinstance(item, RefExpr) and isinstance(item.node, TypeInfo): |
| 157 | + base = fill_typevars_with_any(item.node) |
| 158 | + # TODO: Support tuple types? |
| 159 | + if isinstance(base, Instance): |
| 160 | + cls_bases.append(base) |
| 161 | + |
144 | 162 | class_def = ClassDef(ctx.name, Block([])) |
145 | 163 | class_def.fullname = ctx.api.qualified_name(ctx.name) |
146 | 164 |
|
147 | 165 | info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id) |
148 | 166 | class_def.info = info |
149 | 167 | obj = ctx.api.builtin_type('builtins.object') |
150 | | - info.mro = [info, obj.type] |
151 | | - info.bases = [obj] |
| 168 | + info.bases = cls_bases or [obj] |
| 169 | + try: |
| 170 | + calculate_mro(info) |
| 171 | + except MroError: |
| 172 | + ctx.api.fail("Not able to calculate MRO for declarative base", ctx.call) |
| 173 | + info.bases = [obj] |
| 174 | + info.fallback_to_any = True |
| 175 | + |
152 | 176 | ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) |
153 | 177 | set_declarative(info) |
154 | 178 |
|
|
0 commit comments