66import json
77import os
88from collections import defaultdict
9- from typing import TYPE_CHECKING , Any , cast
9+ from pathlib import Path
10+ from typing import TYPE_CHECKING , Any , Callable , Generic , cast
1011
1112import pytest
1213
13- from redux .basic_types import FinishEvent
14+ from redux .basic_types import FinishEvent , State
1415
1516if TYPE_CHECKING :
16- from pathlib import Path
17-
1817 from _pytest .fixtures import SubRequest
1918
2019 from redux .main import Store
2120
2221
23- class StoreSnapshot :
22+ class StoreSnapshot ( Generic [ State ]) :
2423 """Context object for tests taking snapshots of the store."""
2524
2625 def __init__ (
@@ -32,11 +31,14 @@ def __init__(
3231 store : Store ,
3332 ) -> None :
3433 """Create a new store snapshot context."""
34+ self ._is_failed = False
3535 self ._is_closed = False
3636 self .override = override
3737 self .test_counter : dict [str | None , int ] = defaultdict (int )
3838 file = path .with_suffix ('' ).name
39- self .results_dir = path .parent / 'results' / file / test_id .split ('::' )[- 1 ][5 :]
39+ self .results_dir = Path (
40+ path .parent / 'results' / file / test_id .split ('::' )[- 1 ][5 :],
41+ )
4042 if self .results_dir .exists ():
4143 for file in self .results_dir .glob (
4244 'store-*.jsonc' if override else 'store-*.mismatch.jsonc' ,
@@ -47,12 +49,15 @@ def __init__(
4749 self .store = store
4850 store .subscribe_event (FinishEvent , self .close )
4951
50- @property
51- def json_snapshot (self : StoreSnapshot ) -> str :
52+ def json_snapshot (
53+ self : StoreSnapshot [State ],
54+ * ,
55+ selector : Callable [[State ], Any ] = lambda state : state ,
56+ ) -> str :
5257 """Return the snapshot of the current state of the store."""
5358 return (
5459 json .dumps (
55- self .store .snapshot ,
60+ self .store .serialize_value ( selector ( self . store . _state )), # noqa: SLF001
5661 indent = 2 ,
5762 sort_keys = True ,
5863 ensure_ascii = False ,
@@ -61,13 +66,18 @@ def json_snapshot(self: StoreSnapshot) -> str:
6166 else ''
6267 )
6368
64- def get_filename (self : StoreSnapshot , title : str | None ) -> str :
69+ def get_filename (self : StoreSnapshot [ State ] , title : str | None ) -> str :
6570 """Get the filename for the snapshot."""
6671 if title :
6772 return f"""store-{ title } -{ self .test_counter [title ]:03d} """
6873 return f"""store-{ self .test_counter [title ]:03d} """
6974
70- def take (self : StoreSnapshot , * , title : str | None = None ) -> None :
75+ def take (
76+ self : StoreSnapshot [State ],
77+ * ,
78+ title : str | None = None ,
79+ selector : Callable [[State ], Any ] = lambda state : state ,
80+ ) -> None :
7181 """Take a snapshot of the current window."""
7282 if self ._is_closed :
7383 msg = (
@@ -81,29 +91,39 @@ def take(self: StoreSnapshot, *, title: str | None = None) -> None:
8191 json_path = path .with_suffix ('.jsonc' )
8292 mismatch_path = path .with_suffix ('.mismatch.jsonc' )
8393
84- new_snapshot = self .json_snapshot
94+ new_snapshot = self .json_snapshot ( selector = selector )
8595 if self .override :
8696 json_path .write_text (f'// { filename } \n { new_snapshot } \n ' ) # pragma: no cover
8797 else :
8898 old_snapshot = None
8999 if json_path .exists ():
90100 old_snapshot = json_path .read_text ().split ('\n ' , 1 )[1 ][:- 1 ]
91101 if old_snapshot != new_snapshot :
102+ self ._is_failed = True
92103 mismatch_path .write_text ( # pragma: no cover
93104 f'// MISMATCH: { filename } \n { new_snapshot } \n ' ,
94105 )
95106 assert new_snapshot == old_snapshot , f'Store snapshot mismatch - { filename } '
96107
97108 self .test_counter [title ] += 1
98109
99- def close (self : StoreSnapshot ) -> None :
110+ def monitor (self : StoreSnapshot [State ], selector : Callable [[State ], Any ]) -> None :
111+ """Monitor the state of the store and take snapshots."""
112+
113+ @self .store .autorun (selector = selector )
114+ def _ (state : State ) -> None :
115+ self .take (selector = lambda _ : state )
116+
117+ def close (self : StoreSnapshot [State ]) -> None :
100118 """Close the snapshot context."""
119+ self ._is_closed = True
120+ if self ._is_failed :
121+ return
101122 for title in self .test_counter :
102123 filename = self .get_filename (title )
103124 json_path = (self .results_dir / filename ).with_suffix ('.jsonc' )
104125
105126 assert not json_path .exists (), f'Snapshot { filename } not taken'
106- self ._is_closed = True
107127
108128
109129@pytest .fixture ()
0 commit comments