|
1 | 1 | # ruff: noqa: D100, D101, D102, D103, D104, D105, D107 |
2 | 2 | from __future__ import annotations |
3 | 3 |
|
| 4 | +import dataclasses |
4 | 5 | import inspect |
5 | 6 | import queue |
6 | 7 | import threading |
7 | 8 | import weakref |
8 | 9 | from asyncio import create_task, iscoroutine |
9 | 10 | from collections import defaultdict |
| 11 | +from enum import IntEnum, StrEnum |
10 | 12 | from inspect import signature |
11 | 13 | from threading import Lock |
| 14 | +from types import NoneType |
12 | 15 | from typing import Any, Callable, Coroutine, Generic, cast |
13 | 16 |
|
| 17 | +from immutable import Immutable, is_immutable |
| 18 | + |
14 | 19 | from redux.autorun import Autorun |
15 | 20 | from redux.basic_types import ( |
16 | 21 | Action, |
|
32 | 37 | InitAction, |
33 | 38 | ReducerType, |
34 | 39 | SelectorOutput, |
| 40 | + SnapshotAtom, |
35 | 41 | State, |
36 | 42 | is_complete_reducer_result, |
37 | 43 | is_state_reducer_result, |
@@ -68,6 +74,8 @@ def run(self: _SideEffectRunnerThread[Event]) -> None: |
68 | 74 |
|
69 | 75 |
|
70 | 76 | class Store(Generic[State, Action, Event]): |
| 77 | + custom_serializer = None |
| 78 | + |
71 | 79 | def __init__( |
72 | 80 | self: Store[State, Action, Event], |
73 | 81 | reducer: ReducerType[State, Action, Event], |
@@ -276,3 +284,42 @@ def decorator( |
276 | 284 | ) |
277 | 285 |
|
278 | 286 | return decorator |
| 287 | + |
| 288 | + def set_custom_serializer( |
| 289 | + self: Store, |
| 290 | + serializer: Callable[[object | type], SnapshotAtom], |
| 291 | + ) -> None: |
| 292 | + """Set a custom serializer for the store snapshot.""" |
| 293 | + self.custom_serializer = serializer |
| 294 | + |
| 295 | + @property |
| 296 | + def snapshot(self: Store[State, Action, Event]) -> SnapshotAtom: |
| 297 | + return self._serialize_value(self._state) |
| 298 | + |
| 299 | + def _serialize_value(self: Store, obj: object | type) -> SnapshotAtom: |
| 300 | + if self.custom_serializer: |
| 301 | + return self.custom_serializer(obj) |
| 302 | + if is_immutable(obj): |
| 303 | + return self._serialize_dataclass_to_dict(obj) |
| 304 | + if isinstance(obj, (list, tuple)): |
| 305 | + return [self._serialize_value(i) for i in obj] |
| 306 | + if callable(obj): |
| 307 | + return self._serialize_value(obj()) |
| 308 | + if isinstance(obj, StrEnum): |
| 309 | + return str(obj) |
| 310 | + if isinstance(obj, IntEnum): |
| 311 | + return int(obj) |
| 312 | + if isinstance(obj, (int, float, str, bool, NoneType)): |
| 313 | + return obj |
| 314 | + msg = f'Unable to serialize object with type {type(obj)}.' |
| 315 | + raise ValueError(msg) |
| 316 | + |
| 317 | + def _serialize_dataclass_to_dict( |
| 318 | + self: Store, |
| 319 | + obj: Immutable, |
| 320 | + ) -> dict[str, Any]: |
| 321 | + result = {} |
| 322 | + for field in dataclasses.fields(obj): |
| 323 | + value = self._serialize_value(getattr(obj, field.name)) |
| 324 | + result[field.name] = value |
| 325 | + return result |
0 commit comments