Skip to content

Commit 69bfe77

Browse files
Vizonexseifertm
authored andcommitted
inject at the runner instead however there was a side-effect so I made a comment explaining it.
1 parent aa95acd commit 69bfe77

File tree

2 files changed

+60
-36
lines changed

2 files changed

+60
-36
lines changed

pytest_asyncio/plugin.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def fixture(
135135
*,
136136
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
137137
loop_scope: _ScopeName | None = ...,
138-
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
138+
loop_factory: Callable[[], AbstractEventLoop] | None = ...,
139139
params: Iterable[object] | None = ...,
140140
autouse: bool = ...,
141141
ids: (
@@ -153,7 +153,7 @@ def fixture(
153153
*,
154154
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
155155
loop_scope: _ScopeName | None = ...,
156-
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
156+
loop_factory: Callable[[], AbstractEventLoop] | None = ...,
157157
params: Iterable[object] | None = ...,
158158
autouse: bool = ...,
159159
ids: (
@@ -168,7 +168,7 @@ def fixture(
168168
def fixture(
169169
fixture_function: FixtureFunction[_P, _R] | None = None,
170170
loop_scope: _ScopeName | None = None,
171-
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
171+
loop_factory: Callable[[], AbstractEventLoop] | None = None,
172172
**kwargs: Any,
173173
) -> (
174174
FixtureFunction[_P, _R]
@@ -192,7 +192,11 @@ def _is_asyncio_fixture_function(obj: Any) -> bool:
192192
return getattr(obj, "_force_asyncio_fixture", False)
193193

194194

195-
def _make_asyncio_fixture_function(obj: Any, loop_scope: _ScopeName | None, loop_factory: _ScopeName | None) -> None:
195+
def _make_asyncio_fixture_function(
196+
obj: Any,
197+
loop_scope: _ScopeName | None,
198+
loop_factory: Callable[[], AbstractEventLoop] | None,
199+
) -> None:
196200
if hasattr(obj, "__func__"):
197201
# instance method, check the function object
198202
obj = obj.__func__
@@ -285,14 +289,16 @@ def pytest_report_header(config: Config) -> list[str]:
285289

286290

287291
def _fixture_synchronizer(
288-
fixturedef: FixtureDef, runner: Runner, request: FixtureRequest, loop_factory: Callable[[], AbstractEventLoop]
292+
fixturedef: FixtureDef,
293+
runner: Runner,
294+
request: FixtureRequest,
289295
) -> Callable:
290296
"""Returns a synchronous function evaluating the specified fixture."""
291297
fixture_function = resolve_fixture_function(fixturedef, request)
292298
if inspect.isasyncgenfunction(fixturedef.func):
293-
return _wrap_asyncgen_fixture(fixture_function, runner, request, loop_factory) # type: ignore[arg-type]
299+
return _wrap_asyncgen_fixture(fixture_function, runner, request) # type: ignore[arg-type]
294300
elif inspect.iscoroutinefunction(fixturedef.func):
295-
return _wrap_async_fixture(fixture_function, runner, request, loop_factory) # type: ignore[arg-type]
301+
return _wrap_async_fixture(fixture_function, runner, request) # type: ignore[arg-type]
296302
else:
297303
return fixturedef.func
298304

@@ -307,7 +313,6 @@ def _wrap_asyncgen_fixture(
307313
],
308314
runner: Runner,
309315
request: FixtureRequest,
310-
loop_factory:Callable[[], AbstractEventLoop]
311316
) -> Callable[AsyncGenFixtureParams, AsyncGenFixtureYieldType]:
312317
@functools.wraps(fixture_function)
313318
def _asyncgen_fixture_wrapper(
@@ -337,9 +342,6 @@ async def async_finalizer() -> None:
337342
msg = "Async generator fixture didn't stop."
338343
msg += "Yield only once."
339344
raise ValueError(msg)
340-
if loop_factory:
341-
_loop = loop_factory()
342-
asyncio.set_event_loop(_loop)
343345

344346
runner.run(async_finalizer(), context=context)
345347
if reset_contextvars is not None:
@@ -361,9 +363,8 @@ def _wrap_async_fixture(
361363
],
362364
runner: Runner,
363365
request: FixtureRequest,
364-
loop_factory: Callable[[], AbstractEventLoop] | None = None
365366
) -> Callable[AsyncFixtureParams, AsyncFixtureReturnType]:
366-
@functools.wraps(fixture_function)
367+
@functools.wraps(fixture_function) # type: ignore[arg-type]
367368
def _async_fixture_wrapper(
368369
*args: AsyncFixtureParams.args,
369370
**kwargs: AsyncFixtureParams.kwargs,
@@ -374,10 +375,6 @@ async def setup():
374375

375376
context = contextvars.copy_context()
376377

377-
# ensure loop_factory gets ran before we start running...
378-
if loop_factory:
379-
asyncio.set_event_loop(loop_factory())
380-
381378
result = runner.run(setup(), context=context)
382379
# Copy the context vars modified by the setup task into the current
383380
# context, and (if needed) add a finalizer to reset them.
@@ -522,16 +519,6 @@ def _can_substitute(item: Function) -> bool:
522519
func = item.obj
523520
return inspect.iscoroutinefunction(func)
524521

525-
<<<<<<< HEAD
526-
=======
527-
def runtest(self) -> None:
528-
# print(self.obj.pytestmark[0].__dict__)
529-
synchronized_obj = wrap_in_sync(self.obj, self.obj.pytestmark[0].kwargs.get('loop_factory', None))
530-
with MonkeyPatch.context() as c:
531-
c.setattr(self, "obj", synchronized_obj)
532-
super().runtest()
533-
534-
>>>>>>> edfbfef (figured out loop_factory :))
535522

536523
class AsyncGenerator(PytestAsyncioFunction):
537524
"""Pytest item created by an asynchronous generator"""
@@ -641,16 +628,32 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
641628

642629

643630
@contextlib.contextmanager
644-
def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[None]:
631+
def _temporary_event_loop_policy(
632+
policy: AbstractEventLoopPolicy,
633+
loop_facotry: Callable[..., AbstractEventLoop] | None,
634+
) -> Iterator[None]:
635+
645636
old_loop_policy = _get_event_loop_policy()
646637
try:
647638
old_loop = _get_event_loop_no_warn()
648639
except RuntimeError:
649640
old_loop = None
641+
# XXX: For some reason this function can override runner's
642+
# _loop_factory (At least observed on backported versions of Runner)
643+
# so we need to re-override if existing...
644+
if loop_facotry:
645+
_loop = loop_facotry()
646+
_set_event_loop(_loop)
647+
else:
648+
_loop = None
649+
650650
_set_event_loop_policy(policy)
651651
try:
652652
yield
653653
finally:
654+
if _loop:
655+
# Do not let BaseEventLoop.__del__ complain!
656+
_loop.close()
654657
_set_event_loop_policy(old_loop_policy)
655658
_set_event_loop(old_loop)
656659

@@ -742,8 +745,6 @@ def _synchronize_coroutine(
742745
def inner(*args, **kwargs):
743746
coro = func(*args, **kwargs)
744747
runner.run(coro, context=context)
745-
746-
asyncio.set_event_loop(_last_loop)
747748
return inner
748749

749750

@@ -767,7 +768,7 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
767768

768769
runner_fixture_id = f"_{loop_scope}_scoped_runner"
769770
runner = request.getfixturevalue(runner_fixture_id)
770-
synchronizer = _fixture_synchronizer(fixturedef, runner, request, loop_factory)
771+
synchronizer = _fixture_synchronizer(fixturedef, runner, request)
771772
_make_asyncio_fixture_function(synchronizer, loop_scope, loop_factory)
772773
with MonkeyPatch.context() as c:
773774
c.setattr(fixturedef, "func", synchronizer)
@@ -822,19 +823,32 @@ def _get_default_test_loop_scope(config: Config) -> Any:
822823
"""
823824

824825

826+
def _get_loop_facotry(
827+
request: FixtureRequest,
828+
) -> Callable[[], AbstractEventLoop] | None:
829+
if asyncio_mark := request._pyfuncitem.get_closest_marker("asyncio"):
830+
factory = asyncio_mark.kwargs.get("loop_factory", None)
831+
print(f"FACTORY {factory}")
832+
return factory
833+
else:
834+
return request.obj.__dict__.get("_loop_factory", None) # type: ignore[attr-defined]
835+
836+
825837
def _create_scoped_runner_fixture(scope: _ScopeName) -> Callable:
826838
@pytest.fixture(
827839
scope=scope,
828840
name=f"_{scope}_scoped_runner",
829841
)
830842
def _scoped_runner(
831-
event_loop_policy,
832-
request: FixtureRequest,
843+
event_loop_policy: AbstractEventLoopPolicy, request: FixtureRequest
833844
) -> Iterator[Runner]:
834845
new_loop_policy = event_loop_policy
835-
debug_mode = _get_asyncio_debug(request.config)
836-
with _temporary_event_loop_policy(new_loop_policy):
837-
runner = Runner(debug=debug_mode).__enter__()
846+
847+
# We need to get the factory now because
848+
# _temporary_event_loop_policy can override the Runner
849+
factory = _get_loop_facotry(request)
850+
with _temporary_event_loop_policy(new_loop_policy, factory):
851+
runner = Runner(loop_factory=factory).__enter__()
838852
try:
839853
yield runner
840854
except Exception as e:

tests/test_asyncio_mark.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,16 @@ class CustomEventLoop(asyncio.SelectorEventLoop):
249249
@pytest.mark.asyncio(loop_factory=CustomEventLoop)
250250
async def test_has_different_event_loop():
251251
assert type(asyncio.get_running_loop()).__name__ == "CustomEventLoop"
252+
253+
@pytest_asyncio.fixture(loop_factory=CustomEventLoop)
254+
async def custom_fixture():
255+
yield asyncio.get_running_loop()
256+
257+
async def test_with_fixture(custom_fixture):
258+
# Both of these should be the same...
259+
type(asyncio.get_running_loop()).__name__ == "CustomEventLoop"
260+
type(custom_fixture).__name__ == "CustomEventLoop"
261+
252262
"""
253263
)
254264
)

0 commit comments

Comments
 (0)