@@ -110,26 +110,45 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
110110 if not (isinstance (stmt , AssignmentStmt ) and len (stmt .lvalues ) == 1 and isinstance (stmt .lvalues [0 ], NameExpr )):
111111 continue
112112
113+ # We currently only handle setting __tablename__ as a class attribute, and not through a property.
113114 if stmt .lvalues [0 ].name == "__tablename__" and isinstance (stmt .rvalue , StrExpr ):
114115 ctx .cls .info .metadata .setdefault ('sqlalchemy' , {})['tablename' ] = stmt .rvalue .value
115116
116117 if isinstance (stmt .rvalue , CallExpr ) and stmt .rvalue .callee .fullname == COLUMN_NAME :
117- colname = stmt .lvalues [0 ].name
118- has_explicit_colname = stmt .rvalue
119- ctx .cls .info .metadata .setdefault ('sqlalchemy' , {}).setdefault ('columns' , []).append (colname )
118+ # Save columns. The name of a column on the db side can be different from the one inside the SA model.
119+ sa_column_name = stmt .lvalues [0 ].name
120+
121+ db_column_name = None # type: Optional[str]
122+ if 'name' in stmt .rvalue .arg_names :
123+ name_str_expr = stmt .rvalue .args [stmt .rvalue .arg_names .index ('name' )]
124+ assert isinstance (name_str_expr , StrExpr )
125+ db_column_name = name_str_expr .value
126+ else :
127+ if len (stmt .rvalue .args ) >= 1 and isinstance (stmt .rvalue .args [0 ], StrExpr ):
128+ db_column_name = stmt .rvalue .args [0 ].value
129+
130+ ctx .cls .info .metadata .setdefault ('sqlalchemy' , {}).setdefault ('columns' , []).append (
131+ {"sa_name" : sa_column_name , "db_name" : db_column_name or sa_column_name }
132+ )
133+
134+ # Save foreign keys.
120135 for arg in stmt .rvalue .args :
121136 if isinstance (arg , CallExpr ) and arg .callee .fullname == FOREIGN_KEY_NAME and len (arg .args ) >= 1 :
122137 fk = arg .args [0 ]
123138 if isinstance (fk , StrExpr ):
124- * _ , parent_table , parent_col = fk .value .split ("." )
125- ctx .cls .info .metadata .setdefault ('sqlalchemy' , {}).setdefault ('foreign_keys' , {})[colname ] = {
126- "column" : parent_col ,
127- "table" : parent_table
139+ * r , parent_table_name , parent_db_col_name = fk .value .split ("." )
140+ assert len (r ) <= 1
141+ ctx .cls .info .metadata .setdefault ('sqlalchemy' , {}).setdefault ('foreign_keys' ,
142+ {})[sa_column_name ] = {
143+ "db_name" : parent_db_col_name ,
144+ "table_name" : parent_table_name ,
145+ "schema" : r [0 ] if r else None
128146 }
129147 elif isinstance (fk , MemberExpr ):
130- ctx .cls .info .metadata .setdefault ('sqlalchemy' , {}).setdefault ('foreign_keys' , {})[colname ] = {
131- "column" : fk .name ,
132- "model" : fk .expr .fullname
148+ ctx .cls .info .metadata .setdefault ('sqlalchemy' , {}).setdefault ('foreign_keys' ,
149+ {})[sa_column_name ] = {
150+ "sa_name" : fk .name ,
151+ "model_fullname" : fk .expr .fullname
133152 }
134153
135154 # Also add a selection of auto-generated attributes.
@@ -390,10 +409,10 @@ class User(Base):
390409# We really need to add this to TypeChecker API
391410def parse_bool (expr : Expression ) -> Optional [bool ]:
392411 if isinstance (expr , NameExpr ):
393- if expr .fullname == 'builtins.True' :
394- return True
395- if expr .fullname == 'builtins.False' :
396- return False
412+ if expr .fullname == 'builtins.True' :
413+ return True
414+ if expr .fullname == 'builtins.False' :
415+ return False
397416 return None
398417
399418
0 commit comments