@@ -112,7 +112,7 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
112112
113113 # We currently only handle setting __tablename__ as a class attribute, and not through a property.
114114 if stmt .lvalues [0 ].name == "__tablename__" and isinstance (stmt .rvalue , StrExpr ):
115- ctx .cls .info .metadata .setdefault ('sqlalchemy' , {})['tablename ' ] = stmt .rvalue .value
115+ ctx .cls .info .metadata .setdefault ('sqlalchemy' , {})['table_name ' ] = stmt .rvalue .value
116116
117117 if isinstance (stmt .rvalue , CallExpr ) and stmt .rvalue .callee .fullname == COLUMN_NAME :
118118 # Save columns. The name of a column on the db side can be different from the one inside the SA model.
@@ -339,6 +339,54 @@ def column_hook(ctx: FunctionContext) -> Type:
339339 column = ctx .default_return_type .column )
340340
341341
342+ class IncompleteModelMetadata (Exception ):
343+ pass
344+
345+
346+ def has_foreign_keys (local_model : TypeInfo , remote_model : TypeInfo ) -> bool :
347+ """Tells if `local_model` has a fk to `remote_model`.
348+ Will raise an `IncompleteModelMetadata` if some mandatory metadata is missing.
349+ """
350+ local_metadata = local_model .metadata .get ("sqlalchemy" , {})
351+ remote_metadata = remote_model .metadata .get ("sqlalchemy" , {})
352+
353+ for fk in local_metadata .get ("foreign_keys" , {}).values ():
354+ if 'model_fullname' in fk and remote_model .fullname == fk ['model_fullname' ]:
355+ return True
356+ if 'table_name' in fk :
357+ if 'table_name' not in remote_metadata :
358+ raise IncompleteModelMetadata
359+ # TODO: handle different schemas
360+ if remote_metadata ['table_name' ] == fk ['table_name' ]:
361+ return True
362+
363+ return False
364+
365+
366+ def is_relationship_iterable (ctx : FunctionContext , local_model : TypeInfo , remote_model : TypeInfo ) -> bool :
367+ """Tries to guess if the relationship is onetoone/onetomany/manytoone.
368+
369+ Currently we handle the most current case, where a model relates to the other one through a relationship.
370+ We also handle cases where secondaryjoin argument is provided.
371+ We don't handle advanced usecases (foreign keys on both sides, primaryjoin, etc.).
372+ """
373+ secondaryjoin = get_argument_by_name (ctx , 'secondaryjoin' )
374+
375+ if secondaryjoin is not None :
376+ return True
377+
378+ try :
379+ can_be_many_to_one = has_foreign_keys (local_model , remote_model )
380+ can_be_one_to_many = has_foreign_keys (remote_model , local_model )
381+
382+ if not can_be_many_to_one and can_be_one_to_many :
383+ return True
384+ except IncompleteModelMetadata :
385+ pass
386+
387+ return False # Assume relationship is not iterable, if we weren't able to guess better.
388+
389+
342390def relationship_hook (ctx : FunctionContext ) -> Type :
343391 """Support basic use cases for relationships.
344392
@@ -391,10 +439,17 @@ class User(Base):
391439 # Something complex, stay silent for now.
392440 new_arg = AnyType (TypeOfAny .special_form )
393441
442+ current_model = ctx .api .scope .active_class ()
443+ assert current_model is not None
444+
445+ # TODO: handle backref relationships
446+
394447 # We figured out, the model type. Now check if we need to wrap it in Iterable
395448 if uselist_arg :
396449 if parse_bool (uselist_arg ):
397450 new_arg = ctx .api .named_generic_type ('typing.Iterable' , [new_arg ])
451+ elif not isinstance (new_arg , AnyType ) and is_relationship_iterable (ctx , current_model , new_arg .type ):
452+ new_arg = ctx .api .named_generic_type ('typing.Iterable' , [new_arg ])
398453 else :
399454 if has_annotation :
400455 # If there is an annotation we use it as a source of truth.
0 commit comments