Skip to content

Commit c0d5592

Browse files
authored
Add basic plugins (#49)
Fixes #23 Helps with #5 Fixes #6 Helps with #8 I would propose to merge this soon, and then iterate on discovered issues. This branch is very old, it is time to move on with some minimal version.
1 parent b772662 commit c0d5592

File tree

12 files changed

+513
-38
lines changed

12 files changed

+513
-38
lines changed

external/mypy

Submodule mypy updated 129 files

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def find_stub_files():
2222
author='Ivan Levkivskyi',
2323
author_email='levkivskyi@gmail.com',
2424
license='MIT License',
25-
py_modules=[],
25+
py_modules=['sqlmypy', 'sqltyping'],
2626
install_requires=[
2727
'typing-extensions>=3.6.5'
2828
],
2929
packages=['sqlalchemy-stubs'],
30-
package_data={'sqlalchemy-stubs': find_stub_files()}
30+
package_data={'sqlalchemy-stubs': find_stub_files()},
3131
)

sqlalchemy-plugin/sqlmypy.py

Whitespace-only changes.

sqlalchemy-stubs/orm/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ from .strategy_options import Load as Load
4343

4444
def create_session(bind: Optional[Any] = ..., **kwargs): ...
4545

46-
relationship = RelationshipProperty[Any]
46+
relationship = RelationshipProperty
4747

4848
def relation(*arg, **kw): ...
4949
def dynamic_loader(argument, **kw): ...

