11import inspect
22import sys
3- from contextlib import contextmanager
3+ from contextlib import ExitStack , contextmanager
44from functools import partial , wraps
5- from types import FrameType
5+ from types import FrameType , MappingProxyType
66from typing import TYPE_CHECKING , Any , Callable , Dict , Iterator , TypeVar , Union
7+ from unittest import mock
78
89import pytest
910from typing_extensions import Final , final
1011
1112if TYPE_CHECKING :
1213 from returns .interfaces .specific .result import ResultLikeN
1314
14- _ERROR_HANDLERS : Final = (
15- 'lash' ,
16- )
17- _ERRORS_COPIERS : Final = (
18- 'map' ,
19- 'alt' ,
20- )
21-
2215# We keep track of errors handled by keeping a mapping of <object id>: object.
2316# If an error is handled, it is in the mapping.
2417# If it isn't in the mapping, the error is not handled.
2821# Also, the object itself cannot be (in) the key because
2922# (1) we cannot always assume hashability and
3023# (2) we need to track the object identity, not its value
31- _ERRORS_HANDLED : Final [ Dict [int , Any ]] = {} # noqa: WPS407
24+ _ErrorsHandled = Dict [int , Any ]
3225
3326_FunctionType = TypeVar ('_FunctionType' , bound = Callable )
3427_ReturnsResultType = TypeVar (
4134class ReturnsAsserts (object ):
4235 """Class with helpers assertions to check containers."""
4336
44- __slots__ = ()
37+ __slots__ = ('_errors_handled' , )
38+
39+ def __init__ (self , errors_handled : _ErrorsHandled ) -> None :
40+ """Constructor for this type."""
41+ self ._errors_handled = errors_handled
4542
4643 @staticmethod # noqa: WPS602
4744 def assert_equal ( # noqa: WPS602
@@ -55,10 +52,9 @@ def assert_equal( # noqa: WPS602
5552 from returns .primitives .asserts import assert_equal
5653 assert_equal (first , second , deps = deps , backend = backend )
5754
58- @staticmethod # noqa: WPS602
59- def is_error_handled (container ) -> bool : # noqa: WPS602
55+ def is_error_handled (self , container ) -> bool :
6056 """Ensures that container has its error handled in the end."""
61- return id (container ) in _ERRORS_HANDLED
57+ return id (container ) in self . _errors_handled
6258
6359 @staticmethod # noqa: WPS602
6460 @contextmanager
@@ -86,59 +82,6 @@ def assert_trace( # noqa: WPS602
8682 sys .settrace (old_tracer )
8783
8884
89- @pytest .fixture (scope = 'session' )
90- def returns (_patch_containers ) -> ReturnsAsserts :
91- """Returns our own class with helpers assertions to check containers."""
92- return ReturnsAsserts ()
93-
94-
95- @pytest .fixture (autouse = True )
96- def _clear_errors_handled ():
97- """Ensures the 'errors handled' registry doesn't leak memory."""
98- yield
99- _ERRORS_HANDLED .clear ()
100-
101-
102- def pytest_configure (config ) -> None :
103- """
104- Hook to be executed on import.
105-
106- We use it define custom markers.
107- """
108- config .addinivalue_line (
109- 'markers' ,
110- (
111- 'returns_lawful: all tests under `check_all_laws` ' +
112- 'is marked this way, ' +
113- 'use `-m "not returns_lawful"` to skip them.'
114- ),
115- )
116-
117-
118- @pytest .fixture (scope = 'session' )
119- def _patch_containers () -> None :
120- """
121- Fixture to add test specifics into our containers.
122-
123- Currently we inject:
124-
125- - Error handling state, this is required to test that ``Result``-based
126- containers do handle errors
127-
128- Even more things to come!
129- """
130- _patch_error_handling (_ERROR_HANDLERS , _PatchedContainer .error_handler )
131- _patch_error_handling (_ERRORS_COPIERS , _PatchedContainer .copy_handler )
132-
133-
134- def _patch_error_handling (methods , patch_handler ) -> None :
135- for container in _PatchedContainer .containers_to_patch ():
136- for method in methods :
137- original = getattr (container , method , None )
138- if original :
139- setattr (container , method , patch_handler (original ))
140-
141-
14285def _trace_function (
14386 trace_type : _ReturnsResultType ,
14487 function_to_search : _FunctionType ,
@@ -166,65 +109,107 @@ def _trace_function(
166109 raise _DesiredFunctionFound ()
167110
168111
169- @final
170- class _PatchedContainer (object ):
171- """Class with helper methods to patched containers."""
172-
173- __slots__ = ()
174-
175- @classmethod
176- def containers_to_patch (cls ) -> tuple :
177- """We need this method so coverage will work correctly."""
178- from returns .context import (
179- RequiresContextFutureResult ,
180- RequiresContextIOResult ,
181- RequiresContextResult ,
182- )
183- from returns .future import FutureResult
184- from returns .io import IOFailure , IOSuccess
185- from returns .result import Failure , Success
186-
187- return (
188- Success ,
189- Failure ,
190- IOSuccess ,
191- IOFailure ,
192- RequiresContextResult ,
193- RequiresContextIOResult ,
194- RequiresContextFutureResult ,
195- FutureResult ,
196- )
112+ class _DesiredFunctionFound (BaseException ): # noqa: WPS418
113+ """Exception to raise when expected function is found."""
197114
198- @classmethod
199- def error_handler (cls , original ):
200- if inspect .iscoroutinefunction (original ):
201- async def factory (self , * args , ** kwargs ):
202- original_result = await original (self , * args , ** kwargs )
203- _ERRORS_HANDLED [id (original_result )] = original_result
204- return original_result
205- else :
206- def factory (self , * args , ** kwargs ):
207- original_result = original (self , * args , ** kwargs )
208- _ERRORS_HANDLED [id (original_result )] = original_result
209- return original_result
210- return wraps (original )(factory )
211-
212- @classmethod
213- def copy_handler (cls , original ):
214- if inspect .iscoroutinefunction (original ):
215- async def factory (self , * args , ** kwargs ):
216- original_result = await original (self , * args , ** kwargs )
217- if id (self ) in _ERRORS_HANDLED :
218- _ERRORS_HANDLED [id (original_result )] = original_result
219- return original_result
220- else :
221- def factory (self , * args , ** kwargs ):
222- original_result = original (self , * args , ** kwargs )
223- if id (self ) in _ERRORS_HANDLED :
224- _ERRORS_HANDLED [id (original_result )] = original_result
225- return original_result
226- return wraps (original )(factory )
227115
116+ def pytest_configure (config ) -> None :
117+ """
118+ Hook to be executed on import.
228119
229- class _DesiredFunctionFound (BaseException ): # noqa: WPS418
230- """Exception to raise when expected function is found."""
120+ We use it define custom markers.
121+ """
122+ config .addinivalue_line (
123+ 'markers' ,
124+ (
125+ 'returns_lawful: all tests under `check_all_laws` ' +
126+ 'is marked this way, ' +
127+ 'use `-m "not returns_lawful"` to skip them.'
128+ ),
129+ )
130+
131+
132+ @pytest .fixture ()
133+ def returns () -> Iterator [ReturnsAsserts ]:
134+ """Returns class with helpers assertions to check containers."""
135+ with _spy_error_handling () as errors_handled :
136+ yield ReturnsAsserts (errors_handled )
137+
138+
139+ @contextmanager
140+ def _spy_error_handling () -> Iterator [_ErrorsHandled ]:
141+ """Track error handling of containers."""
142+ errs : _ErrorsHandled = {}
143+ with ExitStack () as cleanup :
144+ for container in _containers_to_patch ():
145+ for method , patch in _ERROR_HANDLING_PATCHERS .items ():
146+ cleanup .enter_context (mock .patch .object (
147+ container ,
148+ method ,
149+ patch (getattr (container , method ), errs = errs ),
150+ ))
151+ yield errs
152+
153+
154+ # delayed imports are needed to prevent messing up coverage
155+ def _containers_to_patch () -> tuple :
156+ from returns .context import (
157+ RequiresContextFutureResult ,
158+ RequiresContextIOResult ,
159+ RequiresContextResult ,
160+ )
161+ from returns .future import FutureResult
162+ from returns .io import IOFailure , IOSuccess
163+ from returns .result import Failure , Success
164+
165+ return (
166+ Success ,
167+ Failure ,
168+ IOSuccess ,
169+ IOFailure ,
170+ RequiresContextResult ,
171+ RequiresContextIOResult ,
172+ RequiresContextFutureResult ,
173+ FutureResult ,
174+ )
175+
176+
177+ def _patched_error_handler (
178+ original : _FunctionType , errs : _ErrorsHandled ,
179+ ) -> _FunctionType :
180+ if inspect .iscoroutinefunction (original ):
181+ async def wrapper (self , * args , ** kwargs ):
182+ original_result = await original (self , * args , ** kwargs )
183+ errs [id (original_result )] = original_result
184+ return original_result
185+ else :
186+ def wrapper (self , * args , ** kwargs ):
187+ original_result = original (self , * args , ** kwargs )
188+ errs [id (original_result )] = original_result
189+ return original_result
190+ return wraps (original )(wrapper ) # type: ignore
191+
192+
193+ def _patched_error_copier (
194+ original : _FunctionType , errs : _ErrorsHandled ,
195+ ) -> _FunctionType :
196+ if inspect .iscoroutinefunction (original ):
197+ async def wrapper (self , * args , ** kwargs ):
198+ original_result = await original (self , * args , ** kwargs )
199+ if id (self ) in errs :
200+ errs [id (original_result )] = original_result
201+ return original_result
202+ else :
203+ def wrapper (self , * args , ** kwargs ):
204+ original_result = original (self , * args , ** kwargs )
205+ if id (self ) in errs :
206+ errs [id (original_result )] = original_result
207+ return original_result
208+ return wraps (original )(wrapper ) # type: ignore
209+
210+
211+ _ERROR_HANDLING_PATCHERS : Final = MappingProxyType ({
212+ 'lash' : _patched_error_handler ,
213+ 'map' : _patched_error_copier ,
214+ 'alt' : _patched_error_copier ,
215+ })
0 commit comments