11from mypy .mro import calculate_mro , MroError
2- from mypy .plugin import Plugin , FunctionContext , ClassDefContext
2+ from mypy .plugin import (
3+ Plugin , FunctionContext , ClassDefContext , DynamicClassDefContext ,
4+ SemanticAnalyzerPluginInterface
5+ )
36from mypy .plugins .common import add_method
47from mypy .nodes import (
58 NameExpr , Expression , StrExpr , TypeInfo , ClassDef , Block , SymbolTable , SymbolTableNode , GDEF ,
1013)
1114from mypy .typevars import fill_typevars_with_any
1215
13- from typing import Optional , Callable , Dict , TYPE_CHECKING , List
16+ from typing import Optional , Callable , Dict , TYPE_CHECKING , List , Type as TypingType , TypeVar
1417if TYPE_CHECKING :
1518 from typing_extensions import Final
1619
20+ T = TypeVar ('T' )
21+ CB = Optional [Callable [[T ], None ]]
22+
1723COLUMN_NAME = 'sqlalchemy.sql.schema.Column' # type: Final
1824RELATIONSHIP_NAME = 'sqlalchemy.orm.relationships.RelationshipProperty' # type: Final
1925
@@ -54,17 +60,17 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
5460 return model_hook
5561 return None
5662
57- def get_dynamic_class_hook (self , fullname ) :
63+ def get_dynamic_class_hook (self , fullname : str ) -> CB [ DynamicClassDefContext ] :
5864 if fullname == 'sqlalchemy.ext.declarative.api.declarative_base' :
5965 return decl_info_hook
6066 return None
6167
62- def get_class_decorator_hook (self , fullname : str ) -> Optional [ Callable [[ ClassDefContext ], None ] ]:
68+ def get_class_decorator_hook (self , fullname : str ) -> CB [ ClassDefContext ]:
6369 if fullname == 'sqlalchemy.ext.declarative.api.as_declarative' :
6470 return decl_deco_hook
6571 return None
6672
67- def get_base_class_hook (self , fullname : str ) -> Optional [ Callable [[ ClassDefContext ], None ] ]:
73+ def get_base_class_hook (self , fullname : str ) -> CB [ ClassDefContext ]:
6874 sym = self .lookup_fully_qualified (fullname )
6975 if sym and isinstance (sym .node , TypeInfo ):
7076 if is_declarative (sym .node ):
@@ -109,9 +115,9 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
109115 add_var_to_class ('__table__' , typ , ctx .cls .info )
110116
111117
112- def add_metadata_var (ctx : ClassDefContext , info : TypeInfo ) -> None :
118+ def add_metadata_var (api : SemanticAnalyzerPluginInterface , info : TypeInfo ) -> None :
113119 """Add .metadata attribute to a declarative base."""
114- sym = ctx . api .lookup_fully_qualified_or_none ('sqlalchemy.sql.schema.MetaData' )
120+ sym = api .lookup_fully_qualified_or_none ('sqlalchemy.sql.schema.MetaData' )
115121 if sym :
116122 assert isinstance (sym .node , TypeInfo )
117123 typ = Instance (sym .node , []) # type: Type
@@ -131,10 +137,10 @@ class Base:
131137 ...
132138 """
133139 set_declarative (ctx .cls .info )
134- add_metadata_var (ctx , ctx .cls .info )
140+ add_metadata_var (ctx . api , ctx .cls .info )
135141
136142
137- def decl_info_hook (ctx ) :
143+ def decl_info_hook (ctx : DynamicClassDefContext ) -> None :
138144 """Support dynamically defining declarative bases.
139145
140146 For example:
@@ -177,7 +183,7 @@ def decl_info_hook(ctx):
177183 set_declarative (info )
178184
179185 # TODO: check what else is added.
180- add_metadata_var (ctx , info )
186+ add_metadata_var (ctx . api , info )
181187
182188
183189def model_hook (ctx : FunctionContext ) -> Type :
@@ -211,13 +217,15 @@ def model_hook(ctx: FunctionContext) -> Type:
211217 # TODO: support TypedDict?
212218 continue
213219 if actual_name not in expected_types :
214- ctx .api .fail ('Unexpected column "{}" for model "{}"' .format (actual_name , model .name ()),
220+ ctx .api .fail ('Unexpected column "{}" for model "{}"' .format (actual_name ,
221+ model .name ()),
215222 ctx .context )
216223 continue
217224 # Using private API to simplify life.
218- ctx .api .check_subtype (actual_type , expected_types [actual_name ],
225+ ctx .api .check_subtype (actual_type , expected_types [actual_name ], # type: ignore
219226 ctx .context ,
220- 'Incompatible type for "{}" of "{}"' .format (actual_name , model .name ()),
227+ 'Incompatible type for "{}" of "{}"' .format (actual_name ,
228+ model .name ()),
221229 'got' , 'expected' )
222230 return ctx .default_return_type
223231
@@ -315,16 +323,19 @@ class User(Base):
315323
316324 if isinstance (arg , StrExpr ):
317325 name = arg .value
318- # Private API for local lookup, but probably needs to be public.
326+ sym = None # type: Optional[SymbolTableNode]
319327 try :
320- sym = ctx .api .lookup_qualified (name ) # type: Optional[SymbolTableNode]
328+ # Private API for local lookup, but probably needs to be public.
329+ sym = ctx .api .lookup_qualified (name ) # type: ignore
321330 except (KeyError , AssertionError ):
322- sym = None
331+ pass
323332 if sym and isinstance (sym .node , TypeInfo ):
324- new_arg = fill_typevars_with_any (sym .node )
333+ new_arg = fill_typevars_with_any (sym .node ) # type: Type
325334 else :
326335 ctx .api .fail ('Cannot find model "{}"' .format (name ), ctx .context )
327- ctx .api .note ('Only imported models can be found; use "if TYPE_CHECKING: ..." to avoid import cycles' ,
336+ # TODO: Add note() to public API.
337+ ctx .api .note ('Only imported models can be found;' # type: ignore
338+ ' use "if TYPE_CHECKING: ..." to avoid import cycles' ,
328339 ctx .context )
329340 new_arg = AnyType (TypeOfAny .from_error )
330341 else :
@@ -359,5 +370,5 @@ def parse_bool(expr: Expression) -> Optional[bool]:
359370 return None
360371
361372
362- def plugin (version ) :
373+ def plugin (version : str ) -> TypingType [ Plugin ] :
363374 return BasicSQLAlchemyPlugin
0 commit comments