@@ -273,10 +273,36 @@ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
273273 return decorator (_cls , stacklevel = stacklevel + 1 )
274274
275275
276+ @overload
277+ def class_schema (
278+ clazz : type ,
279+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
280+ * ,
281+ globalns : Optional [Dict [str , Any ]] = None ,
282+ localns : Optional [Dict [str , Any ]] = None ,
283+ ) -> Type [marshmallow .Schema ]:
284+ ...
285+
286+
287+ @overload
288+ def class_schema (
289+ clazz : type ,
290+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
291+ clazz_frame : Optional [types .FrameType ] = None ,
292+ * ,
293+ globalns : Optional [Dict [str , Any ]] = None ,
294+ ) -> Type [marshmallow .Schema ]:
295+ ...
296+
297+
276298def class_schema (
277299 clazz : type ,
278300 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
301+ # FIXME: delete clazz_frame from API?
279302 clazz_frame : Optional [types .FrameType ] = None ,
303+ * ,
304+ globalns : Optional [Dict [str , Any ]] = None ,
305+ localns : Optional [Dict [str , Any ]] = None ,
280306) -> Type [marshmallow .Schema ]:
281307 """
282308 Convert a class to a marshmallow schema
@@ -398,24 +424,26 @@ def class_schema(
398424 """
399425 if not dataclasses .is_dataclass (clazz ):
400426 clazz = dataclasses .dataclass (clazz )
401- if not clazz_frame :
402- clazz_frame = _maybe_get_callers_frame (clazz )
403-
404- with _SchemaContext (clazz_frame ):
427+ if localns is None :
428+ if clazz_frame is None :
429+ clazz_frame = _maybe_get_callers_frame (clazz )
430+ if clazz_frame is not None :
431+ localns = clazz_frame .f_locals
432+ with _SchemaContext (globalns , localns ):
405433 return _internal_class_schema (clazz , base_schema )
406434
407435
408436class _SchemaContext :
409437 """Global context for an invocation of class_schema."""
410438
411- def __init__ (self , frame : Optional [types .FrameType ]):
439+ def __init__ (
440+ self ,
441+ globalns : Optional [Dict [str , Any ]] = None ,
442+ localns : Optional [Dict [str , Any ]] = None ,
443+ ):
412444 self .seen_classes : Dict [type , str ] = {}
413- self .frame = frame
414-
415- def get_type_hints (self , cls : Type ) -> Dict [str , Any ]:
416- frame = self .frame
417- localns = frame .f_locals if frame is not None else None
418- return get_type_hints (cls , localns = localns )
445+ self .globalns = globalns
446+ self .localns = localns
419447
420448 def __enter__ (self ) -> "_SchemaContext" :
421449 _schema_ctx_stack .push (self )
@@ -486,7 +514,9 @@ def _internal_class_schema(
486514 }
487515
488516 # Update the schema members to contain marshmallow fields instead of dataclass fields
489- type_hints = schema_ctx .get_type_hints (clazz )
517+ type_hints = get_type_hints (
518+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
519+ )
490520 attributes .update (
491521 (
492522 field .name ,
@@ -670,6 +700,7 @@ def field_for_schema(
670700 default : Any = marshmallow .missing ,
671701 metadata : Optional [Mapping [str , Any ]] = None ,
672702 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
703+ # FIXME: delete typ_frame from API?
673704 typ_frame : Optional [types .FrameType ] = None ,
674705) -> marshmallow .fields .Field :
675706 """
@@ -692,7 +723,7 @@ def field_for_schema(
692723 >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
693724 <class 'marshmallow.fields.Url'>
694725 """
695- with _SchemaContext (typ_frame ):
726+ with _SchemaContext (localns = typ_frame . f_locals if typ_frame is not None else None ):
696727 return _field_for_schema (typ , default , metadata , base_schema )
697728
698729
0 commit comments