Skip to content

Commit 6c45757

Browse files
authored
Fix recursive references when using class_schema() (#189)
Fixes #188 The @marshmallow_dataclass.dataclass decorator relies on a lazily computed .Schema attribute/descriptor on the decorated class to solve this issue. For class_schema() this would probably not be a valid approach as the expectation is that its "clazz" argument would not be modified.
1 parent ae3bcdd commit 6c45757

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class User:
3737
import collections.abc
3838
import dataclasses
3939
import inspect
40+
import threading
4041
import types
4142
import warnings
4243
from enum import EnumMeta
@@ -77,6 +78,10 @@ class User:
7778
# Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates.
7879
MAX_CLASS_SCHEMA_CACHE_SIZE = 1024
7980

81+
# Recursion guard for class_schema()
82+
_RECURSION_GUARD = threading.local()
83+
_RECURSION_GUARD.seen_classes = {}
84+
8085

8186
@overload
8287
def dataclass(
@@ -347,7 +352,10 @@ def class_schema(
347352
clazz_frame = current_frame.f_back
348353
# Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
349354
del current_frame
350-
return _internal_class_schema(clazz, base_schema, clazz_frame)
355+
try:
356+
return _internal_class_schema(clazz, base_schema, clazz_frame)
357+
finally:
358+
_RECURSION_GUARD.seen_classes.clear()
351359

352360

353361
@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
@@ -356,6 +364,7 @@ def _internal_class_schema(
356364
base_schema: Optional[Type[marshmallow.Schema]] = None,
357365
clazz_frame: types.FrameType = None,
358366
) -> Type[marshmallow.Schema]:
367+
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
359368
try:
360369
# noinspection PyDataclass
361370
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
@@ -712,9 +721,11 @@ def field_for_schema(
712721

713722
# Nested dataclasses
714723
forward_reference = getattr(typ, "__forward_arg__", None)
724+
715725
nested = (
716726
nested_schema
717727
or forward_reference
728+
or _RECURSION_GUARD.seen_classes.get(typ)
718729
or _internal_class_schema(typ, base_schema, typ_frame)
719730
)
720731

tests/test_class_schema.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,39 @@ class J:
400400
[validator_a, validator_b, validator_c, validator_d],
401401
)
402402

403+
def test_recursive_reference(self):
404+
@dataclasses.dataclass
405+
class Tree:
406+
children: typing.List["Tree"] # noqa: F821
407+
408+
schema = class_schema(Tree)()
409+
410+
self.assertEqual(
411+
schema.load({"children": [{"children": []}]}),
412+
Tree(children=[Tree(children=[])]),
413+
)
414+
415+
def test_cyclic_reference(self):
416+
@dataclasses.dataclass
417+
class First:
418+
second: typing.Optional["Second"] # noqa: F821
419+
420+
@dataclasses.dataclass
421+
class Second:
422+
first: typing.Optional["First"]
423+
424+
first_schema = class_schema(First)()
425+
second_schema = class_schema(Second)()
426+
427+
self.assertEqual(
428+
first_schema.load({"second": {"first": None}}),
429+
First(second=Second(first=None)),
430+
)
431+
self.assertEqual(
432+
second_schema.dump(Second(first=First(second=Second(first=None)))),
433+
{"first": {"second": {"first": None}}},
434+
)
435+
403436

404437
if __name__ == "__main__":
405438
unittest.main()

0 commit comments

Comments
 (0)