|
| 1 | +from mypy.plugin import Plugin, FunctionContext, ClassDefContext |
| 2 | +from mypy.plugins.common import add_method |
| 3 | +from mypy.nodes import( |
| 4 | + NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, |
| 5 | + Argument, Var, ARG_STAR2 |
| 6 | +) |
| 7 | +from mypy.types import ( |
| 8 | + UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType |
| 9 | +) |
| 10 | +from mypy.typevars import fill_typevars_with_any |
| 11 | + |
| 12 | +from typing import Optional, Callable, Dict, TYPE_CHECKING |
| 13 | +if TYPE_CHECKING: |
| 14 | + from typing_extensions import Final |
| 15 | + |
| 16 | +COLUMN_NAME = 'sqlalchemy.sql.schema.Column' # type: Final |
| 17 | +RELATIONSHIP_NAME = 'sqlalchemy.orm.relationships.RelationshipProperty' # type: Final |
| 18 | + |
| 19 | + |
| 20 | +def is_declarative(info: TypeInfo) -> bool: |
| 21 | + """Check if this is a subclass of a declarative base.""" |
| 22 | + if info.mro: |
| 23 | + for base in info.mro: |
| 24 | + metadata = base.metadata.get('sqlalchemy') |
| 25 | + if metadata and metadata.get('declarative_base'): |
| 26 | + return True |
| 27 | + return False |
| 28 | + |
| 29 | + |
| 30 | +def set_declarative(info: TypeInfo) -> None: |
| 31 | + """Record given class as a declarative base.""" |
| 32 | + info.metadata.setdefault('sqlalchemy', {})['declarative_base'] = True |
| 33 | + |
| 34 | + |
| 35 | +class BasicSQLAlchemyPlugin(Plugin): |
| 36 | + """Basic plugin to support simple operations with models. |
| 37 | +
|
| 38 | + Currently supported functionality: |
| 39 | + * Recognize dynamically defined declarative bases. |
| 40 | + * Add an __init__() method to models. |
| 41 | + * Provide better types for 'Column's and 'RelationshipProperty's |
| 42 | + using flags 'primary_key', 'nullable', 'uselist', etc. |
| 43 | + """ |
| 44 | + def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: |
| 45 | + if fullname == COLUMN_NAME: |
| 46 | + return column_hook |
| 47 | + if fullname == RELATIONSHIP_NAME: |
| 48 | + return relationship_hook |
| 49 | + sym = self.lookup_fully_qualified(fullname) |
| 50 | + if sym and isinstance(sym.node, TypeInfo): |
| 51 | + # May be a model instantiation |
| 52 | + if is_declarative(sym.node): |
| 53 | + return model_hook |
| 54 | + return None |
| 55 | + |
| 56 | + def get_dynamic_class_hook(self, fullname): |
| 57 | + if fullname == 'sqlalchemy.ext.declarative.api.declarative_base': |
| 58 | + return decl_info_hook |
| 59 | + return None |
| 60 | + |
| 61 | + def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: |
| 62 | + if fullname == 'sqlalchemy.ext.declarative.api.as_declarative': |
| 63 | + return decl_deco_hook |
| 64 | + return None |
| 65 | + |
| 66 | + def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: |
| 67 | + sym = self.lookup_fully_qualified(fullname) |
| 68 | + if sym and isinstance(sym.node, TypeInfo): |
| 69 | + if is_declarative(sym.node): |
| 70 | + return add_init_hook |
| 71 | + return None |
| 72 | + |
| 73 | + |
| 74 | +def add_init_hook(ctx: ClassDefContext) -> None: |
| 75 | + """Add a dummy __init__() to a model and record it is generated. |
| 76 | +
|
| 77 | + Instantiation will be checked more precisely when we inferred types |
| 78 | + (using get_function_hook and model_hook). |
| 79 | + """ |
| 80 | + if '__init__' in ctx.cls.info.names: |
| 81 | + # Don't override existing definition. |
| 82 | + return |
| 83 | + typ = AnyType(TypeOfAny.special_form) |
| 84 | + var = Var('kwargs', typ) |
| 85 | + kw_arg = Argument(variable=var, type_annotation=typ, initializer=None, kind=ARG_STAR2) |
| 86 | + add_method(ctx, '__init__', [kw_arg], NoneTyp()) |
| 87 | + ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True |
| 88 | + |
| 89 | + |
| 90 | +def decl_deco_hook(ctx: ClassDefContext) -> None: |
| 91 | + """Support declaring base class as declarative with a decorator. |
| 92 | +
|
| 93 | + For example: |
| 94 | + from from sqlalchemy.ext.declarative import as_declarative |
| 95 | +
|
| 96 | + @as_declarative |
| 97 | + class Base: |
| 98 | + ... |
| 99 | + """ |
| 100 | + set_declarative(ctx.cls.info) |
| 101 | + |
| 102 | + |
| 103 | +def decl_info_hook(ctx): |
| 104 | + """Support dynamically defining declarative bases. |
| 105 | +
|
| 106 | + For example: |
| 107 | + from sqlalchemy.ext.declarative import declarative_base |
| 108 | +
|
| 109 | + Base = declarative_base() |
| 110 | + """ |
| 111 | + class_def = ClassDef(ctx.name, Block([])) |
| 112 | + class_def.fullname = ctx.api.qualified_name(ctx.name) |
| 113 | + |
| 114 | + info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id) |
| 115 | + class_def.info = info |
| 116 | + obj = ctx.api.builtin_type('builtins.object') |
| 117 | + info.mro = [info, obj.type] |
| 118 | + info.bases = [obj] |
| 119 | + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) |
| 120 | + set_declarative(info) |
| 121 | + |
| 122 | + |
| 123 | +def model_hook(ctx: FunctionContext) -> Type: |
| 124 | + """More precise model instantiation check. |
| 125 | +
|
| 126 | + Note: sub-models are not supported. |
| 127 | + Note: this is still not perfect, since the context for inference of |
| 128 | + argument types is 'Any'. |
| 129 | + """ |
| 130 | + assert isinstance(ctx.default_return_type, Instance) |
| 131 | + model = ctx.default_return_type.type |
| 132 | + metadata = model.metadata.get('sqlalchemy') |
| 133 | + if not metadata or not metadata.get('generated_init'): |
| 134 | + return ctx.default_return_type |
| 135 | + |
| 136 | + # Collect column names and types defined in the model |
| 137 | + # TODO: cache this? |
| 138 | + expected_types = {} # type: Dict[str, Type] |
| 139 | + for name, sym in model.names.items(): |
| 140 | + if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): |
| 141 | + tp = sym.node.type |
| 142 | + if tp.type.fullname() in (COLUMN_NAME, RELATIONSHIP_NAME): |
| 143 | + assert len(tp.args) == 1 |
| 144 | + expected_types[name] = tp.args[0] |
| 145 | + |
| 146 | + assert len(ctx.arg_names) == 1 # only **kwargs in generated __init__ |
| 147 | + assert len(ctx.arg_types) == 1 |
| 148 | + for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]): |
| 149 | + if actual_name not in expected_types: |
| 150 | + ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name, model.name()), |
| 151 | + ctx.context) |
| 152 | + continue |
| 153 | + # Using private API to simplify life. |
| 154 | + ctx.api.check_subtype(actual_type, expected_types[actual_name], |
| 155 | + ctx.context, |
| 156 | + 'Incompatible type for "{}" of "{}"'.format(actual_name, model.name()), |
| 157 | + 'got', 'expected') |
| 158 | + return ctx.default_return_type |
| 159 | + |
| 160 | + |
| 161 | +def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]: |
| 162 | + """Return the expression for the specific argument. |
| 163 | +
|
| 164 | + This helper should only be used with non-star arguments. |
| 165 | + """ |
| 166 | + if name not in ctx.callee_arg_names: |
| 167 | + return None |
| 168 | + idx = ctx.callee_arg_names.index(name) |
| 169 | + args = ctx.args[idx] |
| 170 | + if len(args) != 1: |
| 171 | + # Either an error or no value passed. |
| 172 | + return None |
| 173 | + return args[0] |
| 174 | + |
| 175 | + |
| 176 | +def get_argtype_by_name(ctx: FunctionContext, name: str) -> Optional[Type]: |
| 177 | + """Same as above but for argument type.""" |
| 178 | + if name not in ctx.callee_arg_names: |
| 179 | + return None |
| 180 | + idx = ctx.callee_arg_names.index(name) |
| 181 | + arg_types = ctx.arg_types[idx] |
| 182 | + if len(arg_types) != 1: |
| 183 | + # Either an error or no value passed. |
| 184 | + return None |
| 185 | + return arg_types[0] |
| 186 | + |
| 187 | + |
| 188 | +def column_hook(ctx: FunctionContext) -> Type: |
| 189 | + """Infer better types for Column calls. |
| 190 | +
|
| 191 | + Examples: |
| 192 | + Column(String) -> Column[Optional[str]] |
| 193 | + Column(String, primary_key=True) -> Column[str] |
| 194 | + Column(String, nullable=False) -> Column[str] |
| 195 | + Column(String, default=...) -> Column[str] |
| 196 | + Column(String, default=..., nullable=True) -> Column[Optional[str]] |
| 197 | +
|
| 198 | + TODO: check the type of 'default'. |
| 199 | + """ |
| 200 | + assert isinstance(ctx.default_return_type, Instance) |
| 201 | + |
| 202 | + nullable_arg = get_argument_by_name(ctx, 'nullable') |
| 203 | + primary_arg = get_argument_by_name(ctx, 'primary_key') |
| 204 | + default_arg = get_argument_by_name(ctx, 'default') |
| 205 | + |
| 206 | + if nullable_arg: |
| 207 | + nullable = parse_bool(nullable_arg) |
| 208 | + else: |
| 209 | + if primary_arg: |
| 210 | + nullable = not parse_bool(primary_arg) |
| 211 | + else: |
| 212 | + nullable = default_arg is None |
| 213 | + # TODO: Add support for literal types. |
| 214 | + |
| 215 | + if not nullable: |
| 216 | + return ctx.default_return_type |
| 217 | + assert len(ctx.default_return_type.args) == 1 |
| 218 | + arg_type = ctx.default_return_type.args[0] |
| 219 | + return Instance(ctx.default_return_type.type, [UnionType([arg_type, NoneTyp()])], |
| 220 | + line=ctx.default_return_type.line, |
| 221 | + column=ctx.default_return_type.column) |
| 222 | + |
| 223 | + |
| 224 | +def relationship_hook(ctx: FunctionContext) -> Type: |
| 225 | + """Support basic use cases for relationships. |
| 226 | +
|
| 227 | + Examples: |
| 228 | + from sqlalchemy.orm import relationship |
| 229 | +
|
| 230 | + from one import OneModel |
| 231 | + if TYPE_CHECKING: |
| 232 | + from other import OtherModel |
| 233 | +
|
| 234 | + class User(Base): |
| 235 | + __tablename__ = 'users' |
| 236 | + id = Column(Integer(), primary_key=True) |
| 237 | + one = relationship(OneModel) |
| 238 | + other = relationship("OtherModel") |
| 239 | +
|
| 240 | + This also tries to infer the type argument for 'RelationshipProperty' |
| 241 | + using the 'uselist' flag. |
| 242 | + """ |
| 243 | + assert isinstance(ctx.default_return_type, Instance) |
| 244 | + original_type_arg = ctx.default_return_type.args[0] |
| 245 | + has_annotation = not isinstance(original_type_arg, UninhabitedType) |
| 246 | + |
| 247 | + arg = get_argument_by_name(ctx, 'argument') |
| 248 | + arg_type = get_argtype_by_name(ctx, 'argument') |
| 249 | + |
| 250 | + uselist_arg = get_argument_by_name(ctx, 'uselist') |
| 251 | + |
| 252 | + if isinstance(arg, StrExpr): |
| 253 | + name = arg.value |
| 254 | + # Private API for local lookup, but probably needs to be public. |
| 255 | + try: |
| 256 | + sym = ctx.api.lookup_qualified(name) # type: Optional[SymbolTableNode] |
| 257 | + except (KeyError, AssertionError): |
| 258 | + sym = None |
| 259 | + if sym and isinstance(sym.node, TypeInfo): |
| 260 | + new_arg = fill_typevars_with_any(sym.node) |
| 261 | + else: |
| 262 | + ctx.api.fail('Cannot find model "{}"'.format(name), ctx.context) |
| 263 | + ctx.api.note('Only imported models can be found; use "if TYPE_CHECKING: ..." to avoid import cycles', ctx.context) |
| 264 | + new_arg = AnyType(TypeOfAny.from_error) |
| 265 | + else: |
| 266 | + if isinstance(arg_type, CallableType) and arg_type.is_type_obj(): |
| 267 | + new_arg = fill_typevars_with_any(arg_type.type_object()) |
| 268 | + else: |
| 269 | + # Something complex, stay silent for now. |
| 270 | + new_arg = AnyType(TypeOfAny.special_form) |
| 271 | + |
| 272 | + # We figured out, the model type. Now check if we need to wrap it in Iterable |
| 273 | + if uselist_arg: |
| 274 | + if parse_bool(uselist_arg): |
| 275 | + new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) |
| 276 | + else: |
| 277 | + if has_annotation: |
| 278 | + # If there is an annotation we use it as a source of truth. |
| 279 | + # This will cause false negatives, but it is better than lots of false positives. |
| 280 | + new_arg = original_type_arg |
| 281 | + |
| 282 | + return Instance(ctx.default_return_type.type, [new_arg], |
| 283 | + line=ctx.default_return_type.line, |
| 284 | + column=ctx.default_return_type.column) |
| 285 | + |
| 286 | + |
| 287 | +# We really need to add this to TypeChecker API |
| 288 | +def parse_bool(expr: Expression) -> Optional[bool]: |
| 289 | + if isinstance(expr, NameExpr): |
| 290 | + if expr.fullname == 'builtins.True': |
| 291 | + return True |
| 292 | + if expr.fullname == 'builtins.False': |
| 293 | + return False |
| 294 | + return None |
| 295 | + |
| 296 | + |
| 297 | +def plugin(version): |
| 298 | + return BasicSQLAlchemyPlugin |
0 commit comments