Skip to content

Commit 5f7edb6

Browse files
Vizonexseifertm
authored andcommitted
Incomplete need to figure out how to get loop_factory / multiple into asyncio.Runner
1 parent 95d5930 commit 5f7edb6

File tree

3 files changed

+82
-32
lines changed

3 files changed

+82
-32
lines changed

pytest_asyncio/plugin.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,8 @@ def restore_contextvars():
417417
class PytestAsyncioFunction(Function):
418418
"""Base class for all test functions managed by pytest-asyncio."""
419419

420+
loop_factory: Callable[[], AbstractEventLoop] | None
421+
420422
@classmethod
421423
def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | None:
422424
"""
@@ -431,7 +433,12 @@ def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | N
431433
return None
432434

433435
@classmethod
434-
def _from_function(cls, function: Function, /) -> Function:
436+
def _from_function(
437+
cls,
438+
function: Function,
439+
loop_factory: Callable[[], AbstractEventLoop] | None = None,
440+
/,
441+
) -> Function:
435442
"""
436443
Instantiates this specific PytestAsyncioFunction type from the specified
437444
Function item.
@@ -447,6 +454,7 @@ def _from_function(cls, function: Function, /) -> Function:
447454
keywords=function.keywords,
448455
originalname=function.originalname,
449456
)
457+
subclass_instance.loop_factory = loop_factory
450458
subclass_instance.own_markers = function.own_markers
451459
assert subclass_instance.own_markers == function.own_markers
452460
return subclass_instance
@@ -610,9 +618,27 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
610618
node.config
611619
) == Mode.AUTO and not node.get_closest_marker("asyncio"):
612620
node.add_marker("asyncio")
613-
if node.get_closest_marker("asyncio"):
614-
updated_item = specialized_item_class._from_function(node)
615-
updated_node_collection.append(updated_item)
621+
if asyncio_marker := node.get_closest_marker("asyncio"):
622+
if loop_factory := asyncio_marker.kwargs.get("loop_factory", None):
623+
# multiply if loop_factory is an iterable object of factories
624+
if hasattr(loop_factory, "__iter__"):
625+
updated_item = [
626+
specialized_item_class._from_function(node, lf)
627+
for lf in loop_factory
628+
]
629+
else:
630+
updated_item = specialized_item_class._from_function(
631+
node, loop_factory
632+
)
633+
else:
634+
updated_item = specialized_item_class._from_function(node)
635+
636+
# we could have multiple factroies to test if so,
637+
# multiply the number of functions for us...
638+
if isinstance(updated_item, list):
639+
updated_node_collection.extend(updated_item)
640+
else:
641+
updated_node_collection.append(updated_item)
616642
hook_result.force_result(updated_node_collection)
617643

618644

@@ -631,20 +657,6 @@ def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[No
631657
_set_event_loop(old_loop)
632658

633659

634-
@contextlib.contextmanager
635-
def _temporary_event_loop(loop: AbstractEventLoop):
636-
try:
637-
old_event_loop = asyncio.get_event_loop()
638-
except RuntimeError:
639-
old_event_loop = None
640-
641-
asyncio.set_event_loop(old_event_loop)
642-
try:
643-
yield
644-
finally:
645-
asyncio.set_event_loop(old_event_loop)
646-
647-
648660
def _get_event_loop_policy() -> AbstractEventLoopPolicy:
649661
with warnings.catch_warnings():
650662
warnings.simplefilter("ignore", DeprecationWarning)
@@ -753,8 +765,15 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
753765
or default_loop_scope
754766
or fixturedef.scope
755767
)
768+
# XXX: Currently Confused as to where to debug and harvest and get the runner to use the loop_factory argument.
769+
loop_factory = getattr(fixturedef.func, "loop_factory", None)
770+
771+
print(f"LOOP FACTORY: {loop_factory} {fixturedef.func}")
772+
sys.stdout.flush()
773+
756774
runner_fixture_id = f"_{loop_scope}_scoped_runner"
757-
runner = request.getfixturevalue(runner_fixture_id)
775+
runner: Runner = request.getfixturevalue(runner_fixture_id)
776+
758777
synchronizer = _fixture_synchronizer(fixturedef, runner, request)
759778
_make_asyncio_fixture_function(synchronizer, loop_scope)
760779
with MonkeyPatch.context() as c:
@@ -779,9 +798,12 @@ def _get_marked_loop_scope(
779798
) -> _ScopeName:
780799
assert asyncio_marker.name == "asyncio"
781800
if asyncio_marker.args or (
782-
asyncio_marker.kwargs and set(asyncio_marker.kwargs) - {"loop_scope", "scope"}
801+
asyncio_marker.kwargs
802+
and set(asyncio_marker.kwargs) - {"loop_scope", "scope", "loop_factory"}
783803
):
784-
raise ValueError("mark.asyncio accepts only a keyword argument 'loop_scope'.")
804+
raise ValueError(
805+
"mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'."
806+
)
785807
if "scope" in asyncio_marker.kwargs:
786808
if "loop_scope" in asyncio_marker.kwargs:
787809
raise pytest.UsageError(_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR)
@@ -849,12 +871,6 @@ def _scoped_runner(
849871
)
850872

