@@ -151,6 +151,8 @@ def __init__(
151151 extra_sqlalchemy_type_to_strawberry_type_map : Optional [
152152 Mapping [Type [TypeEngine ], Type [Any ]]
153153 ] = None ,
154+ edge_type : Type = None ,
155+ connection_type : Type = None ,
154156 ) -> None :
155157 if model_to_type_name is None :
156158 model_to_type_name = self ._default_model_to_type_name
@@ -172,6 +174,9 @@ def __init__(
172174 self ._related_type_models = set ()
173175 self ._related_interface_models = set ()
174176
177+ self .edge_type = edge_type
178+ self .connection_type = connection_type
179+
175180 @staticmethod
176181 def _default_model_to_type_name (model : Type [BaseModelType ]) -> str :
177182 return model .__name__
@@ -211,6 +216,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
211216 Get or create a corresponding Edge model for the given type
212217 (to support future pagination)
213218 """
219+ if self .edge_type is not None :
220+ return self .edge_type
214221 edge_name = f"{ type_name } Edge"
215222 if edge_name not in self .edge_types :
216223 self .edge_types [edge_name ] = edge_type = strawberry .type (
@@ -229,6 +236,8 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
229236 Get or create a corresponding Connection model for the given type
230237 (to support future pagination)
231238 """
239+ if self .connection_type is not None :
240+ return self .connection_type [ForwardRef (type_name )]
232241 connection_name = f"{ type_name } Connection"
233242 if connection_name not in self .connection_types :
234243 self .connection_types [connection_name ] = connection_type = strawberry .type (
@@ -259,7 +268,7 @@ def _convert_column_to_strawberry_type(
259268 corresponding strawberry type.
260269 """
261270 if isinstance (column .type , Enum ):
262- type_annotation = column .type .python_type
271+ type_annotation = strawberry . enum ( column .type .python_type )
263272 elif isinstance (column .type , ARRAY ):
264273 item_type = self ._convert_column_to_strawberry_type (
265274 Column (column .type .item_type , nullable = False )
@@ -404,7 +413,11 @@ async def resolve(self, info: Info):
404413 else :
405414 relationship_key = tuple (
406415 [
407- getattr (self , local .key )
416+ [
417+ getattr (self , k )
418+ for k , column in self .__mapper__ .c .items ()
419+ if local .key == column .key
420+ ][0 ]
408421 for local , _ in relationship .local_remote_pairs
409422 ]
410423 )
@@ -539,6 +552,7 @@ def type(
539552 model : Type [BaseModelType ],
540553 make_interface = False ,
541554 use_federation = False ,
555+ ** kwargs ,
542556 ) -> Callable [[Type [object ]], Any ]:
543557 """
544558 Decorate a type with this to register it as a strawberry type
@@ -560,128 +574,158 @@ class Employee:
560574 ```
561575 """
562576
563- def convert (type_ : Any ) -> Any :
564- old_annotations = getattr (type_ , "__annotations__" , {})
565- type_ .__annotations__ = {}
566- mapper : Mapper = inspect (model )
567- generated_field_keys = []
568-
569- excluded_keys = getattr (type_ , "__exclude__" , [])
570-
571- # if the type inherits from another mapped type, then it may have
572- # generated resolvers. These will be treated by dataclasses as having
573- # a default value, which will likely cause issues because of keys
574- # that don't have default values. To fix this, we wrap them in
575- # `strawberry.field()` (like when they were originally made), so
576- # dataclasses will ignore them.
577- # TODO: Potentially raise/fix this issue upstream
578- for key in dir (type_ ):
579- val = getattr (type_ , key )
580- if getattr (val , _IS_GENERATED_RESOLVER_KEY , False ):
581- setattr (type_ , key , strawberry .field (resolver = val ))
582- generated_field_keys .append (key )
583-
584- self ._handle_columns (mapper , type_ , excluded_keys , generated_field_keys )
585- for key , relationship in mapper .relationships .items ():
586- relationship : RelationshipProperty
587- if (
588- key in excluded_keys
589- or key in type_ .__annotations__
590- or hasattr (type_ , key )
591- ):
592- continue
593- strawberry_type = self ._convert_relationship_to_strawberry_type (
594- relationship
595- )
596- self ._add_annotation (
597- type_ ,
598- key ,
577+ def do_conversion (type_ ):
578+ return self .convert (
579+ type_ ,
580+ model ,
581+ make_interface ,
582+ use_federation ,
583+ )
584+
585+ return do_conversion
586+
587+ def convert (
588+ self ,
589+ type_ : Any ,
590+ model : Type [BaseModelType ],
591+ make_interface = False ,
592+ use_federation = False ,
593+ ) -> Any :
594+ """
595+ Do type conversion. Usually accessed using typical .type decorator. But
596+ can also be used as standalone function.
597+ """
598+ old_annotations = getattr (type_ , "__annotations__" , {})
599+ type_ .__annotations__ = {}
600+ mapper : Mapper = inspect (model )
601+ generated_field_keys = []
602+
603+ excluded_keys = getattr (type_ , "__exclude__" , [])
604+
605+ # if the type inherits from another mapped type, then it may have
606+ # generated resolvers. These will be treated by dataclasses as having
607+ # a default value, which will likely cause issues because of keys
608+ # that don't have default values. To fix this, we wrap them in
609+ # `strawberry.field()` (like when they were originally made), so
610+ # dataclasses will ignore them.
611+ # TODO: Potentially raise/fix this issue upstream
612+ for key in dir (type_ ):
613+ val = getattr (type_ , key )
614+ if getattr (val , _IS_GENERATED_RESOLVER_KEY , False ):
615+ setattr (type_ , key , strawberry .field (resolver = val ))
616+ generated_field_keys .append (key )
617+
618+ self ._handle_columns (mapper , type_ , excluded_keys , generated_field_keys )
619+ for key , relationship in mapper .relationships .items ():
620+ relationship : RelationshipProperty
621+ if (
622+ key in excluded_keys
623+ or key in type_ .__annotations__
624+ or hasattr (type_ , key )
625+ ):
626+ continue
627+ strawberry_type = self ._convert_relationship_to_strawberry_type (
628+ relationship
629+ )
630+ self ._add_annotation (
631+ type_ ,
632+ key ,
633+ strawberry_type ,
634+ generated_field_keys ,
635+ )
636+ field = strawberry .field (
637+ resolver = self .connection_resolver_for (relationship )
638+ )
639+ assert not field .init
640+ setattr (
641+ type_ ,
642+ key ,
643+ field ,
644+ )
645+ for key , descriptor in mapper .all_orm_descriptors .items ():
646+ if (
647+ key in excluded_keys
648+ or key in type_ .__annotations__
649+ or hasattr (type_ , key )
650+ ):
651+ continue
652+ if key in mapper .columns or key in mapper .relationships :
653+ continue
654+ if key in model .__annotations__ :
655+ annotation = eval (model .__annotations__ [key ])
656+ for (
657+ sqlalchemy_type ,
599658 strawberry_type ,
600- generated_field_keys ,
659+ ) in self .sqlalchemy_type_to_strawberry_type_map .items ():
660+ if isinstance (annotation , sqlalchemy_type ):
661+ self ._add_annotation (
662+ type_ , key , strawberry_type , generated_field_keys
663+ )
664+ break
665+ elif isinstance (descriptor , AssociationProxy ):
666+ strawberry_type = self ._get_association_proxy_annotation (
667+ mapper , key , descriptor
601668 )
669+ if strawberry_type is SkipTypeSentinel :
670+ continue
671+ self ._add_annotation (type_ , key , strawberry_type , generated_field_keys )
602672 field = strawberry .field (
603- resolver = self .connection_resolver_for (relationship )
673+ resolver = self .association_proxy_resolver_for (
674+ mapper , descriptor , strawberry_type
675+ )
604676 )
605677 assert not field .init
606- setattr (
678+ setattr (type_ , key , field )
679+ elif isinstance (descriptor , hybrid_property ):
680+ if (
681+ not hasattr (descriptor , "__annotations__" )
682+ or "return" not in descriptor .__annotations__
683+ ):
684+ raise HybridPropertyNotAnnotated (key )
685+ annotation = descriptor .__annotations__ ["return" ]
686+ if isinstance (annotation , str ):
687+ try :
688+ if "typing" in annotation :
689+ # Try to evaluate from existing typing imports
690+ annotation = annotation [7 :]
691+ annotation = eval (annotation )
692+ except NameError :
693+ raise UnsupportedDescriptorType (key )
694+ self ._add_annotation (
607695 type_ ,
608696 key ,
609- field ,
697+ annotation ,
698+ generated_field_keys ,
610699 )
611- for key , descriptor in mapper .all_orm_descriptors .items ():
612- if (
613- key in excluded_keys
614- or key in type_ .__annotations__
615- or hasattr (type_ , key )
616- ):
617- continue
618- if key in mapper .columns or key in mapper .relationships :
619- continue
620- if isinstance (descriptor , AssociationProxy ):
621- strawberry_type = self ._get_association_proxy_annotation (
622- mapper , key , descriptor
623- )
624- if strawberry_type is SkipTypeSentinel :
625- continue
626- self ._add_annotation (
627- type_ , key , strawberry_type , generated_field_keys
628- )
629- field = strawberry .field (
630- resolver = self .association_proxy_resolver_for (
631- mapper , descriptor , strawberry_type
632- )
633- )
634- assert not field .init
635- setattr (type_ , key , field )
636- elif isinstance (descriptor , hybrid_property ):
637- if (
638- not hasattr (descriptor , "__annotations__" )
639- or "return" not in descriptor .__annotations__
640- ):
641- raise HybridPropertyNotAnnotated (key )
642- annotation = descriptor .__annotations__ ["return" ]
643- if isinstance (annotation , str ):
644- try :
645- if "typing" in annotation :
646- # Try to evaluate from existing typing imports
647- annotation = annotation [7 :]
648- annotation = eval (annotation )
649- except NameError :
650- raise UnsupportedDescriptorType (key )
651- self ._add_annotation (
652- type_ ,
653- key ,
654- annotation ,
655- generated_field_keys ,
656- )
657- else :
658- raise UnsupportedDescriptorType (key )
700+ else :
701+ raise UnsupportedDescriptorType (key )
659702
660- # ignore inherited `is_type_of`
661- if "is_type_of" not in type_ .__dict__ :
662- type_ .is_type_of = (
663- lambda obj , info : type (obj ) == model or type (obj ) == type_
664- )
703+ # ignore inherited `is_type_of`
704+ if "is_type_of" not in type_ .__dict__ :
705+ type_ .is_type_of = (
706+ lambda obj , info : type (obj ) == model or type (obj ) == type_
707+ )
665708
666- # need to make fields that are already in the type
667- # (prior to mapping) appear *after* the mapped fields
668- # because the pre-existing fields might have default values,
669- # which will cause the mapped fields to fail
670- # (because they may not have default values)
671- type_ .__annotations__ .update (old_annotations )
672-
673- if make_interface :
674- mapped_type = strawberry .interface (type_ )
675- elif use_federation :
676- mapped_type = strawberry .federation .type (type_ )
677- else :
678- mapped_type = strawberry .type (type_ )
679- self .mapped_types [type_ .__name__ ] = mapped_type
680- setattr (mapped_type , _GENERATED_FIELD_KEYS_KEY , generated_field_keys )
681- setattr (mapped_type , _ORIGINAL_TYPE_KEY , type_ )
682- return mapped_type
709+ # need to make fields that are already in the type
710+ # (prior to mapping) appear *after* the mapped fields
711+ # because the pre-existing fields might have default values,
712+ # which will cause the mapped fields to fail
713+ # (because they may not have default values)
714+ type_ .__annotations__ .update (old_annotations )
683715
684- return convert
716+ if make_interface :
717+ type_name = self .model_to_interface_name (type_ )
718+ mapped_type = strawberry .interface (type_ , name = type_name )
719+ else :
720+ type_name = self .model_to_type_name (type_ )
721+ if use_federation :
722+ mapped_type = strawberry .federation .type (type_ , name = type_name )
723+ else :
724+ mapped_type = strawberry .type (type_ , name = type_name )
725+ self .mapped_types [type_name ] = mapped_type
726+ setattr (mapped_type , _GENERATED_FIELD_KEYS_KEY , generated_field_keys )
727+ setattr (mapped_type , _ORIGINAL_TYPE_KEY , type_ )
728+ return mapped_type
685729
686730 def interface (self , model : Type [BaseModelType ]) -> Callable [[Type [object ]], Any ]:
687731 """
0 commit comments