@@ -61,17 +61,17 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
6161 return model_hook
6262 return None
6363
64- def get_dynamic_class_hook (self , fullname : str ):
64+ def get_dynamic_class_hook (self , fullname : str ) -> CB [ DynamicClassDefContext ] :
6565 if fullname == 'sqlalchemy.ext.declarative.api.declarative_base' :
6666 return decl_info_hook
6767 return None
6868
69- def get_class_decorator_hook (self , fullname : str ):
69+ def get_class_decorator_hook (self , fullname : str ) -> CB [ ClassDefContext ] :
7070 if fullname == 'sqlalchemy.ext.declarative.api.as_declarative' :
7171 return decl_deco_hook
7272 return None
7373
74- def get_base_class_hook (self , fullname : str ):
74+ def get_base_class_hook (self , fullname : str ) -> CB [ ClassDefContext ] :
7575 sym = self .lookup_fully_qualified (fullname )
7676 if sym and isinstance (sym .node , TypeInfo ):
7777 if is_declarative (sym .node ):
@@ -114,7 +114,8 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
114114 if stmt .lvalues [0 ].name == "__tablename__" and isinstance (stmt .rvalue , StrExpr ):
115115 ctx .cls .info .metadata .setdefault ('sqlalchemy' , {})['table_name' ] = stmt .rvalue .value
116116
117- if isinstance (stmt .rvalue , CallExpr ) and stmt .rvalue .callee .fullname == COLUMN_NAME :
117+ if (isinstance (stmt .rvalue , CallExpr ) and isinstance (stmt .rvalue .callee , NameExpr )
118+ and stmt .rvalue .callee .fullname == COLUMN_NAME ):
118119 # Save columns. The name of a column on the db side can be different from the one inside the SA model.
119120 sa_column_name = stmt .lvalues [0 ].name
120121
@@ -133,7 +134,8 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
133134
134135 # Save foreign keys.
135136 for arg in stmt .rvalue .args :
136- if isinstance (arg , CallExpr ) and arg .callee .fullname == FOREIGN_KEY_NAME and len (arg .args ) >= 1 :
137+ if (isinstance (arg , CallExpr ) and isinstance (arg .callee , NameExpr )
138+ and arg .callee .fullname == FOREIGN_KEY_NAME and len (arg .args ) >= 1 ):
137139 fk = arg .args [0 ]
138140 if isinstance (fk , StrExpr ):
139141 * r , parent_table_name , parent_db_col_name = fk .value .split ("." )
@@ -144,7 +146,7 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
144146 "table_name" : parent_table_name ,
145147 "schema" : r [0 ] if r else None
146148 }
147- elif isinstance (fk , MemberExpr ):
149+ elif isinstance (fk , MemberExpr ) and isinstance ( fk . expr , NameExpr ) :
148150 ctx .cls .info .metadata .setdefault ('sqlalchemy' , {}).setdefault ('foreign_keys' ,
149151 {})[sa_column_name ] = {
150152 "sa_name" : fk .name ,
@@ -439,7 +441,8 @@ class User(Base):
439441 # Something complex, stay silent for now.
440442 new_arg = AnyType (TypeOfAny .special_form )
441443
442- current_model = ctx .api .scope .active_class ()
444+ # use private api
445+ current_model = ctx .api .scope .active_class () # type: ignore # type: TypeInfo
443446 assert current_model is not None
444447
445448 # TODO: handle backref relationships
@@ -448,7 +451,7 @@ class User(Base):
448451 if uselist_arg :
449452 if parse_bool (uselist_arg ):
450453 new_arg = ctx .api .named_generic_type ('typing.Iterable' , [new_arg ])
451- elif not isinstance (new_arg , AnyType ) and is_relationship_iterable (ctx , current_model , new_arg .type ):
454+ elif isinstance (new_arg , Instance ) and is_relationship_iterable (ctx , current_model , new_arg .type ):
452455 new_arg = ctx .api .named_generic_type ('typing.Iterable' , [new_arg ])
453456 else :
454457 if has_annotation :
0 commit comments