Skip to content

Commit ee7baf3

Browse files
author
Mehdi
committed
Try to guess if a relationship is iterable
1 parent c268644 commit ee7baf3

File tree

1 file changed

+56
-1
lines changed

1 file changed

+56
-1
lines changed

sqlmypy.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
117117

118118
# We currently only handle setting __tablename__ as a class attribute, and not through a property.
119119
if stmt.lvalues[0].name == "__tablename__" and isinstance(stmt.rvalue, StrExpr):
120-
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['tablename'] = stmt.rvalue.value
120+
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['table_name'] = stmt.rvalue.value
121121

122122
if isinstance(stmt.rvalue, CallExpr) and stmt.rvalue.callee.fullname == COLUMN_NAME:
123123
# Save columns. The name of a column on the db side can be different from the one inside the SA model.
@@ -363,6 +363,54 @@ def grouping_hook(ctx: FunctionContext) -> Type:
363363
return ctx.default_return_type
364364

365365

366+
class IncompleteModelMetadata(Exception):
367+
pass
368+
369+
370+
def has_foreign_keys(local_model: TypeInfo, remote_model: TypeInfo) -> bool:
371+
"""Tells if `local_model` has a fk to `remote_model`.
372+
Will raise an `IncompleteModelMetadata` if some mandatory metadata is missing.
373+
"""
374+
local_metadata = local_model.metadata.get("sqlalchemy", {})
375+
remote_metadata = remote_model.metadata.get("sqlalchemy", {})
376+
377+
for fk in local_metadata.get("foreign_keys", {}).values():
378+
if 'model_fullname' in fk and remote_model.fullname == fk['model_fullname']:
379+
return True
380+
if 'table_name' in fk:
381+
if 'table_name' not in remote_metadata:
382+
raise IncompleteModelMetadata
383+
# TODO: handle different schemas
384+
if remote_metadata['table_name'] == fk['table_name']:
385+
return True
386+
387+
return False
388+
389+
390+
def is_relationship_iterable(ctx: FunctionContext, local_model: TypeInfo, remote_model: TypeInfo) -> bool:
391+
"""Tries to guess if the relationship is onetoone/onetomany/manytoone.
392+
393+
Currently we handle the most current case, where a model relates to the other one through a relationship.
394+
We also handle cases where secondaryjoin argument is provided.
395+
We don't handle advanced usecases (foreign keys on both sides, primaryjoin, etc.).
396+
"""
397+
secondaryjoin = get_argument_by_name(ctx, 'secondaryjoin')
398+
399+
if secondaryjoin is not None:
400+
return True
401+
402+
try:
403+
can_be_many_to_one = has_foreign_keys(local_model, remote_model)
404+
can_be_one_to_many = has_foreign_keys(remote_model, local_model)
405+
406+
if not can_be_many_to_one and can_be_one_to_many:
407+
return True
408+
except IncompleteModelMetadata:
409+
pass
410+
411+
return False # Assume relationship is not iterable, if we weren't able to guess better.
412+
413+
366414
def relationship_hook(ctx: FunctionContext) -> Type:
367415
"""Support basic use cases for relationships.
368416
@@ -415,10 +463,17 @@ class User(Base):
415463
# Something complex, stay silent for now.
416464
new_arg = AnyType(TypeOfAny.special_form)
417465

466+
current_model = ctx.api.scope.active_class()
467+
assert current_model is not None
468+
469+
# TODO: handle backref relationships
470+
418471
# We figured out, the model type. Now check if we need to wrap it in Iterable
419472
if uselist_arg:
420473
if parse_bool(uselist_arg):
421474
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg])
475+
elif not isinstance(new_arg, AnyType) and is_relationship_iterable(ctx, current_model, new_arg.type):
476+
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg])
422477
else:
423478
if has_annotation:
424479
# If there is an annotation we use it as a source of truth.

0 commit comments

Comments
 (0)