Skip to content

Commit 9efe7c4

Browse files
mehdigmirailevkivskyi
authored andcommitted
Plugin: handle declarative_base cls argument (#57)
SQLAlchemy's `declarative_base` has a `cls` argument that is a class/a tuple of classes that the generated class inherits from.
1 parent 66c0312 commit 9efe7c4

File tree

2 files changed

+83
-5
lines changed

2 files changed

+83
-5
lines changed

sqlmypy.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
from mypy.mro import calculate_mro, MroError
12
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
23
from mypy.plugins.common import add_method
3-
from mypy.nodes import(
4+
from mypy.nodes import (
45
NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF,
5-
Argument, Var, ARG_STAR2, MDEF
6+
Argument, Var, ARG_STAR2, MDEF, TupleExpr, RefExpr
67
)
78
from mypy.types import (
89
UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType
910
)
1011
from mypy.typevars import fill_typevars_with_any
1112

12-
from typing import Optional, Callable, Dict, TYPE_CHECKING
13+
from typing import Optional, Callable, Dict, TYPE_CHECKING, List
1314
if TYPE_CHECKING:
1415
from typing_extensions import Final
1516

@@ -141,14 +142,37 @@ def decl_info_hook(ctx):
141142
142143
Base = declarative_base()
143144
"""
145+
cls_bases = [] # type: List[Instance]
146+
147+
# Passing base classes as positional arguments is currently not handled.
148+
if 'cls' in ctx.call.arg_names:
149+
declarative_base_cls_arg = ctx.call.args[ctx.call.arg_names.index("cls")]
150+
if isinstance(declarative_base_cls_arg, TupleExpr):
151+
items = [item for item in declarative_base_cls_arg.items]
152+
else:
153+
items = [declarative_base_cls_arg]
154+
155+
for item in items:
156+
if isinstance(item, RefExpr) and isinstance(item.node, TypeInfo):
157+
base = fill_typevars_with_any(item.node)
158+
# TODO: Support tuple types?
159+
if isinstance(base, Instance):
160+
cls_bases.append(base)
161+
144162
class_def = ClassDef(ctx.name, Block([]))
145163
class_def.fullname = ctx.api.qualified_name(ctx.name)
146164

147165
info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id)
148166
class_def.info = info
149167
obj = ctx.api.builtin_type('builtins.object')
150-
info.mro = [info, obj.type]
151-
info.bases = [obj]
168+
info.bases = cls_bases or [obj]
169+
try:
170+
calculate_mro(info)
171+
except MroError:
172+
ctx.api.fail("Not able to calculate MRO for declarative base", ctx.call)
173+
info.bases = [obj]
174+
info.fallback_to_any = True
175+
152176
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
153177
set_declarative(info)
154178

test/test-data/sqlalchemy-plugin-features.test

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,57 @@ class User(Base):
219219
record = {'name': 'John Doe'}
220220
User(**record) # OK
221221
[out]
222+
223+
[case testDeclarativeBaseWithBaseClass]
224+
from sqlalchemy import Column, Integer, String
225+
from base import Base
226+
227+
class User(Base):
228+
__tablename__ = 'users'
229+
id = Column(Integer(), primary_key=True)
230+
name = Column(String())
231+
232+
user: User
233+
reveal_type(user.f()) # E: Revealed type is 'builtins.str'
234+
235+
[file base.py]
236+
from sqlalchemy.ext.declarative import declarative_base
237+
238+
class Model:
239+
def f(self) -> str: ...
240+
Base = declarative_base(cls=Model)
241+
[out]
242+
243+
[case testDeclarativeBaseWithMultipleBaseClasses]
244+
from sqlalchemy import Column, Integer, String
245+
from base import Base
246+
247+
class User(Base):
248+
__tablename__ = 'users'
249+
id = Column(Integer(), primary_key=True)
250+
name = Column(String())
251+
252+
user: User
253+
reveal_type(user.f()) # E: Revealed type is 'builtins.str'
254+
reveal_type(user.g()) # E: Revealed type is 'builtins.int'
255+
256+
[file base.py]
257+
from sqlalchemy.ext.declarative import declarative_base
258+
259+
class Model:
260+
def f(self) -> str: ...
261+
class Model2:
262+
def g(self) -> int: ...
263+
Base = declarative_base(cls=(Model, Model2))
264+
[out]
265+
266+
[case testDeclarativeBaseWithBaseClassWrongMRO]
267+
from sqlalchemy.ext.declarative import declarative_base
268+
269+
class M1:
270+
...
271+
class M2(M1):
272+
...
273+
Base = declarative_base(cls=(M1, M2)) # E: Not able to calculate MRO for declarative base
274+
reveal_type(Base) # E: Revealed type is 'Any'
275+
[out]

0 commit comments

Comments
 (0)