Skip to content

Commit 0e19564

Browse files
author
Mehdi
committed
Try to guess if a relationship is iterable
1 parent 3a9f71d commit 0e19564

File tree

1 file changed

+56
-1
lines changed

1 file changed

+56
-1
lines changed

sqlmypy.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
112112

113113
# We currently only handle setting __tablename__ as a class attribute, and not through a property.
114114
if stmt.lvalues[0].name == "__tablename__" and isinstance(stmt.rvalue, StrExpr):
115-
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['tablename'] = stmt.rvalue.value
115+
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['table_name'] = stmt.rvalue.value
116116

117117
if isinstance(stmt.rvalue, CallExpr) and stmt.rvalue.callee.fullname == COLUMN_NAME:
118118
# Save columns. The name of a column on the db side can be different from the one inside the SA model.
@@ -339,6 +339,54 @@ def column_hook(ctx: FunctionContext) -> Type:
339339
column=ctx.default_return_type.column)
340340

341341

342+
class IncompleteModelMetadata(Exception):
343+
pass
344+
345+
346+
def has_foreign_keys(local_model: TypeInfo, remote_model: TypeInfo) -> bool:
347+
"""Tells if `local_model` has a fk to `remote_model`.
348+
Will raise an `IncompleteModelMetadata` if some mandatory metadata is missing.
349+
"""
350+
local_metadata = local_model.metadata.get("sqlalchemy", {})
351+
remote_metadata = remote_model.metadata.get("sqlalchemy", {})
352+
353+
for fk in local_metadata.get("foreign_keys", {}).values():
354+
if 'model_fullname' in fk and remote_model.fullname == fk['model_fullname']:
355+
return True
356+
if 'table_name' in fk:
357+
if 'table_name' not in remote_metadata:
358+
raise IncompleteModelMetadata
359+
# TODO: handle different schemas
360+
if remote_metadata['table_name'] == fk['table_name']:
361+
return True
362+
363+
return False
364+
365+
366+
def is_relationship_iterable(ctx: FunctionContext, local_model: TypeInfo, remote_model: TypeInfo) -> bool:
367+
"""Tries to guess if the relationship is onetoone/onetomany/manytoone.
368+
369+
Currently we handle the most current case, where a model relates to the other one through a relationship.
370+
We also handle cases where secondaryjoin argument is provided.
371+
We don't handle advanced usecases (foreign keys on both sides, primaryjoin, etc.).
372+
"""
373+
secondaryjoin = get_argument_by_name(ctx, 'secondaryjoin')
374+
375+
if secondaryjoin is not None:
376+
return True
377+
378+
try:
379+
can_be_many_to_one = has_foreign_keys(local_model, remote_model)
380+
can_be_one_to_many = has_foreign_keys(remote_model, local_model)
381+
382+
if not can_be_many_to_one and can_be_one_to_many:
383+
return True
384+
except IncompleteModelMetadata:
385+
pass
386+
387+
return False # Assume relationship is not iterable, if we weren't able to guess better.
388+
389+
342390
def relationship_hook(ctx: FunctionContext) -> Type:
343391
"""Support basic use cases for relationships.
344392
@@ -391,10 +439,17 @@ class User(Base):
391439
# Something complex, stay silent for now.
392440
new_arg = AnyType(TypeOfAny.special_form)
393441

442+
current_model = ctx.api.scope.active_class()
443+
assert current_model is not None
444+
445+
# TODO: handle backref relationships
446+
394447
# We figured out, the model type. Now check if we need to wrap it in Iterable
395448
if uselist_arg:
396449
if parse_bool(uselist_arg):
397450
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg])
451+
elif not isinstance(new_arg, AnyType) and is_relationship_iterable(ctx, current_model, new_arg.type):
452+
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg])
398453
else:
399454
if has_annotation:
400455
# If there is an annotation we use it as a source of truth.

0 commit comments

Comments
 (0)