33import inspect
44import sys
55from typing import (
6+ Any ,
67 Dict ,
8+ ForwardRef ,
79 Generic ,
810 List ,
911 Optional ,
1517
1618if sys .version_info >= (3 , 9 ):
1719 from typing import Annotated , get_args , get_origin
20+
21+ def eval_forward_ref (t : ForwardRef , globalns , localns , recursive_guard = frozenset ()):
22+ return t ._evaluate (globalns , localns , recursive_guard )
23+
1824else :
1925 from typing_extensions import Annotated , get_args , get_origin
2026
27+ def eval_forward_ref (t : ForwardRef , globalns , localns ):
28+ return t ._evaluate (globalns , localns )
29+
30+
2131_U = TypeVar ("_U" )
2232
2333
@@ -99,7 +109,35 @@ def may_contain_typevars(clazz: type) -> bool:
99109 )
100110
101111
102- def _resolve_typevars (clazz : type ) -> Dict [type , Dict [TypeVar , _Future ]]:
112+ def _get_namespaces (
113+ clazz : type ,
114+ globalns : Optional [Dict [str , Any ]] = None ,
115+ localns : Optional [Dict [str , Any ]] = None ,
116+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
117+ # region - Copied from typing.get_type_hints
118+ if globalns is None :
119+ base_globals = getattr (sys .modules .get (clazz .__module__ , None ), "__dict__" , {})
120+ else :
121+ base_globals = globalns
122+ base_locals = dict (vars (clazz )) if localns is None else localns
123+ if localns is None and globalns is None :
124+ # This is surprising, but required. Before Python 3.10,
125+ # get_type_hints only evaluated the globalns of
126+ # a class. To maintain backwards compatibility, we reverse
127+ # the globalns and localns order so that eval() looks into
128+ # *base_globals* first rather than *base_locals*.
129+ # This only affects ForwardRefs.
130+ base_globals , base_locals = base_locals , base_globals
131+ # endregion - Copied from typing.get_type_hints
132+
133+ return base_globals , base_locals
134+
135+
136+ def _resolve_typevars (
137+ clazz : type ,
138+ globalns : Optional [Dict [str , Any ]] = None ,
139+ localns : Optional [Dict [str , Any ]] = None ,
140+ ) -> Dict [type , Dict [TypeVar , _Future ]]:
103141 """
104142 Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics.
105143
@@ -110,6 +148,7 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
110148 parent_class : Optional [type ] = None
111149 # Loop in reversed order and iteratively resolve types
112150 for subclass in reversed (clazz .mro ()):
151+ base_globals , base_locals = _get_namespaces (subclass , globalns , localns )
113152 if issubclass (subclass , Generic ) and hasattr (subclass , "__orig_bases__" ): # type: ignore[arg-type]
114153 args = get_args (subclass .__orig_bases__ [0 ])
115154
@@ -121,10 +160,17 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
121160 if isinstance (potential_type , TypeVar ):
122161 subclass_generic_params_to_args .append ((potential_type , future ))
123162 else :
124- future .set_result (potential_type )
163+ future .set_result (
164+ eval_forward_ref (
165+ potential_type ,
166+ globalns = base_globals ,
167+ localns = base_locals ,
168+ )
169+ if isinstance (potential_type , ForwardRef )
170+ else potential_type
171+ )
125172
126173 args_by_class [subclass ] = tuple (subclass_generic_params_to_args )
127-
128174 else :
129175 args_by_class [subclass ] = tuple ((arg , _Future ()) for arg in args )
130176
@@ -136,7 +182,11 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
136182 args = get_args (clazz )
137183 for (_arg , future ), potential_type in zip (args_by_class [origin ], args ): # type: ignore[index]
138184 if not isinstance (potential_type , TypeVar ):
139- future .set_result (potential_type )
185+ future .set_result (
186+ eval_forward_ref (potential_type , globalns = globalns , localns = localns )
187+ if isinstance (potential_type , ForwardRef )
188+ else potential_type
189+ )
140190
141191 # Convert to nested dict for easier lookup
142192 return {k : {typ : fut for typ , fut in args } for k , args in args_by_class .items ()}
@@ -166,12 +216,16 @@ def _replace_typevars(
166216 )
167217
168218
169- def get_generic_dataclass_fields (clazz : type ) -> Tuple [dataclasses .Field , ...]:
219+ def get_resolved_dataclass_fields (
220+ clazz : type ,
221+ globalns : Optional [Dict [str , Any ]] = None ,
222+ localns : Optional [Dict [str , Any ]] = None ,
223+ ) -> Tuple [dataclasses .Field , ...]:
170224 unbound_fields = set ()
171225 # Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and
172226 # looses the source class. Thus I don't know how to resolve this at later on.
173227 # Instead we recreate the type but with all known TypeVars resolved to their actual types.
174- resolved_typevars = _resolve_typevars (clazz )
228+ resolved_typevars = _resolve_typevars (clazz , globalns = globalns , localns = localns )
175229 # Dict[field_name, Tuple[original_field, resolved_field]]
176230 fields : Dict [str , Tuple [dataclasses .Field , dataclasses .Field ]] = {}
177231
@@ -190,14 +244,34 @@ def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
190244 if not inspect .isclass (field .type ) and may_contain_typevars (field .type ):
191245 new_field = copy .copy (field )
192246 new_field .type = _replace_typevars (
193- field .type , resolved_typevars [ subclass ]
247+ field .type , resolved_typevars . get ( subclass )
194248 )
195249 elif isinstance (field .type , TypeVar ):
196250 new_field = copy .copy (field )
197251 new_field .type = resolved_typevars [subclass ][field .type ].result ()
252+ elif isinstance (field .type , ForwardRef ):
253+ base_globals , base_locals = _get_namespaces (
254+ subclass , globalns , localns
255+ )
256+ new_field = copy .copy (field )
257+ new_field .type = eval_forward_ref (
258+ field .type , globalns = base_globals , localns = base_locals
259+ )
260+ elif isinstance (field .type , str ):
261+ base_globals , base_locals = _get_namespaces (
262+ subclass , globalns , localns
263+ )
264+ new_field = copy .copy (field )
265+ new_field .type = eval_forward_ref (
266+ ForwardRef (field .type , is_argument = False , is_class = True )
267+ if sys .version_info >= (3 , 9 )
268+ else ForwardRef (field .type , is_argument = False ),
269+ globalns = base_globals ,
270+ localns = base_locals ,
271+ )
198272
199273 fields [field .name ] = (field , new_field )
200- except InvalidStateError :
274+ except ( InvalidStateError , KeyError ) :
201275 unbound_fields .add (field .name )
202276
203277 if unbound_fields :
0 commit comments