55from promise import Promise , is_thenable
66from sqlalchemy .orm .query import Query
77
8+ from graphene import NonNull
89from graphene .relay import Connection , ConnectionField
910from graphene .relay .connection import PageInfo
1011from graphql_relay .connection .arrayconnection import connection_from_list_slice
@@ -19,19 +20,26 @@ def type(self):
1920 from .types import SQLAlchemyObjectType
2021
2122 _type = super (ConnectionField , self ).type
22- if issubclass (_type , Connection ):
23+ nullable_type = get_nullable_type (_type )
24+ if issubclass (nullable_type , Connection ):
2325 return _type
24- assert issubclass (_type , SQLAlchemyObjectType ), (
26+ assert issubclass (nullable_type , SQLAlchemyObjectType ), (
2527 "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
26- ).format (_type .__name__ )
27- assert _type .connection , "The type {} doesn't have a connection" .format (
28- _type .__name__
28+ ).format (nullable_type .__name__ )
29+ assert (
30+ nullable_type .connection
31+ ), "The type {} doesn't have a connection" .format (
32+ nullable_type .__name__
2933 )
30- return _type .connection
34+ assert _type == nullable_type , (
35+ "Passing a SQLAlchemyObjectType instance is deprecated. "
36+ "Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
37+ )
38+ return nullable_type .connection
3139
3240 @property
3341 def model (self ):
34- return self .type ._meta .node ._meta .model
42+ return get_nullable_type ( self .type ) ._meta .node ._meta .model
3543
3644 @classmethod
3745 def get_query (cls , model , info , ** args ):
@@ -70,21 +78,27 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg
7078 return on_resolve (resolved )
7179
7280 def get_resolver (self , parent_resolver ):
73- return partial (self .connection_resolver , parent_resolver , self .type , self .model )
81+ return partial (
82+ self .connection_resolver ,
83+ parent_resolver ,
84+ get_nullable_type (self .type ),
85+ self .model ,
86+ )
7487
7588
7689# TODO Rename this to SortableSQLAlchemyConnectionField
7790class SQLAlchemyConnectionField (UnsortedSQLAlchemyConnectionField ):
7891 def __init__ (self , type , * args , ** kwargs ):
79- if "sort" not in kwargs and issubclass (type , Connection ):
92+ nullable_type = get_nullable_type (type )
93+ if "sort" not in kwargs and issubclass (nullable_type , Connection ):
8094 # Let super class raise if type is not a Connection
8195 try :
82- kwargs .setdefault ("sort" , type .Edge .node ._type .sort_argument ())
96+ kwargs .setdefault ("sort" , nullable_type .Edge .node ._type .sort_argument ())
8397 except (AttributeError , TypeError ):
8498 raise TypeError (
8599 'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
86100 " to None to disabling the creation of the sort query argument" .format (
87- type .__name__
101+ nullable_type .__name__
88102 )
89103 )
90104 elif "sort" in kwargs and kwargs ["sort" ] is None :
@@ -108,8 +122,14 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
108122 The API and behavior may change in future versions.
109123 Use at your own risk.
110124 """
125+
111126 def get_resolver (self , parent_resolver ):
112- return partial (self .connection_resolver , self .resolver , self .type , self .model )
127+ return partial (
128+ self .connection_resolver ,
129+ self .resolver ,
130+ get_nullable_type (self .type ),
131+ self .model ,
132+ )
113133
114134 @classmethod
115135 def from_relationship (cls , relationship , registry , ** field_kwargs ):
@@ -155,3 +175,9 @@ def unregisterConnectionFieldFactory():
155175 )
156176 global __connectionFactory
157177 __connectionFactory = UnsortedSQLAlchemyConnectionField
178+
179+
180+ def get_nullable_type (_type ):
181+ if isinstance (_type , NonNull ):
182+ return _type .of_type
183+ return _type
0 commit comments