851873

852-
@pytest.fixture(scope="session", autouse=True)
853-
def new_event_loop() -> AbstractEventLoop:
854-
"""Creates a new eventloop for different tests being ran"""
855-
return asyncio.new_event_loop()
856-
857-
858874
@pytest.fixture(scope="session", autouse=True)
859875
def event_loop_policy() -> AbstractEventLoopPolicy:
860876
"""Return an instance of the policy used to create asyncio event loops."""

tests/markers/test_invalid_arguments.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ async def test_anything():
4040
)
4141
result = pytester.runpytest_subprocess()
4242
result.assert_outcomes(errors=1)
43-
result.stdout.fnmatch_lines(
44-
["*ValueError: mark.asyncio accepts only a keyword argument*"]
45-
)
43+
result.stdout.fnmatch_lines([""])
4644

4745

4846
def test_error_when_wrong_keyword_argument_is_passed(
@@ -62,7 +60,9 @@ async def test_anything():
6260
result = pytester.runpytest_subprocess()
6361
result.assert_outcomes(errors=1)
6462
result.stdout.fnmatch_lines(
65-
["*ValueError: mark.asyncio accepts only a keyword argument 'loop_scope'*"]
63+
[
64+
"*ValueError: mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'*"
65+
]
6666
)
6767

6868

@@ -83,5 +83,7 @@ async def test_anything():
8383
result = pytester.runpytest_subprocess()
8484
result.assert_outcomes(errors=1)
8585
result.stdout.fnmatch_lines(
86-
["*ValueError: mark.asyncio accepts only a keyword argument*"]
86+
[
87+
"*ValueError: mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'*"
88+
]
8789
)

tests/test_asyncio_mark.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,35 @@ async def test_a(session_loop_fixture):
223223

224224
result = pytester.runpytest("--asyncio-mode=auto")
225225
result.assert_outcomes(passed=1)
226+
227+
228+
def test_asyncio_marker_event_loop_factories(pytester: Pytester):
229+
pytester.makeini(
230+
dedent(
231+
"""\
232+
[pytest]
233+
asyncio_default_fixture_loop_scope = function
234+
asyncio_default_test_loop_scope = module
235+
"""
236+
)
237+
)
238+
239+
pytester.makepyfile(
240+
dedent(
241+
"""\
242+
import asyncio
243+
import pytest_asyncio
244+
import pytest
245+
246+
class CustomEventLoop(asyncio.SelectorEventLoop):
247+
pass
248+
249+
@pytest.mark.asyncio(loop_factory=CustomEventLoop)
250+
async def test_has_different_event_loop():
251+
assert type(asyncio.get_running_loop()).__name__ == "CustomEventLoop"
252+
"""
253+
)
254+
)
255+
256+
result = pytester.runpytest("--asyncio-mode=auto")
257+
result.assert_outcomes(passed=1)

0 commit comments

Comments
 (0)