sqlalchemy-stubs/orm/relationships.pyi

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Generic, TypeVar, Union, overload
1+
from typing import Any, Optional, Generic, TypeVar, Union, overload, Type
22
from .interfaces import (
33
MANYTOMANY as MANYTOMANY,
44
MANYTOONE as MANYTOONE,
@@ -11,12 +11,11 @@ def remote(expr): ...
1111
def foreign(expr): ...
1212

1313

14-
_T = TypeVar('_T')
1514
_T_co = TypeVar('_T_co', covariant=True)
1615

1716

1817
# Note: typical use case is where argument is a string, so this will require
19-
# a plugin to infer '_T', otherwise a user will need to write an explicit annotation.
18+
# a plugin to infer '_T_co', otherwise a user will need to write an explicit annotation.
2019
# It is not clear whether RelationshipProperty is covariant at this stage since
2120
# many types are still missing.
2221
class RelationshipProperty(StrategizedProperty, Generic[_T_co]):
@@ -55,7 +54,7 @@ class RelationshipProperty(StrategizedProperty, Generic[_T_co]):
5554
order_by: Any = ...
5655
back_populates: Any = ...
5756
backref: Any = ...
58-
def __init__(self, argument, secondary: Optional[Any] = ...,
57+
def __init__(self, argument: Any, secondary: Optional[Any] = ...,
5958
primaryjoin: Optional[Any] = ..., secondaryjoin: Optional[Any] = ...,
6059
foreign_keys: Optional[Any] = ..., uselist: Optional[Any] = ...,
6160
order_by: Any = ..., backref: Optional[Any] = ...,

sqlmypy.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
2+
from mypy.plugins.common import add_method
3+
from mypy.nodes import(
4+
NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF,
5+
Argument, Var, ARG_STAR2
6+
)
7+
from mypy.types import (
8+
UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType
9+
)
10+
from mypy.typevars import fill_typevars_with_any
11+
12+
from typing import Optional, Callable, Dict, TYPE_CHECKING
13+
if TYPE_CHECKING:
14+
from typing_extensions import Final
15+
16+
COLUMN_NAME = 'sqlalchemy.sql.schema.Column' # type: Final
17+
RELATIONSHIP_NAME = 'sqlalchemy.orm.relationships.RelationshipProperty' # type: Final
18+
19+
20+
def is_declarative(info: TypeInfo) -> bool:
21+
"""Check if this is a subclass of a declarative base."""
22+
if info.mro:
23+
for base in info.mro:
24+
metadata = base.metadata.get('sqlalchemy')
25+
if metadata and metadata.get('declarative_base'):
26+
return True
27+
return False
28+
29+
30+
def set_declarative(info: TypeInfo) -> None:
31+
"""Record given class as a declarative base."""
32+
info.metadata.setdefault('sqlalchemy', {})['declarative_base'] = True
33+
34+
35+
class BasicSQLAlchemyPlugin(Plugin):
36+
"""Basic plugin to support simple operations with models.
37+
38+
Currently supported functionality:
39+
* Recognize dynamically defined declarative bases.
40+
* Add an __init__() method to models.
41+
* Provide better types for 'Column's and 'RelationshipProperty's
42+
using flags 'primary_key', 'nullable', 'uselist', etc.
43+
"""
44+
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]:
45+
if fullname == COLUMN_NAME:
46+
return column_hook
47+
if fullname == RELATIONSHIP_NAME:
48+
return relationship_hook
49+
sym = self.lookup_fully_qualified(fullname)
50+
if sym and isinstance(sym.node, TypeInfo):
51+
# May be a model instantiation
52+
if is_declarative(sym.node):
53+
return model_hook
54+
return None
55+
56+
def get_dynamic_class_hook(self, fullname):
57+
if fullname == 'sqlalchemy.ext.declarative.api.declarative_base':
58+
return decl_info_hook
59+
return None
60+
61+
def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
62+
if fullname == 'sqlalchemy.ext.declarative.api.as_declarative':
63+
return decl_deco_hook
64+
return None
65+
66+
def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
67+
sym = self.lookup_fully_qualified(fullname)
68+
if sym and isinstance(sym.node, TypeInfo):
69+
if is_declarative(sym.node):
70+
return add_init_hook
71+
return None
72+
73+
74+
def add_init_hook(ctx: ClassDefContext) -> None:
75+
"""Add a dummy __init__() to a model and record it is generated.
76+
77+
Instantiation will be checked more precisely when we inferred types
78+
(using get_function_hook and model_hook).
79+
"""
80+
if '__init__' in ctx.cls.info.names:
81+
# Don't override existing definition.
82+
return
83+
typ = AnyType(TypeOfAny.special_form)
84+
var = Var('kwargs', typ)
85+
kw_arg = Argument(variable=var, type_annotation=typ, initializer=None, kind=ARG_STAR2)
86+
add_method(ctx, '__init__', [kw_arg], NoneTyp())
87+
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True
88+
89+
90+
def decl_deco_hook(ctx: ClassDefContext) -> None:
91+
"""Support declaring base class as declarative with a decorator.
92+
93+
For example:
94+
from from sqlalchemy.ext.declarative import as_declarative
95+
96+
@as_declarative
97+
class Base:
98+
...
99+
"""
100+
set_declarative(ctx.cls.info)
101+
102+
103+
def decl_info_hook(ctx):
104+
"""Support dynamically defining declarative bases.
105+
106+
For example:
107+
from sqlalchemy.ext.declarative import declarative_base
108+
109+
Base = declarative_base()
110+
"""
111+
class_def = ClassDef(ctx.name, Block([]))
112+
class_def.fullname = ctx.api.qualified_name(ctx.name)
113+
114+
info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id)
115+
class_def.info = info
116+
obj = ctx.api.builtin_type('builtins.object')
117+
info.mro = [info, obj.type]
118+
info.bases = [obj]
119+
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
120+
set_declarative(info)
121+
122+
123+
def model_hook(ctx: FunctionContext) -> Type:
124+
"""More precise model instantiation check.
125+
126+
Note: sub-models are not supported.
127+
Note: this is still not perfect, since the context for inference of
128+
argument types is 'Any'.
129+
"""
130+
assert isinstance(ctx.default_return_type, Instance)
131+
model = ctx.default_return_type.type
132+
metadata = model.metadata.get('sqlalchemy')
133+
if not metadata or not metadata.get('generated_init'):
134+
return ctx.default_return_type
135+
136+
# Collect column names and types defined in the model
137+
# TODO: cache this?
138+
expected_types = {} # type: Dict[str, Type]
139+
for name, sym in model.names.items():
140+
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
141+
tp = sym.node.type
142+
if tp.type.fullname() in (COLUMN_NAME, RELATIONSHIP_NAME):
143+
assert len(tp.args) == 1
144+
expected_types[name] = tp.args[0]
145+
146+
assert len(ctx.arg_names) == 1 # only **kwargs in generated __init__
147+
assert len(ctx.arg_types) == 1
148+
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
149+
if actual_name not in expected_types:
150+
ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name, model.name()),
151+
ctx.context)
152+
continue
153+
# Using private API to simplify life.
154+
ctx.api.check_subtype(actual_type, expected_types[actual_name],
155+
ctx.context,
156+
'Incompatible type for "{}" of "{}"'.format(actual_name, model.name()),
157+
'got', 'expected')
158+
return ctx.default_return_type
159+
160+
161+
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
162+
"""Return the expression for the specific argument.
163+
164+
This helper should only be used with non-star arguments.
165+
"""
166+
if name not in ctx.callee_arg_names:
167+
return None
168+
idx = ctx.callee_arg_names.index(name)
169+
args = ctx.args[idx]
170+
if len(args) != 1:
171+
# Either an error or no value passed.
172+
return None
173+
return args[0]
174+
175+
176+
def get_argtype_by_name(ctx: FunctionContext, name: str) -> Optional[Type]:
177+
"""Same as above but for argument type."""
178+
if name not in ctx.callee_arg_names:
179+
return None
180+
idx = ctx.callee_arg_names.index(name)
181+
arg_types = ctx.arg_types[idx]
182+
if len(arg_types) != 1:
183+
# Either an error or no value passed.
184+
return None
185+
return arg_types[0]
186+
187+
188+
def column_hook(ctx: FunctionContext) -> Type:
189+
"""Infer better types for Column calls.
190+
191+
Examples:
192+
Column(String) -> Column[Optional[str]]
193+
Column(String, primary_key=True) -> Column[str]
194+
Column(String, nullable=False) -> Column[str]
195+
Column(String, default=...) -> Column[str]
196+
Column(String, default=..., nullable=True) -> Column[Optional[str]]
197+
198+
TODO: check the type of 'default'.
199+
"""
200+
assert isinstance(ctx.default_return_type, Instance)
201+
202+
nullable_arg = get_argument_by_name(ctx, 'nullable')
203+
primary_arg = get_argument_by_name(ctx, 'primary_key')
204+
default_arg = get_argument_by_name(ctx, 'default')
205+
206+
if nullable_arg:
207+
nullable = parse_bool(nullable_arg)
208+
else:
209+
if primary_arg:
210+
nullable = not parse_bool(primary_arg)
211+
else:
212+
nullable = default_arg is None
213+
# TODO: Add support for literal types.
214+
215+
if not nullable:
216+
return ctx.default_return_type
217+
assert len(ctx.default_return_type.args) == 1
218+
arg_type = ctx.default_return_type.args[0]
219+
return Instance(ctx.default_return_type.type, [UnionType([arg_type, NoneTyp()])],
220+
line=ctx.default_return_type.line,
221+
column=ctx.default_return_type.column)
222+
223+
224+
def relationship_hook(ctx: FunctionContext) -> Type:
225+
"""Support basic use cases for relationships.
226+
227+
Examples:
228+
from sqlalchemy.orm import relationship
229+
230+
from one import OneModel
231+
if TYPE_CHECKING:
232+
from other import OtherModel
233+
234+
class User(Base):
235+
__tablename__ = 'users'
236+
id = Column(Integer(), primary_key=True)
237+
one = relationship(OneModel)
238+
other = relationship("OtherModel")
239+
240+
This also tries to infer the type argument for 'RelationshipProperty'
241+
using the 'uselist' flag.
242+
"""
243+
assert isinstance(ctx.default_return_type, Instance)
244+
original_type_arg = ctx.default_return_type.args[0]
245+
has_annotation = not isinstance(original_type_arg, UninhabitedType)
246+
247+
arg = get_argument_by_name(ctx, 'argument')
248+
arg_type = get_argtype_by_name(ctx, 'argument')
249+
250+
uselist_arg = get_argument_by_name(ctx, 'uselist')
251+
252+
if isinstance(arg, StrExpr):
253+
name = arg.value
254+
# Private API for local lookup, but probably needs to be public.
255+
try:
256+
sym = ctx.api.lookup_qualified(name) # type: Optional[SymbolTableNode]
257+
except (KeyError, AssertionError):
258+
sym = None
259+
if sym and isinstance(sym.node, TypeInfo):
260+
new_arg = fill_typevars_with_any(sym.node)
261+
else:
262+
ctx.api.fail('Cannot find model "{}"'.format(name), ctx.context)
263+
ctx.api.note('Only imported models can be found; use "if TYPE_CHECKING: ..." to avoid import cycles', ctx.context)
264+
new_arg = AnyType(TypeOfAny.from_error)
265+
else:
266+
if isinstance(arg_type, CallableType) and arg_type.is_type_obj():
267+
new_arg = fill_typevars_with_any(arg_type.type_object())
268+
else:
269+
# Something complex, stay silent for now.
270+
new_arg = AnyType(TypeOfAny.special_form)
271+
272+
# We figured out, the model type. Now check if we need to wrap it in Iterable
273+
if uselist_arg:
274+
if parse_bool(uselist_arg):
275+
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg])
276+
else:
277+
if has_annotation:
278+
# If there is an annotation we use it as a source of truth.
279+
# This will cause false negatives, but it is better than lots of false positives.
280+
new_arg = original_type_arg
281+
282+
return Instance(ctx.default_return_type.type, [new_arg],
283+
line=ctx.default_return_type.line,
284+
column=ctx.default_return_type.column)
285+
286+
287+
# We really need to add this to TypeChecker API
288+
def parse_bool(expr: Expression) -> Optional[bool]:
289+
if isinstance(expr, NameExpr):
290+
if expr.fullname == 'builtins.True':
291+
return True
292+
if expr.fullname == 'builtins.False':
293+
return False
294+
return None
295+
296+
297+
def plugin(version):
298+
return BasicSQLAlchemyPlugin
File renamed without changes.

test/sqlalchemy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[mypy]
2+
plugins = sqlmypy
3+

0 commit comments

Comments
 (0)