@@ -413,14 +413,12 @@ async def resolve(self, info: Info):
413413 return []
414414 else :
415415 return None
416- if isinstance (info .context , dict ):
416+ if isinstance (info .context , dict ):
417417 loader = info .context ["sqlalchemy_loader" ]
418418 else :
419419 loader = info .context .sqlalchemy_loader
420- related_objects = (
421- await loader
422- .loader_for (relationship )
423- .load (relationship_key )
420+ related_objects = await loader .loader_for (relationship ).load (
421+ relationship_key
424422 )
425423 return related_objects
426424
@@ -537,7 +535,10 @@ def _handle_columns(
537535 )
538536
539537 def type (
540- self , model : Type [BaseModelType ], make_interface = False
538+ self ,
539+ model : Type [BaseModelType ],
540+ make_interface = False ,
541+ use_federation = False ,
541542 ) -> Callable [[Type [object ]], Any ]:
542543 """
543544 Decorate a type with this to register it as a strawberry type
@@ -662,10 +663,11 @@ def convert(type_: Any) -> Any:
662663
663664 if make_interface :
664665 mapped_type = strawberry .interface (type_ )
665- self .mapped_interfaces [type_ .__name__ ] = mapped_type
666+ elif use_federation :
667+ mapped_type = strawberry .federation .type (type_ )
666668 else :
667669 mapped_type = strawberry .type (type_ )
668- self .mapped_types [type_ .__name__ ] = mapped_type
670+ self .mapped_types [type_ .__name__ ] = mapped_type
669671 setattr (mapped_type , _GENERATED_FIELD_KEYS_KEY , generated_field_keys )
670672 setattr (mapped_type , _ORIGINAL_TYPE_KEY , type_ )
671673 return mapped_type
0 commit comments