diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dca4a3b..65f7a8d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,14 @@ repos: rev: v0.991 hooks: - id: mypy - additional_dependencies: [marshmallow-enum,typeguard,marshmallow] + additional_dependencies: + - marshmallow + - marshmallow-enum + - pytest + - typeguard + - types-cachetools + - types-setuptools + - typing-inspect args: [--show-error-codes] - repo: https://github.com/asottile/blacken-docs rev: v1.12.1 diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 106cde2..17b94a6 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -37,17 +37,26 @@ class User: import collections.abc import dataclasses import inspect +import sys import threading import types import warnings +from contextlib import contextmanager from enum import Enum -from functools import lru_cache, partial +from functools import partial from typing import ( Any, Callable, + ChainMap, + ClassVar, Dict, + Generic, + Hashable, + Iterable, + Iterator, List, Mapping, + MutableMapping, NewType as typing_NewType, Optional, Set, @@ -55,23 +64,69 @@ class User: Type, TypeVar, Union, - cast, get_type_hints, overload, Sequence, FrozenSet, ) +import cachetools import marshmallow import typing_inspect from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute - __all__ = ["dataclass", "add_schema", "class_schema", "field_for_schema", "NewType"] + +if sys.version_info >= (3, 8): + from typing import get_args + from typing import get_origin +elif sys.version_info >= (3, 7): + from typing_extensions import get_args + from typing_extensions import get_origin +else: + + def get_args(tp): + return typing_inspect.get_args(tp, evaluate=True) + + def get_origin(tp): + TYPE_MAP = { + List: list, + Sequence: collections.abc.Sequence, + Set: set, + FrozenSet: frozenset, + Tuple: tuple, + Dict: dict, + Mapping: collections.abc.Mapping, + Generic: Generic, + } + + origin = typing_inspect.get_origin(tp) + if origin in TYPE_MAP: + return TYPE_MAP[origin] + elif origin is not tp: + return origin + return None + + +if sys.version_info >= (3, 8): + from typing import Protocol + from typing import final +else: + from typing_extensions import Protocol + from typing_extensions import final + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + + NoneType = type(None) _U = TypeVar("_U") +_V = TypeVar("_V") +_Field = TypeVar("_Field", bound=marshmallow.fields.Field) # Whitelist of dataclass members that will be copied to generated schema. MEMBERS_WHITELIST: Set[str] = {"Meta"} @@ -79,8 +134,154 @@ class User: # Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates. MAX_CLASS_SCHEMA_CACHE_SIZE = 1024 -# Recursion guard for class_schema() -_RECURSION_GUARD = threading.local() + +class Error(TypeError): + """Class passed ``class_schema`` can not be converted to a Marshmallow schema. + + FIXME: Currently this inherits from TypeError for backward compatibility with + older versions of marshmallow_dataclass which always raised + TypeError(f"{name} is not a dataclass and cannot be turned into one.") + + """ + + +class InvalidClassError(ValueError, Error): + """Argument to ``class_schema`` can not be converted to a Marshmallow schema. + + This exception is raised when, while generating a Marshmallow schema for a + dataclass, a class is encountered for which a Marshmallow Schema can not + be generated. + + """ + + +class UnrecognizedFieldTypeError(Error): + """An unrecognized field type spec was encountered. + + This exception is raised when, while generating a Marshmallow schema for a + dataclass, a field is encountered for which a Marshmallow Field can not + be generated. + + """ + + +class UnboundTypeVarError(Error): + """TypeVar instance can not be resolved to a type spec. + + This exception is raised when an unbound TypeVar is encountered. + + """ + + +################################################################ +# Type aliases and type guards (FIXME: move these) + +if sys.version_info >= (3, 7): + _TypeVarType = TypeVar +else: + # py36: type.TypeVar does not work as a type annotation + # (⇒ "AttributeError: type object 'TypeVar' has no attribute '_gorg'") + _TypeVarType = typing_NewType("_TypeVarType", type) + + +def _is_type_var(obj: object) -> TypeGuard[_TypeVarType]: + return isinstance(obj, TypeVar) + + +TypeSpec = object +GenericAlias = typing_NewType("GenericAlias", object) +GenericAliasOfDataclass = typing_NewType("GenericAliasOfDataclass", GenericAlias) + + +def _is_generic_alias_of_dataclass( + cls: object, +) -> TypeGuard[GenericAliasOfDataclass]: + """ + Check if given class is a generic alias of a dataclass, if the dataclass is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + return _is_dataclass_type(get_origin(cls)) + + +_DataclassType = typing_NewType("_DataclassType", type) + + +def _is_dataclass_type(obj: object) -> TypeGuard[_DataclassType]: + return isinstance(obj, type) and dataclasses.is_dataclass(obj) + + +class _NewType(Protocol): + def __call__(self, obj: _U) -> _U: + ... + + @property + def __name__(self) -> str: + ... + + @property + def __supertype__(self) -> type: + ... + + +def _is_new_type(obj: object) -> TypeGuard[_NewType]: + return bool(typing_inspect.is_new_type(obj)) + + +def _maybe_get_callers_frame( + cls: Union[type, GenericAliasOfDataclass], stacklevel: int = 1 +) -> Optional[types.FrameType]: + """Return the caller's frame, but only if it will help resolve forward type references. + + We sometimes need the caller's frame to get access to the caller's + local namespace in order to be able to resolve forward type + references in dataclasses. + + Notes + ----- + + If the caller's locals are the same as the dataclass' module + globals — this is the case for the common case of dataclasses + defined at the module top-level — we don't need the locals. + (Typing.get_type_hints() knows how to check the class module + globals on its own.) + + In that case, we don't need the caller's frame. Not holding a + reference to the frame in our our lazy ``.Scheme`` class attribute + is a significant win, memory-wise. + + """ + try: + frame = inspect.currentframe() + for _ in range(stacklevel + 1): + if frame is None: + return None + frame = frame.f_back + + if frame is None: + return None + + globalns = getattr(sys.modules.get(cls.__module__), "__dict__", None) + if frame.f_locals is globalns: + # Locals are the globals + return None + + return frame + + finally: + # Paranoia, per https://docs.python.org/3/library/inspect.html#the-interpreter-stack + del frame + + +def _check_decorated_type(cls: object) -> None: + if typing_inspect.is_generic_type(cls): + # A .Schema attribute doesn't make sense on a generic type — there's + # no way for it to know the generic parameters at run time. + raise TypeError( + "decorator does not support generic types " + "(hint: use class_schema directly instead)" + ) + if not isinstance(cls, type): + raise TypeError(f"expected a class not {cls!r}") @overload @@ -125,6 +326,7 @@ def dataclass( frozen: bool = False, base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, ) -> Union[Type[_U], Callable[[Type[_U]], Type[_U]]]: """ This decorator does the same as dataclasses.dataclass, but also applies :func:`add_schema`. @@ -151,30 +353,34 @@ def dataclass( >>> Point.Schema().load({'x':0, 'y':0}) # This line can be statically type checked Point(x=0.0, y=0.0) """ - # dataclass's typing doesn't expect it to be called as a function, so ignore type check - dc = dataclasses.dataclass( # type: ignore - _cls, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen + dc = dataclasses.dataclass( + repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen ) - if not cls_frame: - current_frame = inspect.currentframe() - if current_frame: - cls_frame = current_frame.f_back - # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack - del current_frame + + def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + _check_decorated_type(cls) + dc(cls) + return add_schema( + cls, base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 + ) + if _cls is None: - return lambda cls: add_schema(dc(cls), base_schema, cls_frame=cls_frame) - return add_schema(dc, base_schema, cls_frame=cls_frame) + return decorator + return decorator(_cls, stacklevel=stacklevel + 1) -@overload -def add_schema(_cls: Type[_U]) -> Type[_U]: - ... +class _ClassDecorator(Protocol): + def __call__(self, cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + ... @overload def add_schema( + *, base_schema: Optional[Type[marshmallow.Schema]] = None, -) -> Callable[[Type[_U]], Type[_U]]: + cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, +) -> _ClassDecorator: ... @@ -183,11 +389,18 @@ def add_schema( _cls: Type[_U], base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, ) -> Type[_U]: ... -def add_schema(_cls=None, base_schema=None, cls_frame=None): +def add_schema( + _cls: Optional[Type[_U]] = None, + base_schema: Optional[Type[marshmallow.Schema]] = None, + cls_frame: Optional[types.FrameType] = None, + stacklevel: int = 1, + attr_name: str = "Schema", +) -> Union[Type[_U], _ClassDecorator]: """ This decorator adds a marshmallow schema as the 'Schema' attribute in a dataclass. It uses :func:`class_schema` internally. @@ -209,22 +422,50 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None): Artist(names=('Martin', 'Ramirez')) """ - def decorator(clazz: Type[_U]) -> Type[_U]: - # noinspection PyTypeHints - clazz.Schema = lazy_class_attribute( # type: ignore - partial(class_schema, clazz, base_schema, cls_frame), - "Schema", - clazz.__name__, - ) - return clazz + def decorator(cls: Type[_V], stacklevel: int = stacklevel) -> Type[_V]: + nonlocal cls_frame + _check_decorated_type(cls) + if cls_frame is None: + cls_frame = _maybe_get_callers_frame(cls, stacklevel=stacklevel) + fget = partial(class_schema, cls, base_schema, cls_frame) + setattr(cls, attr_name, lazy_class_attribute(fget, attr_name)) + return cls + + if _cls is None: + return decorator + return decorator(_cls, stacklevel=stacklevel + 1) + - return decorator(_cls) if _cls else decorator +@overload +def class_schema( + clazz: type, + base_schema: Optional[Type[marshmallow.Schema]] = None, + *, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, +) -> Type[marshmallow.Schema]: + ... +@overload def class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, clazz_frame: Optional[types.FrameType] = None, + *, + globalns: Optional[Dict[str, Any]] = None, +) -> Type[marshmallow.Schema]: + ... + + +def class_schema( + clazz: object, + base_schema: Optional[Type[marshmallow.Schema]] = None, + # FIXME: delete clazz_frame from API? + clazz_frame: Optional[types.FrameType] = None, + *, + globalns: Optional[Dict[str, Any]] = None, + localns: Optional[Dict[str, Any]] = None, ) -> Type[marshmallow.Schema]: """ Convert a class to a marshmallow schema @@ -344,447 +585,679 @@ def class_schema( >>> class_schema(Custom)().load({}) Custom(name=None) """ - if not dataclasses.is_dataclass(clazz): - clazz = dataclasses.dataclass(clazz) - if not clazz_frame: - current_frame = inspect.currentframe() - if current_frame: - clazz_frame = current_frame.f_back - # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack - del current_frame - _RECURSION_GUARD.seen_classes = {} - try: - return _internal_class_schema(clazz, base_schema, clazz_frame) - finally: - _RECURSION_GUARD.seen_classes.clear() + if not (_is_dataclass_type(clazz) or _is_generic_alias_of_dataclass(clazz)): + raise InvalidClassError(f"{clazz} is not a dataclass") + if localns is None: + if clazz_frame is None: + clazz_frame = _maybe_get_callers_frame(clazz) + if clazz_frame is not None: + localns = clazz_frame.f_locals -@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) -def _internal_class_schema( - clazz: type, - base_schema: Optional[Type[marshmallow.Schema]] = None, - clazz_frame: Optional[types.FrameType] = None, -) -> Type[marshmallow.Schema]: - _RECURSION_GUARD.seen_classes[clazz] = clazz.__name__ - try: - # noinspection PyDataclass - fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) - except TypeError: # Not a dataclass + if base_schema is None: + base_schema = marshmallow.Schema + + schema_ctx = _SchemaContext(globalns, localns, base_schema) + return schema_ctx.class_schema(clazz).result() + + +class InvalidStateError(Exception): + """Raised when an operation is performed on a future that is not + allowed in the current state. + """ + + +class _Future(Generic[_U]): + """The _Future class allows deferred access to a result that is not + yet available. + """ + + _done: bool + _result: _U + + def __init__(self) -> None: + self._done = False + + def done(self) -> bool: + """Return ``True`` if the value is available""" + return self._done + + def result(self) -> _U: + """Return the deferred value. + + Raises ``InvalidStateError`` if the value has not been set. + """ + if self.done(): + return self._result + raise InvalidStateError("result has not been set") + + def set_result(self, result: _U) -> None: + if self.done(): + raise InvalidStateError("result has already been set") + self._result = result + self._done = True + + +def _has_generic_base(cls: type) -> bool: + """Return True if cls has any generic base classes.""" + return any(typing_inspect.get_parameters(base) for base in cls.__mro__[1:]) + + +@final +@dataclasses.dataclass(frozen=True) +class _TypeVarBindings(Mapping[TypeSpec, TypeSpec]): + """A mapping of bindings of TypeVars to type specs.""" + + parameters: Sequence[_TypeVarType] = () + args: Sequence[TypeSpec] = () + + def __post_init__(self) -> None: + if len(self.parameters) != len(self.args): + raise ValueError("the 'parameters' and 'args' must be of the same length") + + @classmethod + def from_generic_alias(cls, generic_alias: GenericAlias) -> "_TypeVarBindings": + origin = get_origin(generic_alias) + parameters = typing_inspect.get_parameters(origin) + args = get_args(generic_alias) + return cls(parameters, args) + + def __getitem__(self, key: TypeSpec) -> TypeSpec: try: - warnings.warn( - "****** WARNING ****** " - f"marshmallow_dataclass was called on the class {clazz}, which is not a dataclass. " - "It is going to try and convert the class into a dataclass, which may have " - "undesirable side effects. To avoid this message, make sure all your classes and " - "all the classes of their fields are either explicitly supported by " - "marshmallow_dataclass, or define the schema explicitly using " - "field(metadata=dict(marshmallow_field=...)). For more information, see " - "https://github.com/lovasoa/marshmallow_dataclass/issues/51 " - "****** WARNING ******" - ) - created_dataclass: type = dataclasses.dataclass(clazz) - return _internal_class_schema(created_dataclass, base_schema, clazz_frame) - except Exception as exc: - raise TypeError( - f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." - ) from exc - - # Copy all marshmallow hooks and whitelisted members of the dataclass to the schema. - attributes = { - k: v - for k, v in inspect.getmembers(clazz) - if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST - } + i = self.parameters.index(key) + except ValueError: + raise KeyError(key) from None + return self.args[i] - # Update the schema members to contain marshmallow fields instead of dataclass fields - type_hints = get_type_hints( - clazz, localns=clazz_frame.f_locals if clazz_frame else None - ) - attributes.update( - ( - field.name, - field_for_schema( - type_hints[field.name], - _get_field_default(field), - field.metadata, - base_schema, - clazz_frame, - ), - ) - for field in fields - if field.init - ) + def __iter__(self) -> Iterator[_TypeVarType]: + return iter(self.parameters) - schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) - return cast(Type[marshmallow.Schema], schema_class) + def __len__(self) -> int: + return len(self.parameters) + def compose(self, other: "_TypeVarBindings") -> "_TypeVarBindings": + """Compose TypeVar bindings. -def _field_by_type( - typ: Union[type, Any], base_schema: Optional[Type[marshmallow.Schema]] -) -> Optional[Type[marshmallow.fields.Field]]: - return ( - base_schema and base_schema.TYPE_MAPPING.get(typ) - ) or marshmallow.Schema.TYPE_MAPPING.get(typ) + Given: + def map(bindings, spec): + return bindings.get(spec, spec) -def _field_by_supertype( - typ: Type, - default: Any, - newtype_supertype: Type, - metadata: dict, - base_schema: Optional[Type[marshmallow.Schema]], - typ_frame: Optional[types.FrameType], -) -> marshmallow.fields.Field: - """ - Return a new field for fields based on a super field. (Usually spawned from NewType) - """ - # Add the information coming our custom NewType implementation - - typ_args = getattr(typ, "_marshmallow_args", {}) - - # Handle multiple validators from both `typ` and `metadata`. - # See https://github.com/lovasoa/marshmallow_dataclass/issues/91 - new_validators: List[Callable] = [] - for meta_dict in (typ_args, metadata): - if "validate" in meta_dict: - if marshmallow.utils.is_iterable_but_not_string(meta_dict["validate"]): - new_validators.extend(meta_dict["validate"]) - elif callable(meta_dict["validate"]): - new_validators.append(meta_dict["validate"]) - metadata["validate"] = new_validators if new_validators else None - - metadata = {**typ_args, **metadata} - metadata.setdefault("metadata", {}).setdefault("description", typ.__name__) - field = getattr(typ, "_marshmallow_field", None) - if field: - return field(**metadata) - else: - return field_for_schema( - newtype_supertype, - metadata=metadata, - default=default, - base_schema=base_schema, - typ_frame=typ_frame, + composed = outer.compose(inner) + + Then, for all values of spec: + + map(composed, spec) == map(outer, map(inner, spec)) + + """ + mapped_args = tuple( + self.get(arg, arg) if _is_type_var(arg) else arg for arg in other.args ) + return _TypeVarBindings(other.parameters, mapped_args) -def _generic_type_add_any(typ: type) -> type: - """if typ is generic type without arguments, replace them by Any.""" - if typ is list or typ is List: - typ = List[Any] - elif typ is dict or typ is Dict: - typ = Dict[Any, Any] - elif typ is Mapping: - typ = Mapping[Any, Any] - elif typ is Sequence: - typ = Sequence[Any] - elif typ is set or typ is Set: - typ = Set[Any] - elif typ is frozenset or typ is FrozenSet: - typ = FrozenSet[Any] - return typ - - -def _field_for_generic_type( - typ: type, - base_schema: Optional[Type[marshmallow.Schema]], - typ_frame: Optional[types.FrameType], - **metadata: Any, -) -> Optional[marshmallow.fields.Field]: - """ - If the type is a generic interface, resolve the arguments and construct the appropriate Field. +@dataclasses.dataclass +class _SchemaContext: + """Global context for an invocation of class_schema. + + The _SchemaContext is not thread-safe — methods on a given _SchemaContext + instance should only be invoked from a single thread. (Other threads + can safely work with their own _SchemaContext instances.) + """ - origin = typing_inspect.get_origin(typ) - arguments = typing_inspect.get_args(typ, True) - if origin: - # Override base_schema.TYPE_MAPPING to change the class used for generic types below - type_mapping = base_schema.TYPE_MAPPING if base_schema else {} - if origin in (list, List): - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) - list_type = cast( - Type[marshmallow.fields.List], - type_mapping.get(List, marshmallow.fields.List), - ) - return list_type(child_type, **metadata) - if origin in (collections.abc.Sequence, Sequence) or ( - origin in (tuple, Tuple) - and len(arguments) == 2 - and arguments[1] is Ellipsis - ): - from . import collection_field + globalns: Optional[Dict[str, Any]] = None + localns: Optional[Dict[str, Any]] = None + base_schema: Type[marshmallow.Schema] = marshmallow.Schema - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) - return collection_field.Sequence(cls_or_instance=child_type, **metadata) - if origin in (set, Set): - from . import collection_field + typevar_bindings: _TypeVarBindings = dataclasses.field( + init=False, default_factory=_TypeVarBindings + ) - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) - return collection_field.Set( - cls_or_instance=child_type, frozen=False, **metadata + @contextmanager + def bind_type_vars(self, bindings: _TypeVarBindings) -> Iterator[None]: + outer_bindings = self.typevar_bindings + try: + self.typevar_bindings = outer_bindings.compose(bindings) + yield + finally: + self.typevar_bindings = outer_bindings + + def get_type_mapping( + self, use_mro: bool = False + ) -> Mapping[TypeSpec, Type[marshmallow.fields.Field]]: + """Get base_schema.TYPE_MAPPING. + + If use_mro is true, then merges the TYPE_MAPPINGs from + all bases in base_schema's MRO. + """ + base_schema = self.base_schema + if use_mro: + return ChainMap( + *(getattr(cls, "TYPE_MAPPING", {}) for cls in base_schema.__mro__) ) - if origin in (frozenset, FrozenSet): - from . import collection_field + return getattr(base_schema, "TYPE_MAPPING", {}) + + # We use two caches: + # + # 1. A global LRU cache. This cache is solely for the sake of efficiency + # + # 2. A context-local cache. Note that a new context is created for each + # call to the public marshmallow_dataclass.class_schema function. + # This context-local cache exists in order to avoid infinite + # recursion when working on a cyclic dataclass. + # + _global_cache: ClassVar[MutableMapping[Hashable, Any]] + _global_cache = cachetools.LRUCache(MAX_CLASS_SCHEMA_CACHE_SIZE) + + def _global_cache_key(self, clazz: Hashable) -> Hashable: + return clazz, self.base_schema + + _local_cache: MutableMapping[Hashable, Any] = dataclasses.field( + init=False, default_factory=dict + ) - child_type = field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ) - return collection_field.Set( - cls_or_instance=child_type, frozen=True, **metadata - ) - if origin in (tuple, Tuple): - children = tuple( - field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame) - for arg in arguments - ) - tuple_type = cast( - Type[marshmallow.fields.Tuple], - type_mapping.get( # type:ignore[call-overload] - Tuple, marshmallow.fields.Tuple - ), + def _get_local_cache(self) -> MutableMapping[Hashable, Any]: + return self._local_cache + + @cachetools.cached( + cache=_global_cache, key=_global_cache_key, lock=threading.Lock() + ) + @cachetools.cachedmethod(cache=_get_local_cache) + def class_schema(self, clazz: Hashable) -> _Future[Type[marshmallow.Schema]]: + # insert future result into cache to prevent recursion + future: _Future[Type[marshmallow.Schema]] + future = self._local_cache.setdefault((clazz,), _Future()) + + constructor: Callable[..., object] + + if self.is_simple_annotated_class(clazz): + class_name = clazz.__name__ + constructor = _simple_class_constructor(clazz) + attributes = self.schema_attrs_for_simple_class(clazz) + elif _is_generic_alias_of_dataclass(clazz): + origin = get_origin(clazz) + assert _is_dataclass_type(origin) + class_name = origin.__name__ + constructor = origin + with self.bind_type_vars(_TypeVarBindings.from_generic_alias(clazz)): + attributes = self.schema_attrs_for_dataclass(origin) + elif _is_dataclass_type(clazz): + class_name = clazz.__name__ + constructor = clazz + attributes = self.schema_attrs_for_dataclass(clazz) + else: + raise InvalidClassError( + f"{clazz} is not a dataclass or a simple annotated class" ) - return tuple_type(children, **metadata) - elif origin in (dict, Dict, collections.abc.Mapping, Mapping): - dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) - return dict_type( - keys=field_for_schema( - arguments[0], base_schema=base_schema, typ_frame=typ_frame - ), - values=field_for_schema( - arguments[1], base_schema=base_schema, typ_frame=typ_frame - ), - **metadata, + + load_to_dict = self.base_schema.load + + def load( + self: marshmallow.Schema, + data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], + *, + many: Optional[bool] = None, + unknown: Optional[str] = None, + **kwargs: Any, + ) -> Any: + many = self.many if many is None else bool(many) + loaded = load_to_dict(self, data, many=many, unknown=unknown, **kwargs) + if many: + return [constructor(**item) for item in loaded] + else: + return constructor(**loaded) + + attributes["load"] = load + + schema_class: Type[marshmallow.Schema] = type( + f"{class_name}Schema", (self.base_schema,), attributes + ) + + future.set_result(schema_class) + return future + + def schema_attrs_for_dataclass(self, clazz: _DataclassType) -> Dict[str, Any]: + if _has_generic_base(clazz): + raise InvalidClassError( + "class_schema does not support dataclasses with generic base classes" ) - if typing_inspect.is_union_type(typ): - if typing_inspect.is_optional_type(typ): - metadata["allow_none"] = metadata.get("allow_none", True) - metadata["dump_default"] = metadata.get("dump_default", None) + + type_hints = get_type_hints(clazz, globalns=self.globalns, localns=self.localns) + attrs = dict(_marshmallow_hooks(clazz)) + for field in dataclasses.fields(clazz): + if field.init: + typ = type_hints[field.name] + default = ( + field.default_factory + if field.default_factory is not dataclasses.MISSING + else field.default + if field.default is not dataclasses.MISSING + else marshmallow.missing + ) + attrs[field.name] = self.field_for_schema(typ, default, field.metadata) + return attrs + + _SimpleClass = typing_NewType("_SimpleClass", type) + + def is_simple_annotated_class(self, obj: object) -> TypeGuard[_SimpleClass]: + """Determine whether obj is a "simple annotated class". + + The ```class_schema``` function can generate schemas for + simple annotated classes (as well as for dataclasses). + """ + if not isinstance(obj, type): + return False + if getattr(obj, "__init__", None) is not object.__init__: + return False + if getattr(obj, "__new__", None) is not object.__new__: + return False + + type_hints = get_type_hints(obj, globalns=self.globalns, localns=self.localns) + return any(not typing_inspect.is_classvar(th) for th in type_hints.values()) + + def schema_attrs_for_simple_class(self, clazz: _SimpleClass) -> Dict[str, Any]: + type_hints = get_type_hints(clazz, globalns=self.globalns, localns=self.localns) + + attrs = dict(_marshmallow_hooks(clazz)) + for field_name, typ in type_hints.items(): + if not typing_inspect.is_classvar(typ): + default = getattr(clazz, field_name, marshmallow.missing) + attrs[field_name] = self.field_for_schema(typ, default) + return attrs + + def field_for_schema( + self, + typ: object, + default: Any = marshmallow.missing, + metadata: Optional[Mapping[str, Any]] = None, + ) -> marshmallow.fields.Field: + """ + Get a marshmallow Field corresponding to the given python type. + The metadata of the dataclass field is used as arguments to the marshmallow Field. + + This is an internal version of field_for_schema. + + :param typ: The type for which a field should be generated + :param default: value to use for (de)serialization when the field is missing + :param metadata: Additional parameters to pass to the marshmallow field constructor + + """ + + if _is_type_var(typ): + type_spec = self.typevar_bindings.get(typ, typ) + if _is_type_var(type_spec): + raise UnboundTypeVarError( + f"can not resolve type variable {type_spec.__name__}" + ) + return self.field_for_schema(type_spec, default, metadata) + + metadata = {} if metadata is None else dict(metadata) + + # If the field was already defined by the user + predefined_field = metadata.get("marshmallow_field") + if predefined_field: + if not isinstance(predefined_field, marshmallow.fields.Field): + raise TypeError( + "metadata['marshmallow_field'] must be set to a Field instance, " + f"not {predefined_field}" + ) + return predefined_field + + if default is not marshmallow.missing: + metadata.setdefault("dump_default", default) + # 'missing' must not be set for required fields. if not metadata.get("required"): - metadata["load_default"] = metadata.get("load_default", None) - metadata.setdefault("required", False) - subtypes = [t for t in arguments if t is not NoneType] # type: ignore - if len(subtypes) == 1: - return field_for_schema( - subtypes[0], + metadata.setdefault("load_default", default) + else: + metadata.setdefault("required", not typing_inspect.is_optional_type(typ)) + + if _is_builtin_collection_type(typ): + return self.field_for_builtin_collection_type(typ, metadata) + + # Base types + type_mapping = self.get_type_mapping(use_mro=True) + field = type_mapping.get(typ) + if field is not None: + return field(**metadata) + + if typ is Any: + metadata.setdefault("allow_none", True) + return marshmallow.fields.Raw(**metadata) + + if typing_inspect.is_literal_type(typ): + return self.field_for_literal_type(typ, metadata) + + if typing_inspect.is_final_type(typ): + return self.field_for_schema( + _get_subtype_for_final_type(typ, default), + default=default, metadata=metadata, - base_schema=base_schema, - typ_frame=typ_frame, ) - from . import union_field - return union_field.Union( - [ - ( - subtyp, - field_for_schema( - subtyp, - metadata={"required": True}, - base_schema=base_schema, - typ_frame=typ_frame, - ), - ) - for subtyp in subtypes - ], - **metadata, - ) - return None + if typing_inspect.is_union_type(typ): + return self.field_for_union_type(typ, metadata) + if _is_new_type(typ): + return self.field_for_new_type(typ, default, metadata) -def field_for_schema( - typ: type, - default=marshmallow.missing, - metadata: Optional[Mapping[str, Any]] = None, - base_schema: Optional[Type[marshmallow.Schema]] = None, - typ_frame: Optional[types.FrameType] = None, -) -> marshmallow.fields.Field: - """ - Get a marshmallow Field corresponding to the given python type. - The metadata of the dataclass field is used as arguments to the marshmallow Field. + # enumerations + if isinstance(typ, type) and issubclass(typ, Enum): + return self.field_for_enum(typ, metadata) - :param typ: The type for which a field should be generated - :param default: value to use for (de)serialization when the field is missing - :param metadata: Additional parameters to pass to the marshmallow field constructor - :param base_schema: marshmallow schema used as a base class when deriving dataclass schema - :param typ_frame: frame of type definition + # nested dataclasses + if ( + _is_dataclass_type(typ) + or _is_generic_alias_of_dataclass(typ) + or self.is_simple_annotated_class(typ) + ): + nested = self.schema_for_nested(typ) + # type spec for Nested.__init__ is not correct + return marshmallow.fields.Nested(nested, **metadata) # type: ignore[arg-type] + + raise UnrecognizedFieldTypeError(f"can not deduce field type for {typ}") + + def field_for_builtin_collection_type( + self, typ: object, metadata: Dict[str, Any] + ) -> marshmallow.fields.Field: + """ + Handle builtin container types like list, tuple, set, etc. + """ + origin = get_origin(typ) + if origin is None: + origin = typ + assert len(get_args(typ)) == 0 + + args = get_args(typ) + + if origin is tuple and ( + len(args) == 0 or (len(args) == 2 and args[1] is Ellipsis) + ): + # Special case: homogeneous tuple — treat as Sequence + origin = collections.abc.Sequence + args = args[:1] - >>> int_field = field_for_schema(int, default=9, metadata=dict(required=True)) - >>> int_field.__class__ - + # Override base_schema.TYPE_MAPPING to change the class used for generic types below + def get_field_type(type_spec: TypeSpec, default: Type[_Field]) -> Type[_Field]: + type_mapping = self.get_type_mapping(use_mro=False) + return type_mapping.get(type_spec, default) # type: ignore[return-value] - >>> int_field.dump_default - 9 + def get_field(i: int) -> marshmallow.fields.Field: + return self.field_for_schema(args[i] if args else Any) - >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ - - """ + if origin is tuple: + tuple_fields = tuple(self.field_for_schema(arg) for arg in args) + tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple) + return tuple_type(tuple_fields, **metadata) - metadata = {} if metadata is None else dict(metadata) + if origin in (dict, collections.abc.Mapping): + dict_type = get_field_type(Dict, default=marshmallow.fields.Dict) + return dict_type(keys=get_field(0), values=get_field(1), **metadata) - if default is not marshmallow.missing: - metadata.setdefault("dump_default", default) - # 'missing' must not be set for required fields. - if not metadata.get("required"): - metadata.setdefault("load_default", default) - else: - metadata.setdefault("required", not typing_inspect.is_optional_type(typ)) + if origin is list: + list_type = get_field_type(List, default=marshmallow.fields.List) + return list_type(get_field(0), **metadata) - # If the field was already defined by the user - predefined_field = metadata.get("marshmallow_field") - if predefined_field: - return predefined_field + if origin is collections.abc.Sequence: + from . import collection_field - # Generic types specified without type arguments - typ = _generic_type_add_any(typ) + return collection_field.Sequence(get_field(0), **metadata) - # Base types - field = _field_by_type(typ, base_schema) - if field: - return field(**metadata) + if origin in (set, frozenset): + from . import collection_field - if typ is Any: - metadata.setdefault("allow_none", True) - return marshmallow.fields.Raw(**metadata) + frozen = origin is frozenset + return collection_field.Set(get_field(0), frozen=frozen, **metadata) - if typing_inspect.is_literal_type(typ): - arguments = typing_inspect.get_args(typ) - return marshmallow.fields.Raw( - validate=( - marshmallow.validate.Equal(arguments[0]) - if len(arguments) == 1 - else marshmallow.validate.OneOf(arguments) - ), + raise ValueError(f"{typ} is not a builtin collection type") + + def field_for_union_type( + self, typ: object, metadata: Dict[str, Any] + ) -> marshmallow.fields.Field: + """ + Construct the appropriate Field for a union or optional type. + """ + assert typing_inspect.is_union_type(typ) + subtypes = [t for t in get_args(typ) if t is not NoneType] + + if typing_inspect.is_optional_type(typ): + metadata = { + "allow_none": True, + "dump_default": None, + **metadata, + } + if not metadata.setdefault("required", False): + metadata.setdefault("load_default", None) + + if len(subtypes) == 1: + return self.field_for_schema(subtypes[0], metadata=metadata) + + from . import union_field + + return union_field.Union( + [ + (typ, self.field_for_schema(typ, metadata={"required": True})) + for typ in subtypes + ], **metadata, ) - if typing_inspect.is_final_type(typ): + @staticmethod + def field_for_literal_type( + typ: object, metadata: Dict[str, Any] + ) -> marshmallow.fields.Field: + """ + Construct the appropriate Field for a Literal type. + """ + validate: marshmallow.validate.Validator + + assert typing_inspect.is_literal_type(typ) arguments = typing_inspect.get_args(typ) - if arguments: - subtyp = arguments[0] - elif default is not marshmallow.missing: - if callable(default): - subtyp = Any - warnings.warn( - "****** WARNING ****** " - "marshmallow_dataclass was called on a dataclass with an " - 'attribute that is type-annotated with "Final" and uses ' - "dataclasses.field for specifying a default value using a " - "factory. The Marshmallow field type cannot be inferred from the " - "factory and will fall back to a raw field which is equivalent to " - 'the type annotation "Any" and will result in no validation. ' - "Provide a type to Final[...] to ensure accurate validation. " - "****** WARNING ******" - ) - else: - subtyp = type(default) - warnings.warn( - "****** WARNING ****** " - "marshmallow_dataclass was called on a dataclass with an " - 'attribute that is type-annotated with "Final" with a default ' - "value from which the Marshmallow field type is inferred. " - "Support for type inference from a default value is limited and " - "may result in inaccurate validation. Provide a type to " - "Final[...] to ensure accurate validation. " - "****** WARNING ******" - ) + if len(arguments) == 1: + validate = marshmallow.validate.Equal(arguments[0]) else: - subtyp = Any - return field_for_schema(subtyp, default, metadata, base_schema, typ_frame) - - # Generic types - generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata) - if generic_field: - return generic_field - - # typing.NewType returns a function (in python <= 3.9) or a class (python >= 3.10) with a - # __supertype__ attribute - newtype_supertype = getattr(typ, "__supertype__", None) - if typing_inspect.is_new_type(typ) and newtype_supertype is not None: - return _field_by_supertype( - typ=typ, + validate = marshmallow.validate.OneOf(arguments) + return marshmallow.fields.Raw(validate=validate, **metadata) + + def field_for_new_type( + self, new_type: _NewType, default: Any, metadata: Mapping[str, Any] + ) -> marshmallow.fields.Field: + """ + Return a new field for fields based on a NewType. + """ + # Add the information coming our custom NewType implementation + + # Handle multiple validators from both `typ` and `metadata`. + # See https://github.com/lovasoa/marshmallow_dataclass/issues/91 + merged_metadata = _merge_metadata( + getattr(new_type, "_marshmallow_args", {}), + metadata, + ) + merged_metadata.setdefault("metadata", {}).setdefault( + "description", new_type.__name__ + ) + + field: Optional[Type[marshmallow.fields.Field]] = getattr( + new_type, "_marshmallow_field", None + ) + if field is not None: + return field(**merged_metadata) + + return self.field_for_schema( + new_type.__supertype__, default=default, - newtype_supertype=newtype_supertype, - metadata=metadata, - base_schema=base_schema, - typ_frame=typ_frame, + metadata=merged_metadata, ) - # enumerations - if issubclass(typ, Enum): - try: + @staticmethod + def field_for_enum(typ: type, metadata: Dict[str, Any]) -> marshmallow.fields.Field: + """ + Return a new field for an Enum field. + """ + if sys.version_info >= (3, 7): return marshmallow.fields.Enum(typ, **metadata) - except AttributeError: + else: # Remove this once support for python 3.6 is dropped. import marshmallow_enum return marshmallow_enum.EnumField(typ, **metadata) - # Nested marshmallow dataclass - # it would be just a class name instead of actual schema util the schema is not ready yet - nested_schema = getattr(typ, "Schema", None) + def schema_for_nested( + self, typ: object + ) -> Union[Type[marshmallow.Schema], Callable[[], Type[marshmallow.Schema]]]: + """ + Return a marshmallow.Schema for a nested dataclass (or simple annotated class) + """ + if isinstance(typ, type) and hasattr(typ, "Schema"): + # marshmallow_dataclass.dataclass + # Defer evaluation of .Schema attribute, to avoid forward reference issues + return partial(getattr, typ, "Schema") - # Nested dataclasses - forward_reference = getattr(typ, "__forward_arg__", None) + future = self.class_schema(typ) + deferred = future.result + return deferred() if future.done() else deferred - nested = ( - nested_schema - or forward_reference - or _RECURSION_GUARD.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema, typ_frame) # type: ignore [arg-type] - ) - return marshmallow.fields.Nested(nested, **metadata) +def _merge_metadata(*args: Mapping[str, Any]) -> Dict[str, Any]: + """Merge mutiple metadata mappings into a single dict. + This is a standard dict merge, except that the "validate" field + is handled specially: validators specified in any of the args + are combined. -def _base_schema( - clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None -) -> Type[marshmallow.Schema]: - """ - Base schema factory that creates a schema for `clazz` derived either from `base_schema` - or `BaseSchema` """ + merged: Dict[str, Any] = {} + validators: List[Callable[[Any], Any]] = [] + + for metadata in args: + merged.update(metadata) + validate = metadata.get("validate") + if callable(validate): + validators.append(validate) + elif marshmallow.utils.is_iterable_but_not_string(validate): + assert isinstance(validate, Iterable) + validators.extend(validate) + elif validate is not None: + validators.append(validate) + + if not all(callable(validate) for validate in validators): + raise ValueError( + "the 'validate' parameter must be a callable or a collection of callables." + ) + + merged["validate"] = validators if validators else None + return merged - # Remove `type: ignore` when mypy handles dynamic base classes - # https://github.com/python/mypy/issues/2813 - class BaseSchema(base_schema or marshmallow.Schema): # type: ignore - def load(self, data: Mapping, *, many: Optional[bool] = None, **kwargs): - all_loaded = super().load(data, many=many, **kwargs) - many = self.many if many is None else bool(many) - if many: - return [clazz(**loaded) for loaded in all_loaded] - else: - return clazz(**all_loaded) - return BaseSchema +def _marshmallow_hooks(clazz: type) -> Iterator[Tuple[str, Any]]: + for name, attr in inspect.getmembers(clazz): + if hasattr(attr, "__marshmallow_hook__") or name in MEMBERS_WHITELIST: + yield name, attr -def _get_field_default(field: dataclasses.Field): +def _simple_class_constructor(clazz: Type[_U]) -> Callable[..., _U]: + def constructor(**kwargs: Any) -> _U: + obj = clazz.__new__(clazz) + for k, v in kwargs.items(): + setattr(obj, k, v) + return obj + + return constructor + + +def _is_builtin_collection_type(typ: object) -> bool: + origin = get_origin(typ) + if origin is None: + origin = typ + + return origin in { + list, + collections.abc.Sequence, + set, + frozenset, + tuple, + dict, + collections.abc.Mapping, + } + + +def _get_subtype_for_final_type(typ: object, default: Any) -> object: """ - Return a marshmallow default value given a dataclass default value + Construct the appropriate Field for a Final type. + """ + assert typing_inspect.is_final_type(typ) + arguments = typing_inspect.get_args(typ) + if arguments: + return arguments[0] + elif default is marshmallow.missing: + return Any + elif callable(default): + warnings.warn( + "****** WARNING ****** " + "marshmallow_dataclass was called on a dataclass with an " + 'attribute that is type-annotated with "Final" and uses ' + "dataclasses.field for specifying a default value using a " + "factory. The Marshmallow field type cannot be inferred from the " + "factory and will fall back to a raw field which is equivalent to " + 'the type annotation "Any" and will result in no validation. ' + "Provide a type to Final[...] to ensure accurate validation. " + "****** WARNING ******" + ) + return Any + warnings.warn( + "****** WARNING ****** " + "marshmallow_dataclass was called on a dataclass with an " + 'attribute that is type-annotated with "Final" with a default ' + "value from which the Marshmallow field type is inferred. " + "Support for type inference from a default value is limited and " + "may result in inaccurate validation. Provide a type to " + "Final[...] to ensure accurate validation. " + "****** WARNING ******" + ) + return type(default) + - >>> _get_field_default(dataclasses.field()) - +def field_for_schema( + typ: object, + default: Any = marshmallow.missing, + metadata: Optional[Mapping[str, Any]] = None, + base_schema: Optional[Type[marshmallow.Schema]] = None, + # FIXME: delete typ_frame from API? + typ_frame: Optional[types.FrameType] = None, +) -> marshmallow.fields.Field: + """ + Get a marshmallow Field corresponding to the given python type. + The metadata of the dataclass field is used as arguments to the marshmallow Field. + + :param typ: The type for which a field should be generated + :param default: value to use for (de)serialization when the field is missing + :param metadata: Additional parameters to pass to the marshmallow field constructor + :param base_schema: marshmallow schema used as a base class when deriving dataclass schema + :param typ_frame: frame of type definition + + >>> int_field = field_for_schema(int, default=9, metadata=dict(required=True)) + >>> int_field.__class__ + + + >>> int_field.dump_default + 9 + + >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__ + """ - # Remove `type: ignore` when https://github.com/python/mypy/issues/6910 is fixed - default_factory = field.default_factory # type: ignore - if default_factory is not dataclasses.MISSING: - return default_factory - elif field.default is dataclasses.MISSING: - return marshmallow.missing - return field.default + if base_schema is None: + base_schema = marshmallow.Schema + localns = typ_frame.f_locals if typ_frame is not None else None + schema_ctx = _SchemaContext(localns=localns, base_schema=base_schema) + return schema_ctx.field_for_schema(typ, default, metadata) def NewType( name: str, typ: Type[_U], field: Optional[Type[marshmallow.fields.Field]] = None, - **kwargs, -) -> Callable[[_U], _U]: + **kwargs: Any, +) -> type: """NewType creates simple unique types to which you can attach custom marshmallow attributes. All the keyword arguments passed to this function will be transmitted @@ -817,9 +1290,9 @@ def NewType( # noinspection PyTypeHints new_type = typing_NewType(name, typ) # type: ignore # noinspection PyTypeHints - new_type._marshmallow_field = field # type: ignore + new_type._marshmallow_field = field # noinspection PyTypeHints - new_type._marshmallow_args = kwargs # type: ignore + new_type._marshmallow_args = kwargs return new_type diff --git a/marshmallow_dataclass/collection_field.py b/marshmallow_dataclass/collection_field.py index 6823b72..0f9d4c2 100644 --- a/marshmallow_dataclass/collection_field.py +++ b/marshmallow_dataclass/collection_field.py @@ -29,16 +29,18 @@ class Set(marshmallow.fields.List): will be random. So if the order matters, use a List or Sequence ! """ + set_type: typing.Type[ + typing.Union[typing.FrozenSet[typing.Any], typing.Set[typing.Any]] + ] + def __init__( self, cls_or_instance: typing.Union[marshmallow.fields.Field, type], frozen: bool = False, - **kwargs, + **kwargs: typing.Any, ): super().__init__(cls_or_instance, **kwargs) - self.set_type: typing.Type[typing.Union[frozenset, set]] = ( - frozenset if frozen else set - ) + self.set_type = frozenset if frozen else set def _deserialize( # type: ignore[override] self, diff --git a/marshmallow_dataclass/lazy_class_attribute.py b/marshmallow_dataclass/lazy_class_attribute.py index 2dbe4a4..0555d7d 100644 --- a/marshmallow_dataclass/lazy_class_attribute.py +++ b/marshmallow_dataclass/lazy_class_attribute.py @@ -1,45 +1,39 @@ -from typing import Any, Callable, Optional +import threading +from typing import Callable, Generic, Optional, TypeVar __all__ = ("lazy_class_attribute",) -class LazyClassAttribute: - """Descriptor decorator implementing a class-level, read-only - property, which caches its results on the class(es) on which it - operates. - """ +_T_co = TypeVar("_T_co", covariant=True) - __slots__ = ("func", "name", "called", "forward_value") - def __init__( - self, - func: Callable[..., Any], - name: Optional[str] = None, - forward_value: Any = None, - ): - self.func = func - self.name = name - self.called = False - self.forward_value = forward_value +class LazyClassAttribute(Generic[_T_co]): + """Descriptor implementing a cached class property.""" - def __get__(self, instance, cls=None): - if not cls: - cls = type(instance) - - # avoid recursion - if self.called: - return self.forward_value + __slots__ = ("fget", "attr_name", "rlock", "called_from") - self.called = True + def __init__(self, fget: Callable[[], _T_co], attr_name: str): + self.fget = fget + self.attr_name = attr_name + self.rlock = threading.RLock() + self.called_from: Optional[threading.Thread] = None - setattr(cls, self.name, self.func()) - - # "getattr" is used to handle bounded methods - return getattr(cls, self.name) + def __get__(self, instance: object, cls: Optional[type] = None) -> _T_co: + if not cls: + cls = type(instance) - def __set_name__(self, owner, name): - self.name = self.name or name + with self.rlock: + if self.called_from is not None: + if self.called_from is not threading.current_thread(): + return getattr(cls, self.attr_name) # type: ignore[no-any-return] + raise AttributeError( + f"recursive evaluation of {cls.__name__}.{self.attr_name}" + ) + self.called_from = threading.current_thread() + value = self.fget() + setattr(cls, self.attr_name, value) + return value lazy_class_attribute = LazyClassAttribute diff --git a/marshmallow_dataclass/mypy.py b/marshmallow_dataclass/mypy.py index d33a5ad..a45cb19 100644 --- a/marshmallow_dataclass/mypy.py +++ b/marshmallow_dataclass/mypy.py @@ -2,8 +2,10 @@ from typing import Callable, Optional, Type from mypy import nodes -from mypy.plugin import DynamicClassDefContext, Plugin +from mypy.plugin import ClassDefContext, DynamicClassDefContext, Plugin from mypy.plugins import dataclasses +from mypy.plugins.common import add_attribute_to_class +from mypy.types import AnyType, TypeOfAny, TypeType import marshmallow_dataclass @@ -22,11 +24,30 @@ def get_dynamic_class_hook( return new_type_hook return None - def get_class_decorator_hook(self, fullname: str): + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: if fullname == "marshmallow_dataclass.dataclass": - return dataclasses.dataclass_class_maker_callback + return dataclasses.dataclass_tag_callback return None + def get_class_decorator_hook_2( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], bool]]: + if fullname == "marshmallow_dataclass.dataclass": + return class_decorator_hook + return None + + +def class_decorator_hook(ctx: ClassDefContext) -> bool: + if not dataclasses.dataclass_class_maker_callback(ctx): + return False + any_type = AnyType(TypeOfAny.explicit) + schema_type = ctx.api.named_type_or_none("marshmallow.Schema") or any_type + schema_type_type = TypeType.make_normalized(schema_type) + add_attribute_to_class(ctx.api, ctx.cls, "Schema", schema_type_type) + return True + def new_type_hook(ctx: DynamicClassDefContext) -> None: """ @@ -66,6 +87,6 @@ def _get_arg_by_name( except TypeError: return None try: - return bound_args.arguments[name] + return bound_args.arguments[name] # type: ignore[no-any-return] except KeyError: return None diff --git a/marshmallow_dataclass/union_field.py b/marshmallow_dataclass/union_field.py index 6e87e29..c7875ec 100644 --- a/marshmallow_dataclass/union_field.py +++ b/marshmallow_dataclass/union_field.py @@ -1,5 +1,5 @@ import copy -from typing import List, Tuple, Any, Optional +from typing import Any, List, Mapping, Optional, Tuple import typeguard from marshmallow import fields, Schema, ValidationError @@ -26,21 +26,23 @@ class Union(fields.Field): :param kwargs: The same keyword arguments that :class:`Field` receives. """ - def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs): + def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs: Any): super().__init__(**kwargs) self.union_fields = union_fields def _bind_to_schema(self, field_name: str, schema: Schema) -> None: - super()._bind_to_schema(field_name, schema) + super()._bind_to_schema(field_name, schema) # type: ignore[no-untyped-call] new_union_fields = [] for typ, field in self.union_fields: field = copy.deepcopy(field) - field._bind_to_schema(field_name, self) + field._bind_to_schema(field_name, self) # type: ignore[no-untyped-call] new_union_fields.append((typ, field)) self.union_fields = new_union_fields - def _serialize(self, value: Any, attr: Optional[str], obj, **kwargs) -> Any: + def _serialize( + self, value: Any, attr: Optional[str], obj: Any, **kwargs: Any + ) -> Any: errors = [] if value is None: return value @@ -56,7 +58,13 @@ def _serialize(self, value: Any, attr: Optional[str], obj, **kwargs) -> Any: f"Unable to serialize value with any of the fields in the union: {errors}" ) - def _deserialize(self, value: Any, attr: Optional[str], data, **kwargs) -> Any: + def _deserialize( + self, + value: Any, + attr: Optional[str], + data: Optional[Mapping[str, Any]], + **kwargs: Any, + ) -> Any: errors = [] for typ, field in self.union_fields: try: diff --git a/mypy_plugin.py b/mypy_plugin.py new file mode 100644 index 0000000..f76a199 --- /dev/null +++ b/mypy_plugin.py @@ -0,0 +1,64 @@ +"""Shim to load the marshmallow_dataclass.mypy plugin. + +This shim is needed when running mypy from pre-commit. + +Pre-commit runs mypy from its own venv (into which we do not want +to install marshmallow_dataclass). Because of this, loading the plugin +by module name, e.g. + + [tool.mypy] + plugins = "marshmallow_dataclass.mypy" + +does not work. Mypy also supports specifying a path to the plugin +module source, which would normally get us out of this bind, however, +the fact that our plugin is in a file named "mypy.py" causes issues. + +If we set + + [tool.mypy] + plugins = "marshmallow_dataclass/mypy.py" + +mypy `attempts to load`__ the plugin module by temporarily prepending + ``marshmallow_dataclass`` to ``sys.path`` then importing the ``mypy`` +module. Sadly, mypy's ``mypy`` module has already been imported, +so this doesn't end well. + +__ https://github.com/python/mypy/blob/914901f14e0e6223077a8433388c367138717451/mypy/build.py#L450 + + +Our solution, here, is to manually load the plugin module (with a better +``sys.path``, and import the ``plugin`` from the real plugin module into this one. + +Now we can configure mypy to load this file, by path. + + [tool.mypy] + plugins = "mypy_plugin.py" + +""" +import importlib +import sys +from os import fspath +from pathlib import Path +from typing import Type +from warnings import warn + +from mypy.plugin import Plugin + + +def null_plugin(version: str) -> Type[Plugin]: + """A fallback do-nothing plugin hook""" + return Plugin + + +module_name = "marshmallow_dataclass.mypy" + +src = fspath(Path(__file__).parent) +sys.path.insert(0, src) +try: + plugin_module = importlib.import_module(module_name) + plugin = plugin_module.plugin +except Exception as exc: + warn(f"can not load {module_name} plugin: {exc}") + plugin = null_plugin +finally: + del sys.path[0] diff --git a/pyproject.toml b/pyproject.toml index 4112bd1..9364c51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,3 +6,46 @@ target-version = ['py36', 'py37', 'py38', 'py39', 'py310', 'py310'] filterwarnings = [ "error:::marshmallow_dataclass|test", ] + +[tool.coverage.report] +omit = [ + # pytest-mypy-plugins run mypy plugin tests get run in a subprocess, + # so we don't get coverage data + "marshmallow_dataclass/mypy.py", +] +exclude_lines = [ + "pragma: no cover", + '^\s*\.\.\.\s*$', +] + +[tool.mypy] +packages = [ + "marshmallow_dataclass", + "tests", +] +# XXX: Specifying the marshmallow_dataclass.mypy plugin directly by +# module name or by path does not work when running mypy from pre-commit. +# (See the docstring in mypy_plugin.py for more.) +plugins = "mypy_plugin.py" + +strict = true +warn_unreachable = true + +[[tool.mypy.overrides]] +# dependencies without type hints +module = [ + "marshmallow_enum", + "typing_inspect", +] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = [ + "tests.*", +] +disable_error_code = "annotation-unchecked" +check_untyped_defs = false +disallow_untyped_calls = false +disallow_untyped_defs = false +disallow_incomplete_defs = false + diff --git a/setup.cfg b/setup.cfg index 1216d89..d5a376b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,3 @@ ignore = E203, E266, E501, W503 max-line-length = 100 max-complexity = 18 select = B,C,E,F,W,T4,B9 - -[mypy] -ignore_missing_imports = true diff --git a/setup.py b/setup.py index 325b350..5f7a1bd 100644 --- a/setup.py +++ b/setup.py @@ -28,13 +28,10 @@ "docs": ["sphinx"], "tests": [ "pytest>=5.4", + "types-cachetools", # re: pypy: typed-ast (a dependency of mypy) fails to install on pypy # https://github.com/python/typed_ast/issues/111 "pytest-mypy-plugins>=1.2.0; implementation_name != 'pypy'", - # `Literal` was introduced in: - # - Python 3.8 (https://www.python.org/dev/peps/pep-0586) - # - typing-extensions 3.7.2 (https://github.com/python/typing/pull/591) - "typing-extensions>=3.7.2; python_version < '3.8'", ], } EXTRAS_REQUIRE["dev"] = ( @@ -62,8 +59,10 @@ license="MIT", python_requires=">=3.6", install_requires=[ + "cachetools>=4.2.4,<6.0", "marshmallow>=3.13.0,<4.0", "typing-inspect>=0.8.0", + "typing-extensions>=3.10; python_version < '3.8'", ], extras_require=EXTRAS_REQUIRE, package_data={"marshmallow_dataclass": ["py.typed"]}, diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index aa82975..3ac8efb 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -1,20 +1,27 @@ import inspect +import sys import typing import unittest -from typing import Any, cast, TYPE_CHECKING +from typing import Any, cast from uuid import UUID -try: - from typing import Final, Literal # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Final, Literal # type: ignore[assignment] +if sys.version_info >= (3, 8): + from typing import Final, Literal +else: + from typing_extensions import Final, Literal import dataclasses from marshmallow import Schema, ValidationError from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer from marshmallow.validate import Validator -from marshmallow_dataclass import class_schema, NewType +from marshmallow_dataclass import ( + add_schema, + class_schema, + NewType, + UnboundTypeVarError, + _is_generic_alias_of_dataclass, +) class TestClassSchema(unittest.TestCase): @@ -226,21 +233,14 @@ class A: with self.assertRaises(ValidationError): schema.load({"data": data}) - def test_final_infers_type_from_default(self): - # @dataclasses.dataclass(frozen=True) + def test_final_infers_type_from_default(self) -> None: + @dataclasses.dataclass(frozen=True) class A: data: Final = "a" - # @dataclasses.dataclass + @dataclasses.dataclass class B: - data: Final = A() - - # NOTE: This workaround is needed to avoid a Mypy crash. - # See: https://github.com/python/mypy/issues/10090#issuecomment-865971891 - if not TYPE_CHECKING: - frozen_dataclass = dataclasses.dataclass(frozen=True) - A = frozen_dataclass(A) - B = dataclasses.dataclass(B) + data: Final = A() # type: ignore[misc] with self.assertWarns(Warning): schema_a = class_schema(A)() @@ -269,14 +269,9 @@ class B: schema_b.load({"data": data}) def test_final_infers_type_any_from_field_default_factory(self): - # @dataclasses.dataclass + @dataclasses.dataclass class A: - data: Final = dataclasses.field(default_factory=lambda: []) - - # NOTE: This workaround is needed to avoid a Mypy crash. - # See: https://github.com/python/mypy/issues/10090#issuecomment-866686096 - if not TYPE_CHECKING: - A = dataclasses.dataclass(A) + data: Final = dataclasses.field(default_factory=lambda: []) # type: ignore[misc] with self.assertWarns(Warning): schema = class_schema(A)() @@ -401,6 +396,173 @@ class J: [validator_a, validator_b, validator_c, validator_d], ) + def test_simple_annotated_class(self): + class Child: + x: int + + @dataclasses.dataclass + class Container: + child: Child + + schema = class_schema(Container)() + + loaded = schema.load({"child": {"x": "42"}}) + self.assertEqual(loaded.child.x, 42) + + def test_generic_dataclass(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class NestedFixed: + data: SimpleGeneric[int] + + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data": {"data": "str"}}) + + def test_generic_dataclass_repeated_fields(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class AA: + a: int + + @dataclasses.dataclass + class BB(typing.Generic[T]): + b: T + + @dataclasses.dataclass + class Nested: + y: BB[AA] + x: BB[float] + z: BB[float] + # if y is the first field in this class, deserialisation will fail. + # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 + + schema_nested = class_schema(Nested)() + self.assertEqual( + Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))), + schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), + ) + + def test_marshmallow_dataclass_decorator_raises_on_generics(self): + import marshmallow_dataclass + + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + marshmallow_dataclass.dataclass(GenClass) + + with self.assertRaisesRegex(TypeError, "generic"): + marshmallow_dataclass.dataclass(GenClass[int]) + + def test_add_schema_raises_on_generics(self): + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + add_schema(GenClass) + + with self.assertRaisesRegex(TypeError, "generic"): + add_schema(GenClass[int]) + + def test_deep_generic(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + pairs: typing.List[typing.Tuple[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)]) + ) + + def test_generic_bases(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic base class"): + class_schema(TestClass[int]) + + def test_bound_generic_base(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[int]): + pass + + with self.assertRaisesRegex(TypeError, "generic base class"): + class_schema(TestClass) + + def test_unbound_type_var(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base: + answer: T # type: ignore[valid-type] + + with self.assertRaises(UnboundTypeVarError): + class_schema(Base) + + with self.assertRaises(TypeError): + class_schema(Base) + def test_recursive_reference(self): @dataclasses.dataclass class Tree: diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index 0e60f0b..e4bea21 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -3,12 +3,12 @@ import typing import unittest from enum import Enum -from typing import Dict, Optional, Union, Any, List, Tuple +from typing import Dict, Optional, Union, Any, List, Tuple, Iterable -try: - from typing import Final, Literal # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Final, Literal # type: ignore[assignment] +if sys.version_info >= (3, 8): + from typing import Final, Literal +else: + from typing_extensions import Final, Literal from marshmallow import fields, Schema, validate @@ -21,14 +21,18 @@ class TestFieldForSchema(unittest.TestCase): - def assertFieldsEqual(self, a: fields.Field, b: fields.Field): + def assertFieldsEqual( + self, a: fields.Field, b: fields.Field, *, ignore_attrs: Iterable[str] = () + ) -> None: + ignored = set(ignore_attrs) + self.assertEqual(a.__class__, b.__class__, "field class") def attrs(x): return { k: f"{v!r} ({v.__mro__!r})" if inspect.isclass(v) else repr(v) for k, v in x.__dict__.items() - if not k.startswith("_") + if not (k in ignored or k.startswith("_")) } self.assertEqual(attrs(a), attrs(b)) @@ -213,10 +217,12 @@ class NewSchema(Schema): class NewDataclass: pass + field = field_for_schema(NewDataclass, metadata=dict(required=False)) + self.assertFieldsEqual( - field_for_schema(NewDataclass, metadata=dict(required=False)), - fields.Nested(NewDataclass.Schema), + field, fields.Nested(NewDataclass.Schema), ignore_attrs=["nested"] ) + self.assertIs(type(field.schema), NewDataclass.Schema) def test_override_container_type_with_type_mapping(self): type_mapping = [ diff --git a/tests/test_forward_references.py b/tests/test_forward_references.py index fc05b12..2a2fa96 100644 --- a/tests/test_forward_references.py +++ b/tests/test_forward_references.py @@ -133,3 +133,19 @@ class B: B.Schema().load(dict(a=dict(c=1))) # marshmallow.exceptions.ValidationError: # {'a': {'d': ['Missing data for required field.'], 'c': ['Unknown field.']}} + + def test_locals_from_decoration_ns(self): + # Test that locals are picked-up at decoration-time rather + # than when the decorator is constructed. + @frozen_dataclass + class A: + b: "B" + + @frozen_dataclass + class B: + x: int + + assert A.Schema().load({"b": {"x": 42}}) == A(b=B(x=42)) + + +frozen_dataclass = dataclass(frozen=True) diff --git a/tests/test_lazy_class_attribute.py b/tests/test_lazy_class_attribute.py new file mode 100644 index 0000000..c6dc1fe --- /dev/null +++ b/tests/test_lazy_class_attribute.py @@ -0,0 +1,59 @@ +import threading +import time +from itertools import count + +import pytest + +from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute + + +def test_caching() -> None: + counter = count() + + def fget() -> str: + return f"value-{next(counter)}" + + class A: + x = lazy_class_attribute(fget, "x") + + assert A.x == "value-0" + assert A.x == "value-0" + + +def test_recursive_evaluation() -> None: + def fget() -> str: + return A.x + + class A: + x: str = lazy_class_attribute(fget, "x") # type: ignore[assignment] + + with pytest.raises(AttributeError, match="recursive evaluation of A.x"): + A.x + + +def test_threading() -> None: + counter = count() + lock = threading.Lock() + + def fget() -> str: + time.sleep(0.05) + with lock: + return f"value-{next(counter)}" + + class A: + x = lazy_class_attribute(fget, "x") + + n_threads = 4 + barrier = threading.Barrier(n_threads) + values = set() + + def run(): + barrier.wait() + values.add(A.x) + + threads = [threading.Thread(target=run) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + assert values == {"value-0"} diff --git a/tests/test_memory_leak.py b/tests/test_memory_leak.py new file mode 100644 index 0000000..31f56e0 --- /dev/null +++ b/tests/test_memory_leak.py @@ -0,0 +1,139 @@ +import gc +import inspect +import sys +import unittest +import weakref +from dataclasses import dataclass +from unittest import mock + +import marshmallow +import marshmallow_dataclass as md + + +class Referenceable: + pass + + +class TestMemoryLeak(unittest.TestCase): + """Test for memory leaks as decribed in `#198`_. + + .. _#198: https://github.com/lovasoa/marshmallow_dataclass/issues/198 + """ + + def setUp(self): + gc.collect() + gc.disable() + self.frame_collected = False + + def tearDown(self): + gc.enable() + + def trackFrame(self): + """Create a tracked local variable in the callers frame. + + We track these locals in the WeakSet self.livingLocals. + + When the callers frame is freed, the locals will be GCed as well. + In this way we can check that the callers frame has been collected. + """ + local = Referenceable() + weakref.finalize(local, self._set_frame_collected) + try: + frame = inspect.currentframe() + frame.f_back.f_locals["local_variable"] = local + finally: + del frame + + def _set_frame_collected(self): + self.frame_collected = True + + def assertFrameCollected(self): + """Check that all locals created by makeLocal have been GCed""" + if not hasattr(sys, "getrefcount"): + # pypy does not do reference counting + gc.collect(0) + self.assertTrue(self.frame_collected) + + def test_sanity(self): + """Test that our scheme for detecting leaked frames works.""" + frames = [] + + def f(): + frames.append(inspect.currentframe()) + self.trackFrame() + + f() + + gc.collect(0) + self.assertFalse( + self.frame_collected + ) # with frame leaked, f's locals are still alive + frames.clear() + self.assertFrameCollected() + + def test_class_schema(self): + def f(): + @dataclass + class Foo: + value: int + + md.class_schema(Foo) + + self.trackFrame() + + f() + self.assertFrameCollected() + + def test_md_dataclass_lazy_schema(self): + def f(): + @md.dataclass + class Foo: + value: int + + self.trackFrame() + + f() + # NB: The "lazy" Foo.Schema attribute descriptor holds a reference to f's frame, + # which, in turn, holds a reference to class Foo, thereby creating ref cycle. + # So, a gc pass is required to clean that up. + gc.collect(0) + self.assertFrameCollected() + + def test_md_dataclass(self): + def f(): + @md.dataclass + class Foo: + value: int + + self.assertIsInstance(Foo.Schema(), marshmallow.Schema) + self.trackFrame() + + f() + self.assertFrameCollected() + + def assertDecoratorDoesNotLeakFrame(self, decorator): + def f() -> None: + class Foo: + value: int + + self.trackFrame() + with self.assertRaisesRegex(Exception, "forced exception"): + decorator(Foo) + + with mock.patch( + "marshmallow_dataclass.setattr", side_effect=Exception("forced exception") + ) as m: + f() + + assert m.mock_calls == [mock.call(mock.ANY, "Schema", mock.ANY)] + # NB: The Mock holds a reference to its arguments, one of which is the + # lazy_class_attribute which holds a reference to the caller's frame + m.reset_mock() + + self.assertFrameCollected() + + def test_exception_in_dataclass(self): + self.assertDecoratorDoesNotLeakFrame(md.dataclass) + + def test_exception_in_add_schema(self): + self.assertDecoratorDoesNotLeakFrame(md.add_schema) diff --git a/tests/test_mypy.yml b/tests/test_mypy.yml index 55e5cb2..3b99545 100644 --- a/tests/test_mypy.yml +++ b/tests/test_mypy.yml @@ -42,6 +42,22 @@ name: str user = User(id=4, name='Johny') + +- case: dataclass_Schema_attribute + mypy_config: | + follow_imports = silent + plugins = marshmallow_dataclass.mypy + env: + - PYTHONPATH=. + main: | + from marshmallow_dataclass import dataclass + + @dataclass + class Test: + child: "Test" + + reveal_type(Test.Schema) # N: Revealed type is "Type[marshmallow.schema.Schema]" + - case: public_custom_types mypy_config: | follow_imports = silent @@ -63,5 +79,5 @@ website = Website(url="http://www.example.org", email="admin@example.org") reveal_type(website.url) # N: Revealed type is "builtins.str" reveal_type(website.email) # N: Revealed type is "builtins.str" - + Website(url=42, email="user@email.com") # E: Argument "url" to "Website" has incompatible type "int"; expected "str" [arg-type] diff --git a/tests/test_typevar_bindings.py b/tests/test_typevar_bindings.py new file mode 100644 index 0000000..7516bd6 --- /dev/null +++ b/tests/test_typevar_bindings.py @@ -0,0 +1,60 @@ +""" Tests for _TypeVarBindings """ +from dataclasses import dataclass +from typing import Generic +from typing import TypeVar + +import pytest + +from marshmallow_dataclass import _is_generic_alias_of_dataclass +from marshmallow_dataclass import _TypeVarBindings + + +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") +W = TypeVar("W") + + +def test_default_init() -> None: + bindings = _TypeVarBindings() + assert len(bindings) == 0 + assert list(bindings) == [] + + +def test_init_raises_on_mismatched_args(): + with pytest.raises(ValueError): + _TypeVarBindings((T, U), (int, str, bool)) + + +def test_from_generic_alias() -> None: + @dataclass + class Gen(Generic[T, U]): + a: T + b: U + + generic_alias = Gen[str, int] + assert _is_generic_alias_of_dataclass(generic_alias) + bindings = _TypeVarBindings.from_generic_alias(generic_alias) + assert dict(bindings) == {T: str, U: int} + + +def test_getitem(): + bindings = _TypeVarBindings((T, U), (int, str)) + assert bindings[U] is str + + with pytest.raises(KeyError): + bindings[V] + with pytest.raises(KeyError): + bindings[str] + with pytest.raises(KeyError): + bindings[0] + + +def test_compose(): + b1 = _TypeVarBindings((T, U), (int, V)) + b2 = _TypeVarBindings((V, W), (U, T)) + + assert dict(b1.compose(b2)) == {V: V, W: int} + assert dict(b2.compose(b1)) == {T: int, U: U} + assert dict(b1.compose(b1)) == {T: int, U: V} + assert dict(b2.compose(b2)) == {V: U, W: T}