Skip to content

Commit c52cce0

Browse files
committed
Add ability to pass explicit localns (and globalns) to class_schema
When class_schema is called, it doesn't need the caller's whole stack frame. What it really wants is a `localns` to pass to `typing.get_type_hints` to be used to resolve type references. Here we add the ability to pass an explicit `localns` parameter to `class_schema`. We also add the ability to pass an explicit `globalns`, because ... might as well — it might come in useful. (Since we need these only to pass to `get_type_hints`, we might as well match `get_type_hints` API as closely as possible.)
1 parent 70aee99 commit c52cce0

File tree

1 file changed

+45
-14
lines changed

1 file changed

+45
-14
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,36 @@ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
273273
return decorator(_cls, stacklevel=stacklevel + 1)
274274

275275

276+
@overload
277+
def class_schema(
278+
clazz: type,
279+
base_schema: Optional[Type[marshmallow.Schema]] = None,
280+
*,
281+
globalns: Optional[Dict[str, Any]] = None,
282+
localns: Optional[Dict[str, Any]] = None,
283+
) -> Type[marshmallow.Schema]:
284+
...
285+
286+
287+
@overload
288+
def class_schema(
289+
clazz: type,
290+
base_schema: Optional[Type[marshmallow.Schema]] = None,
291+
clazz_frame: Optional[types.FrameType] = None,
292+
*,
293+
globalns: Optional[Dict[str, Any]] = None,
294+
) -> Type[marshmallow.Schema]:
295+
...
296+
297+
276298
def class_schema(
277299
clazz: type,
278300
base_schema: Optional[Type[marshmallow.Schema]] = None,
279-
clazz_frame: types.FrameType = None,
301+
# FIXME: delete clazz_frame from API?
302+
clazz_frame: Optional[types.FrameType] = None,
303+
*,
304+
globalns: Optional[Dict[str, Any]] = None,
305+
localns: Optional[Dict[str, Any]] = None,
280306
) -> Type[marshmallow.Schema]:
281307
"""
282308
Convert a class to a marshmallow schema
@@ -398,24 +424,26 @@ def class_schema(
398424
"""
399425
if not dataclasses.is_dataclass(clazz):
400426
clazz = dataclasses.dataclass(clazz)
401-
if not clazz_frame:
402-
clazz_frame = _maybe_get_callers_frame(clazz)
403-
404-
with _SchemaContext(clazz_frame):
427+
if localns is None:
428+
if clazz_frame is None:
429+
clazz_frame = _maybe_get_callers_frame(clazz)
430+
if clazz_frame is not None:
431+
localns = clazz_frame.f_locals
432+
with _SchemaContext(globalns, localns):
405433
return _internal_class_schema(clazz, base_schema)
406434

407435

408436
class _SchemaContext:
409437
"""Global context for an invocation of class_schema."""
410438

411-
def __init__(self, frame: Optional[types.FrameType]):
439+
def __init__(
440+
self,
441+
globalns: Optional[Dict[str, Any]] = None,
442+
localns: Optional[Dict[str, Any]] = None,
443+
):
412444
self.seen_classes: Dict[type, str] = {}
413-
self.frame = frame
414-
415-
def get_type_hints(self, cls: Type) -> Dict[str, Any]:
416-
frame = self.frame
417-
localns = frame.f_locals if frame is not None else None
418-
return get_type_hints(cls, localns=localns)
445+
self.globalns = globalns
446+
self.localns = localns
419447

420448
def __enter__(self) -> "_SchemaContext":
421449
_schema_ctx_stack.push(self)
@@ -481,7 +509,9 @@ def _internal_class_schema(
481509
}
482510

483511
# Update the schema members to contain marshmallow fields instead of dataclass fields
484-
type_hints = schema_ctx.get_type_hints(clazz)
512+
type_hints = get_type_hints(
513+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
514+
)
485515
attributes.update(
486516
(
487517
field.name,
@@ -660,6 +690,7 @@ def field_for_schema(
660690
default=marshmallow.missing,
661691
metadata: Mapping[str, Any] = None,
662692
base_schema: Optional[Type[marshmallow.Schema]] = None,
693+
# FIXME: delete typ_frame from API?
663694
typ_frame: Optional[types.FrameType] = None,
664695
) -> marshmallow.fields.Field:
665696
"""
@@ -682,7 +713,7 @@ def field_for_schema(
682713
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
683714
<class 'marshmallow.fields.Url'>
684715
"""
685-
with _SchemaContext(typ_frame):
716+
with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None):
686717
return _field_for_schema(typ, default, metadata, base_schema)
687718

688719

0 commit comments

Comments
 (0)