@@ -372,6 +372,8 @@ def restore_contextvars():
372372class PytestAsyncioFunction (Function ):
373373 """Base class for all test functions managed by pytest-asyncio."""
374374
375+ loop_factory : Callable [[], AbstractEventLoop ] | None
376+
375377 @classmethod
376378 def item_subclass_for (cls , item : Function , / ) -> type [PytestAsyncioFunction ] | None :
377379 """
@@ -386,12 +388,18 @@ def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | N
386388 return None
387389
388390 @classmethod
389- def _from_function (cls , function : Function , / ) -> Function :
391+ def _from_function (
392+ cls ,
393+ function : Function ,
394+ loop_factory : Callable [[], AbstractEventLoop ] | None = None ,
395+ / ,
396+ ) -> Function :
390397 """
391398 Instantiates this specific PytestAsyncioFunction type from the specified
392399 Function item.
393400 """
394401 assert function .get_closest_marker ("asyncio" )
402+
395403 subclass_instance = cls .from_parent (
396404 function .parent ,
397405 name = function .name ,
@@ -401,6 +409,7 @@ def _from_function(cls, function: Function, /) -> Function:
401409 keywords = function .keywords ,
402410 originalname = function .originalname ,
403411 )
412+ subclass_instance .loop_factory = loop_factory
404413 subclass_instance .own_markers = function .own_markers
405414 assert subclass_instance .own_markers == function .own_markers
406415 return subclass_instance
@@ -525,9 +534,27 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
525534 node .config
526535 ) == Mode .AUTO and not node .get_closest_marker ("asyncio" ):
527536 node .add_marker ("asyncio" )
528- if node .get_closest_marker ("asyncio" ):
529- updated_item = specialized_item_class ._from_function (node )
530- updated_node_collection .append (updated_item )
537+ if asyncio_marker := node .get_closest_marker ("asyncio" ):
538+ if loop_factory := asyncio_marker .kwargs .get ("loop_factory" , None ):
539+ # multiply if loop_factory is an iterable object of factories
540+ if hasattr (loop_factory , "__iter__" ):
541+ updated_item = [
542+ specialized_item_class ._from_function (node , lf )
543+ for lf in loop_factory
544+ ]
545+ else :
546+ updated_item = specialized_item_class ._from_function (
547+ node , loop_factory
548+ )
549+ else :
550+ updated_item = specialized_item_class ._from_function (node )
551+
552+ # we could have multiple factroies to test if so,
553+ # multiply the number of functions for us...
554+ if isinstance (updated_item , list ):
555+ updated_node_collection .extend (updated_item )
556+ else :
557+ updated_node_collection .append (updated_item )
531558 hook_result .force_result (updated_node_collection )
532559
533560
@@ -546,20 +573,6 @@ def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[No
546573 _set_event_loop (old_loop )
547574
548575
549- @contextlib .contextmanager
550- def _temporary_event_loop (loop : AbstractEventLoop ):
551- try :
552- old_event_loop = asyncio .get_event_loop ()
553- except RuntimeError :
554- old_event_loop = None
555-
556- asyncio .set_event_loop (old_event_loop )
557- try :
558- yield
559- finally :
560- asyncio .set_event_loop (old_event_loop )
561-
562-
563576def _get_event_loop_policy () -> AbstractEventLoopPolicy :
564577 with warnings .catch_warnings ():
565578 warnings .simplefilter ("ignore" , DeprecationWarning )
@@ -669,12 +682,15 @@ def pytest_runtest_setup(item: pytest.Item) -> None:
669682 marker = item .get_closest_marker ("asyncio" )
670683 if marker is None :
671684 return
685+ getattr (marker , "loop_factory" , None )
672686 default_loop_scope = _get_default_test_loop_scope (item .config )
673687 loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
674688 runner_fixture_id = f"_{ loop_scope } _scoped_runner"
675- fixturenames = item .fixturenames # type: ignore[attr-defined]
689+ fixturenames : list [str ] = item .fixturenames # type: ignore[attr-defined]
690+
676691 if runner_fixture_id not in fixturenames :
677692 fixturenames .append (runner_fixture_id )
693+
678694 obj = getattr (item , "obj" , None )
679695 if not getattr (obj , "hypothesis" , False ) and getattr (
680696 obj , "is_hypothesis_test" , False
@@ -701,8 +717,15 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
701717 or default_loop_scope
702718 or fixturedef .scope
703719 )
720+ # XXX: Currently Confused as to where to debug and harvest and get the runner to use the loop_factory argument.
721+ loop_factory = getattr (fixturedef .func , "loop_factory" , None )
722+
723+ print (f"LOOP FACTORY: { loop_factory } { fixturedef .func } " )
724+ sys .stdout .flush ()
725+
704726 runner_fixture_id = f"_{ loop_scope } _scoped_runner"
705- runner = request .getfixturevalue (runner_fixture_id )
727+ runner : Runner = request .getfixturevalue (runner_fixture_id )
728+
706729 synchronizer = _fixture_synchronizer (fixturedef , runner , request )
707730 _make_asyncio_fixture_function (synchronizer , loop_scope )
708731 with MonkeyPatch .context () as c :
@@ -727,9 +750,12 @@ def _get_marked_loop_scope(
727750) -> _ScopeName :
728751 assert asyncio_marker .name == "asyncio"
729752 if asyncio_marker .args or (
730- asyncio_marker .kwargs and set (asyncio_marker .kwargs ) - {"loop_scope" , "scope" }
753+ asyncio_marker .kwargs
754+ and set (asyncio_marker .kwargs ) - {"loop_scope" , "scope" , "loop_factory" }
731755 ):
732- raise ValueError ("mark.asyncio accepts only a keyword argument 'loop_scope'." )
756+ raise ValueError (
757+ "mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'."
758+ )
733759 if "scope" in asyncio_marker .kwargs :
734760 if "loop_scope" in asyncio_marker .kwargs :
735761 raise pytest .UsageError (_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR )
@@ -795,12 +821,6 @@ def _scoped_runner(
795821 )
796822
797823
798- @pytest .fixture (scope = "session" , autouse = True )
799- def new_event_loop () -> AbstractEventLoop :
800- """Creates a new eventloop for different tests being ran"""
801- return asyncio .new_event_loop ()
802-
803-
804824@pytest .fixture (scope = "session" , autouse = True )
805825def event_loop_policy () -> AbstractEventLoopPolicy :
806826 """Return an instance of the policy used to create asyncio event loops."""
0 commit comments