@@ -66,17 +66,17 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
6666 return model_hook
6767 return None
6868
69- def get_dynamic_class_hook (self , fullname : str ):
69+ def get_dynamic_class_hook (self , fullname : str ) -> CB [ DynamicClassDefContext ] :
7070 if fullname == 'sqlalchemy.ext.declarative.api.declarative_base' :
7171 return decl_info_hook
7272 return None
7373
74- def get_class_decorator_hook (self , fullname : str ):
74+ def get_class_decorator_hook (self , fullname : str ) -> CB [ ClassDefContext ] :
7575 if fullname == 'sqlalchemy.ext.declarative.api.as_declarative' :
7676 return decl_deco_hook
7777 return None
7878
79- def get_base_class_hook (self , fullname : str ):
79+ def get_base_class_hook (self , fullname : str ) -> CB [ ClassDefContext ] :
8080 sym = self .lookup_fully_qualified (fullname )
8181 if sym and isinstance (sym .node , TypeInfo ):
8282 if is_declarative (sym .node ):
@@ -119,7 +119,8 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
119119 if stmt .lvalues [0 ].name == "__tablename__" and isinstance (stmt .rvalue , StrExpr ):
120120 ctx .cls .info .metadata .setdefault ('sqlalchemy' , {})['table_name' ] = stmt .rvalue .value
121121
122- if isinstance (stmt .rvalue , CallExpr ) and stmt .rvalue .callee .fullname == COLUMN_NAME :
122+ if (isinstance (stmt .rvalue , CallExpr ) and isinstance (stmt .rvalue .callee , NameExpr )
123+ and stmt .rvalue .callee .fullname == COLUMN_NAME ):
123124 # Save columns. The name of a column on the db side can be different from the one inside the SA model.
124125 sa_column_name = stmt .lvalues [0 ].name
125126
@@ -138,7 +139,8 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
138139
139140 # Save foreign keys.
140141 for arg in stmt .rvalue .args :
141- if isinstance (arg , CallExpr ) and arg .callee .fullname == FOREIGN_KEY_NAME and len (arg .args ) >= 1 :
142+ if (isinstance (arg , CallExpr ) and isinstance (arg .callee , NameExpr )
143+ and arg .callee .fullname == FOREIGN_KEY_NAME and len (arg .args ) >= 1 ):
142144 fk = arg .args [0 ]
143145 if isinstance (fk , StrExpr ):
144146 * r , parent_table_name , parent_db_col_name = fk .value .split ("." )
@@ -149,7 +151,7 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
149151 "table_name" : parent_table_name ,
150152 "schema" : r [0 ] if r else None
151153 }
152- elif isinstance (fk , MemberExpr ):
154+ elif isinstance (fk , MemberExpr ) and isinstance ( fk . expr , NameExpr ) :
153155 ctx .cls .info .metadata .setdefault ('sqlalchemy' , {}).setdefault ('foreign_keys' ,
154156 {})[sa_column_name ] = {
155157 "sa_name" : fk .name ,
@@ -463,7 +465,8 @@ class User(Base):
463465 # Something complex, stay silent for now.
464466 new_arg = AnyType (TypeOfAny .special_form )
465467
466- current_model = ctx .api .scope .active_class ()
468+ # use private api
469+ current_model = ctx .api .scope .active_class () # type: ignore # type: TypeInfo
467470 assert current_model is not None
468471
469472 # TODO: handle backref relationships
@@ -472,7 +475,7 @@ class User(Base):
472475 if uselist_arg :
473476 if parse_bool (uselist_arg ):
474477 new_arg = ctx .api .named_generic_type ('typing.Iterable' , [new_arg ])
475- elif not isinstance (new_arg , AnyType ) and is_relationship_iterable (ctx , current_model , new_arg .type ):
478+ elif isinstance (new_arg , Instance ) and is_relationship_iterable (ctx , current_model , new_arg .type ):
476479 new_arg = ctx .api .named_generic_type ('typing.Iterable' , [new_arg ])
477480 else :
478481 if has_annotation :
0 commit comments