@@ -36,7 +36,6 @@ class User:
3636"""
3737
3838import collections .abc
39- import copy
4039import dataclasses
4140import inspect
4241import sys
@@ -64,6 +63,12 @@ class User:
6463import typing_extensions
6564import typing_inspect
6665
66+ from marshmallow_dataclass .generic_resolver import (
67+ UnboundTypeVarError ,
68+ get_generic_dataclass_fields ,
69+ is_generic_alias ,
70+ is_generic_type ,
71+ )
6772from marshmallow_dataclass .lazy_class_attribute import lazy_class_attribute
6873
6974if sys .version_info >= (3 , 9 ):
@@ -134,55 +139,10 @@ def _maybe_get_callers_frame(
134139 del frame
135140
136141
137- class UnboundTypeVarError (TypeError ):
138- """TypeVar instance can not be resolved to a type spec.
139-
140- This exception is raised when an unbound TypeVar is encountered.
141-
142- """
143-
144-
145- class InvalidStateError (Exception ):
146- """Raised when an operation is performed on a future that is not
147- allowed in the current state.
148- """
149-
150-
151- class _Future (Generic [_U ]):
152- """The _Future class allows deferred access to a result that is not
153- yet available.
154- """
155-
156- _done : bool
157- _result : _U
158-
159- def __init__ (self ) -> None :
160- self ._done = False
161-
162- def done (self ) -> bool :
163- """Return ``True`` if the value is available"""
164- return self ._done
165-
166- def result (self ) -> _U :
167- """Return the deferred value.
168-
169- Raises ``InvalidStateError`` if the value has not been set.
170- """
171- if self .done ():
172- return self ._result
173- raise InvalidStateError ("result has not been set" )
174-
175- def set_result (self , result : _U ) -> None :
176- if self .done ():
177- raise InvalidStateError ("result has already been set" )
178- self ._result = result
179- self ._done = True
180-
181-
182142def _check_decorated_type (cls : object ) -> None :
183143 if not isinstance (cls , type ):
184144 raise TypeError (f"expected a class not { cls !r} " )
185- if _is_generic_alias (cls ):
145+ if is_generic_alias (cls ):
186146 # A .Schema attribute doesn't make sense on a generic alias — there's
187147 # no way for it to know the generic parameters at run time.
188148 raise TypeError (
@@ -513,9 +473,7 @@ def class_schema(
513473 >>> class_schema(Custom)().load({})
514474 Custom(name=None)
515475 """
516- if not dataclasses .is_dataclass (clazz ) and not _is_generic_alias_of_dataclass (
517- clazz
518- ):
476+ if not dataclasses .is_dataclass (clazz ) and not is_generic_alias_of_dataclass (clazz ):
519477 clazz = dataclasses .dataclass (clazz )
520478 if localns is None :
521479 if clazz_frame is None :
@@ -791,8 +749,16 @@ def _field_for_annotated_type(
791749 marshmallow_annotations = [
792750 arg
793751 for arg in arguments [1 :]
794- if (inspect .isclass (arg ) and issubclass (arg , marshmallow .fields .Field ))
795- or isinstance (arg , marshmallow .fields .Field )
752+ if _is_marshmallow_field (arg )
753+ # Support `CustomGenericField[mf.String]`
754+ or (
755+ is_generic_type (arg )
756+ and _is_marshmallow_field (typing_extensions .get_origin (arg ))
757+ )
758+ # Support `partial(mf.List, mf.String)`
759+ or (isinstance (arg , partial ) and _is_marshmallow_field (arg .func ))
760+ # Support `lambda *args, **kwargs: mf.List(mf.String, *args, **kwargs)`
761+ or (_is_callable_marshmallow_field (arg ))
796762 ]
797763 if marshmallow_annotations :
798764 if len (marshmallow_annotations ) > 1 :
@@ -932,7 +898,7 @@ def _field_for_schema(
932898
933899 # i.e.: Literal['abc']
934900 if typing_inspect .is_literal_type (typ ):
935- arguments = typing_inspect .get_args (typ )
901+ arguments = typing_extensions .get_args (typ )
936902 return marshmallow .fields .Raw (
937903 validate = (
938904 marshmallow .validate .Equal (arguments [0 ])
@@ -944,7 +910,7 @@ def _field_for_schema(
944910
945911 # i.e.: Final[str] = 'abc'
946912 if typing_inspect .is_final_type (typ ):
947- arguments = typing_inspect .get_args (typ )
913+ arguments = typing_extensions .get_args (typ )
948914 if arguments :
949915 subtyp = arguments [0 ]
950916 elif default is not marshmallow .missing :
@@ -1061,14 +1027,14 @@ def _get_field_default(field: dataclasses.Field):
10611027 return field .default
10621028
10631029
1064- def _is_generic_alias_of_dataclass (clazz : type ) -> bool :
1030+ def is_generic_alias_of_dataclass (clazz : type ) -> bool :
10651031 """
10661032 Check if given class is a generic alias of a dataclass, if the dataclass is
10671033 defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
10681034 """
10691035 is_generic = is_generic_type (clazz )
1070- type_arguments = typing_inspect .get_args (clazz )
1071- origin_class = typing_inspect .get_origin (clazz )
1036+ type_arguments = typing_extensions .get_args (clazz )
1037+ origin_class = typing_extensions .get_origin (clazz )
10721038 return (
10731039 is_generic
10741040 and len (type_arguments ) > 0
@@ -1107,136 +1073,30 @@ class X:
11071073 return _get_type_hints (X , schema_ctx )["x" ]
11081074
11091075
1110- def _is_generic_alias (clazz : type ) -> bool :
1111- """
1112- Check if given class is a generic alias of a class is
1113- defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
1114- """
1115- is_generic = is_generic_type (clazz )
1116- type_arguments = typing_inspect .get_args (clazz )
1117- return is_generic and len (type_arguments ) > 0
1118-
1119-
1120- def is_generic_type (clazz : type ) -> bool :
1121- """
1122- typing_inspect.is_generic_type explicitly ignores Union, Tuple, Callable, ClassVar
1123- """
1124- return (
1125- isinstance (clazz , type )
1126- and issubclass (clazz , Generic ) # type: ignore[arg-type]
1127- or isinstance (clazz , typing_inspect .typingGenericAlias )
1128- )
1129-
1130-
1131- def _resolve_typevars (clazz : type ) -> Dict [type , Dict [TypeVar , _Future ]]:
1132- """
1133- Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics.
1134-
1135- Returns a dict of each base class and the resolved generics.
1136- """
1137- # Use Tuples so can zip (order matters)
1138- args_by_class : Dict [type , Tuple [Tuple [TypeVar , _Future ], ...]] = {}
1139- parent_class : Optional [type ] = None
1140- # Loop in reversed order and iteratively resolve types
1141- for subclass in reversed (clazz .mro ()):
1142- if issubclass (subclass , Generic ) and hasattr (subclass , "__orig_bases__" ): # type: ignore[arg-type]
1143- args = typing_inspect .get_args (subclass .__orig_bases__ [0 ])
1144-
1145- if parent_class and args_by_class .get (parent_class ):
1146- subclass_generic_params_to_args : List [Tuple [TypeVar , _Future ]] = []
1147- for (_arg , future ), potential_type in zip (
1148- args_by_class [parent_class ], args
1149- ):
1150- if isinstance (potential_type , TypeVar ):
1151- subclass_generic_params_to_args .append ((potential_type , future ))
1152- else :
1153- future .set_result (potential_type )
1154-
1155- args_by_class [subclass ] = tuple (subclass_generic_params_to_args )
1156-
1157- else :
1158- args_by_class [subclass ] = tuple ((arg , _Future ()) for arg in args )
1159-
1160- parent_class = subclass
1161-
1162- # clazz itself is a generic alias i.e.: A[int]. So it hold the last types.
1163- if _is_generic_alias (clazz ):
1164- origin = typing_inspect .get_origin (clazz )
1165- args = typing_inspect .get_args (clazz )
1166- for (_arg , future ), potential_type in zip (args_by_class [origin ], args ):
1167- if not isinstance (potential_type , TypeVar ):
1168- future .set_result (potential_type )
1169-
1170- # Convert to nested dict for easier lookup
1171- return {k : {typ : fut for typ , fut in args } for k , args in args_by_class .items ()}
1172-
1173-
1174- def _replace_typevars (
1175- clazz : type , resolved_generics : Optional [Dict [TypeVar , _Future ]] = None
1176- ) -> type :
1177- if not resolved_generics or inspect .isclass (clazz ) or not is_generic_type (clazz ):
1178- return clazz
1179-
1180- return clazz .copy_with ( # type: ignore
1181- tuple (
1182- (
1183- _replace_typevars (arg , resolved_generics )
1184- if is_generic_type (arg )
1185- else (
1186- resolved_generics [arg ].result () if arg in resolved_generics else arg
1187- )
1188- )
1189- for arg in typing_inspect .get_args (clazz )
1190- )
1191- )
1192-
1193-
11941076def _dataclass_fields (clazz : type ) -> Tuple [dataclasses .Field , ...]:
11951077 if not is_generic_type (clazz ):
11961078 return dataclasses .fields (clazz )
11971079
11981080 else :
1199- unbound_fields = set ()
1200- # Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and
1201- # looses the source class. Thus I don't know how to resolve this at later on.
1202- # Instead we recreate the type but with all known TypeVars resolved to their actual types.
1203- resolved_typevars = _resolve_typevars (clazz )
1204- # Dict[field_name, Tuple[original_field, resolved_field]]
1205- fields : Dict [str , Tuple [dataclasses .Field , dataclasses .Field ]] = {}
1206-
1207- for subclass in reversed (clazz .mro ()):
1208- if not dataclasses .is_dataclass (subclass ):
1209- continue
1210-
1211- for field in dataclasses .fields (subclass ):
1212- try :
1213- if field .name in fields and fields [field .name ][0 ] == field :
1214- continue # identical, so already resolved.
1215-
1216- # Either the first time we see this field, or it got overridden
1217- # If it's a class we handle it later as a Nested. Nothing to resolve now.
1218- new_field = field
1219- if not inspect .isclass (field .type ) and is_generic_type (field .type ):
1220- new_field = copy .copy (field )
1221- new_field .type = _replace_typevars (
1222- field .type , resolved_typevars [subclass ]
1223- )
1224- elif isinstance (field .type , TypeVar ):
1225- new_field = copy .copy (field )
1226- new_field .type = resolved_typevars [subclass ][
1227- field .type
1228- ].result ()
1229-
1230- fields [field .name ] = (field , new_field )
1231- except InvalidStateError :
1232- unbound_fields .add (field .name )
1233-
1234- if unbound_fields :
1235- raise UnboundTypeVarError (
1236- f"{ clazz .__name__ } has unbound fields: { ', ' .join (unbound_fields )} "
1237- )
1081+ return get_generic_dataclass_fields (clazz )
1082+
1083+
1084+ def _is_marshmallow_field (obj ) -> bool :
1085+ return (
1086+ inspect .isclass (obj ) and issubclass (obj , marshmallow .fields .Field )
1087+ ) or isinstance (obj , marshmallow .fields .Field )
1088+
1089+
1090+ def _is_callable_marshmallow_field (obj ) -> bool :
1091+ """Checks if the object is a callable and if the callable returns a marshmallow field"""
1092+ if callable (obj ):
1093+ try :
1094+ potential_field = obj ()
1095+ return _is_marshmallow_field (potential_field )
1096+ except Exception :
1097+ return False
12381098
1239- return tuple ( v [ 1 ] for v in fields . values ())
1099+ return False
12401100
12411101
12421102def NewType (
0 commit comments