33from contextlib import contextmanager
44from functools import partial , wraps
55from types import FrameType
6- from typing import TYPE_CHECKING , Any , Callable , Iterator , TypeVar , Union
6+ from typing import TYPE_CHECKING , Any , Callable , Dict , Iterator , TypeVar , Union
77
88import pytest
99from typing_extensions import Final , final
1010
1111if TYPE_CHECKING :
1212 from returns .interfaces .specific .result import ResultLikeN
1313
14- _ERROR_FIELD : Final = '_error_handled'
1514_ERROR_HANDLERS : Final = (
1615 'lash' ,
1716)
2019 'alt' ,
2120)
2221
22+ # We keep track of errors handled by keeping a mapping of <object id>: object.
23+ # If an error is handled, it is in the mapping.
24+ # If it isn't in the mapping, the error is not handled.
25+ #
26+ # Note only storing object IDs would not work, as objects may be GC'ed
27+ # and their object id assigned to another object.
28+ # Also, the object itself cannot be (in) the key because
29+ # (1) we cannot always assume hashability and
30+ # (2) we need to track the object identity, not its value
31+ _ERRORS_HANDLED : Final [Dict [int , Any ]] = {} # noqa: WPS407
32+
2333_FunctionType = TypeVar ('_FunctionType' , bound = Callable )
2434_ReturnsResultType = TypeVar (
2535 '_ReturnsResultType' ,
@@ -33,8 +43,8 @@ class ReturnsAsserts(object):
3343
3444 __slots__ = ()
3545
36- def assert_equal (
37- self ,
46+ @ staticmethod # noqa: WPS602
47+ def assert_equal ( # noqa: WPS602
3848 first ,
3949 second ,
4050 * ,
@@ -45,13 +55,14 @@ def assert_equal(
4555 from returns .primitives .asserts import assert_equal
4656 assert_equal (first , second , deps = deps , backend = backend )
4757
48- def is_error_handled (self , container ) -> bool :
58+ @staticmethod # noqa: WPS602
59+ def is_error_handled (container ) -> bool : # noqa: WPS602
4960 """Ensures that container has its error handled in the end."""
50- return bool ( getattr ( container , _ERROR_FIELD , False ))
61+ return id ( container ) in _ERRORS_HANDLED
5162
63+ @staticmethod # noqa: WPS602
5264 @contextmanager
53- def assert_trace (
54- self ,
65+ def assert_trace ( # noqa: WPS602
5566 trace_type : _ReturnsResultType ,
5667 function_to_search : _FunctionType ,
5768 ) -> Iterator [None ]:
@@ -76,11 +87,18 @@ def assert_trace(
7687
7788
7889@pytest .fixture (scope = 'session' )
79- def returns (_patch_containers ) -> ReturnsAsserts : # noqa: WPS442
90+ def returns (_patch_containers ) -> ReturnsAsserts :
8091 """Returns our own class with helpers assertions to check containers."""
8192 return ReturnsAsserts ()
8293
8394
95+ @pytest .fixture (autouse = True )
96+ def _clear_errors_handled ():
97+ """Ensures the 'errors handled' registry doesn't leak memory."""
98+ yield
99+ _ERRORS_HANDLED .clear ()
100+
101+
84102def pytest_configure (config ) -> None :
85103 """
86104 Hook to be executed on import.
@@ -182,16 +200,12 @@ def error_handler(cls, original):
182200 if inspect .iscoroutinefunction (original ):
183201 async def factory (self , * args , ** kwargs ):
184202 original_result = await original (self , * args , ** kwargs )
185- object .__setattr__ (
186- original_result , _ERROR_FIELD , True , # noqa: WPS425
187- )
203+ _ERRORS_HANDLED [id (original_result )] = original_result
188204 return original_result
189205 else :
190206 def factory (self , * args , ** kwargs ):
191207 original_result = original (self , * args , ** kwargs )
192- object .__setattr__ (
193- original_result , _ERROR_FIELD , True , # noqa: WPS425
194- )
208+ _ERRORS_HANDLED [id (original_result )] = original_result
195209 return original_result
196210 return wraps (original )(factory )
197211
@@ -200,20 +214,14 @@ def copy_handler(cls, original):
200214 if inspect .iscoroutinefunction (original ):
201215 async def factory (self , * args , ** kwargs ):
202216 original_result = await original (self , * args , ** kwargs )
203- object .__setattr__ (
204- original_result ,
205- _ERROR_FIELD ,
206- getattr (self , _ERROR_FIELD , False ),
207- )
217+ if id (self ) in _ERRORS_HANDLED :
218+ _ERRORS_HANDLED [id (original_result )] = original_result
208219 return original_result
209220 else :
210221 def factory (self , * args , ** kwargs ):
211222 original_result = original (self , * args , ** kwargs )
212- object .__setattr__ (
213- original_result ,
214- _ERROR_FIELD ,
215- getattr (self , _ERROR_FIELD , False ),
216- )
223+ if id (self ) in _ERRORS_HANDLED :
224+ _ERRORS_HANDLED [id (original_result )] = original_result
217225 return original_result
218226 return wraps (original )(factory )
219227
0 commit comments