@@ -273,10 +273,36 @@ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
273273 return decorator (_cls , stacklevel = stacklevel + 1 )
274274
275275
276+ @overload
277+ def class_schema (
278+ clazz : type ,
279+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
280+ * ,
281+ globalns : Optional [Dict [str , Any ]] = None ,
282+ localns : Optional [Dict [str , Any ]] = None ,
283+ ) -> Type [marshmallow .Schema ]:
284+ ...
285+
286+
287+ @overload
288+ def class_schema (
289+ clazz : type ,
290+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
291+ clazz_frame : Optional [types .FrameType ] = None ,
292+ * ,
293+ globalns : Optional [Dict [str , Any ]] = None ,
294+ ) -> Type [marshmallow .Schema ]:
295+ ...
296+
297+
276298def class_schema (
277299 clazz : type ,
278300 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
279- clazz_frame : types .FrameType = None ,
301+ # FIXME: delete clazz_frame from API?
302+ clazz_frame : Optional [types .FrameType ] = None ,
303+ * ,
304+ globalns : Optional [Dict [str , Any ]] = None ,
305+ localns : Optional [Dict [str , Any ]] = None ,
280306) -> Type [marshmallow .Schema ]:
281307 """
282308 Convert a class to a marshmallow schema
@@ -398,24 +424,26 @@ def class_schema(
398424 """
399425 if not dataclasses .is_dataclass (clazz ):
400426 clazz = dataclasses .dataclass (clazz )
401- if not clazz_frame :
402- clazz_frame = _maybe_get_callers_frame (clazz )
403-
404- with _SchemaContext (clazz_frame ):
427+ if localns is None :
428+ if clazz_frame is None :
429+ clazz_frame = _maybe_get_callers_frame (clazz )
430+ if clazz_frame is not None :
431+ localns = clazz_frame .f_locals
432+ with _SchemaContext (globalns , localns ):
405433 return _internal_class_schema (clazz , base_schema )
406434
407435
408436class _SchemaContext :
409437 """Global context for an invocation of class_schema."""
410438
411- def __init__ (self , frame : Optional [types .FrameType ]):
439+ def __init__ (
440+ self ,
441+ globalns : Optional [Dict [str , Any ]] = None ,
442+ localns : Optional [Dict [str , Any ]] = None ,
443+ ):
412444 self .seen_classes : Dict [type , str ] = {}
413- self .frame = frame
414-
415- def get_type_hints (self , cls : Type ) -> Dict [str , Any ]:
416- frame = self .frame
417- localns = frame .f_locals if frame is not None else None
418- return get_type_hints (cls , localns = localns )
445+ self .globalns = globalns
446+ self .localns = localns
419447
420448 def __enter__ (self ) -> "_SchemaContext" :
421449 _schema_ctx_stack .push (self )
@@ -481,7 +509,9 @@ def _internal_class_schema(
481509 }
482510
483511 # Update the schema members to contain marshmallow fields instead of dataclass fields
484- type_hints = schema_ctx .get_type_hints (clazz )
512+ type_hints = get_type_hints (
513+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
514+ )
485515 attributes .update (
486516 (
487517 field .name ,
@@ -660,6 +690,7 @@ def field_for_schema(
660690 default = marshmallow .missing ,
661691 metadata : Mapping [str , Any ] = None ,
662692 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
693+ # FIXME: delete typ_frame from API?
663694 typ_frame : Optional [types .FrameType ] = None ,
664695) -> marshmallow .fields .Field :
665696 """
@@ -682,7 +713,7 @@ def field_for_schema(
682713 >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
683714 <class 'marshmallow.fields.Url'>
684715 """
685- with _SchemaContext (typ_frame ):
716+ with _SchemaContext (localns = typ_frame . f_locals if typ_frame is not None else None ):
686717 return _field_for_schema (typ , default , metadata , base_schema )
687718
688719
0 commit comments