Skip to content

Commit aa95acd

Browse files
Vizonexseifertm
authored andcommitted
figured out loop_factory :)
1 parent 5f7edb6 commit aa95acd

File tree

1 file changed

+41
-51
lines changed

1 file changed

+41
-51
lines changed

pytest_asyncio/plugin.py

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
PytestPluginManager,
5050
)
5151

52+
from typing import Callable
5253
if sys.version_info >= (3, 10):
5354
from typing import ParamSpec
5455
else:
@@ -134,6 +135,7 @@ def fixture(
134135
*,
135136
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
136137
loop_scope: _ScopeName | None = ...,
138+
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
137139
params: Iterable[object] | None = ...,
138140
autouse: bool = ...,
139141
ids: (
@@ -151,6 +153,7 @@ def fixture(
151153
*,
152154
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
153155
loop_scope: _ScopeName | None = ...,
156+
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
154157
params: Iterable[object] | None = ...,
155158
autouse: bool = ...,
156159
ids: (
@@ -165,20 +168,21 @@ def fixture(
165168
def fixture(
166169
fixture_function: FixtureFunction[_P, _R] | None = None,
167170
loop_scope: _ScopeName | None = None,
171+
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
168172
**kwargs: Any,
169173
) -> (
170174
FixtureFunction[_P, _R]
171175
| Callable[[FixtureFunction[_P, _R]], FixtureFunction[_P, _R]]
172176
):
173177
if fixture_function is not None:
174-
_make_asyncio_fixture_function(fixture_function, loop_scope)
178+
_make_asyncio_fixture_function(fixture_function, loop_scope, loop_factory)
175179
return pytest.fixture(fixture_function, **kwargs)
176180

177181
else:
178182

179183
@functools.wraps(fixture)
180184
def inner(fixture_function: FixtureFunction[_P, _R]) -> FixtureFunction[_P, _R]:
181-
return fixture(fixture_function, loop_scope=loop_scope, **kwargs)
185+
return fixture(fixture_function, loop_factory=loop_factory, loop_scope=loop_scope, **kwargs)
182186

183187
return inner
184188

@@ -188,12 +192,13 @@ def _is_asyncio_fixture_function(obj: Any) -> bool:
188192
return getattr(obj, "_force_asyncio_fixture", False)
189193

190194

191-
def _make_asyncio_fixture_function(obj: Any, loop_scope: _ScopeName | None) -> None:
195+
def _make_asyncio_fixture_function(obj: Any, loop_scope: _ScopeName | None, loop_factory: _ScopeName | None) -> None:
192196
if hasattr(obj, "__func__"):
193197
# instance method, check the function object
194198
obj = obj.__func__
195199
obj._force_asyncio_fixture = True
196200
obj._loop_scope = loop_scope
201+
obj._loop_factory = loop_factory
197202

198203

199204
def _is_coroutine_or_asyncgen(obj: Any) -> bool:
@@ -280,14 +285,14 @@ def pytest_report_header(config: Config) -> list[str]:
280285

281286

282287
def _fixture_synchronizer(
283-
fixturedef: FixtureDef, runner: Runner, request: FixtureRequest
288+
fixturedef: FixtureDef, runner: Runner, request: FixtureRequest, loop_factory: Callable[[], AbstractEventLoop]
284289
) -> Callable:
285290
"""Returns a synchronous function evaluating the specified fixture."""
286291
fixture_function = resolve_fixture_function(fixturedef, request)
287292
if inspect.isasyncgenfunction(fixturedef.func):
288-
return _wrap_asyncgen_fixture(fixture_function, runner, request) # type: ignore[arg-type]
293+
return _wrap_asyncgen_fixture(fixture_function, runner, request, loop_factory) # type: ignore[arg-type]
289294
elif inspect.iscoroutinefunction(fixturedef.func):
290-
return _wrap_async_fixture(fixture_function, runner, request) # type: ignore[arg-type]
295+
return _wrap_async_fixture(fixture_function, runner, request, loop_factory) # type: ignore[arg-type]
291296
else:
292297
return fixturedef.func
293298

@@ -302,6 +307,7 @@ def _wrap_asyncgen_fixture(
302307
],
303308
runner: Runner,
304309
request: FixtureRequest,
310+
loop_factory:Callable[[], AbstractEventLoop]
305311
) -> Callable[AsyncGenFixtureParams, AsyncGenFixtureYieldType]:
306312
@functools.wraps(fixture_function)
307313
def _asyncgen_fixture_wrapper(
@@ -331,6 +337,9 @@ async def async_finalizer() -> None:
331337
msg = "Async generator fixture didn't stop."
332338
msg += "Yield only once."
333339
raise ValueError(msg)
340+
if loop_factory:
341+
_loop = loop_factory()
342+
asyncio.set_event_loop(_loop)
334343

335344
runner.run(async_finalizer(), context=context)
336345
if reset_contextvars is not None:
@@ -352,6 +361,7 @@ def _wrap_async_fixture(
352361
],
353362
runner: Runner,
354363
request: FixtureRequest,
364+
loop_factory: Callable[[], AbstractEventLoop] | None = None
355365
) -> Callable[AsyncFixtureParams, AsyncFixtureReturnType]:
356366
@functools.wraps(fixture_function)
357367
def _async_fixture_wrapper(
@@ -363,8 +373,12 @@ async def setup():
363373
return res
364374

365375
context = contextvars.copy_context()
366-
result = runner.run(setup(), context=context)
367376

377+
# ensure loop_factory gets ran before we start running...
378+
if loop_factory:
379+
asyncio.set_event_loop(loop_factory())
380+
381+
result = runner.run(setup(), context=context)
368382
# Copy the context vars modified by the setup task into the current
369383
# context, and (if needed) add a finalizer to reset them.
370384
#
@@ -417,8 +431,6 @@ def restore_contextvars():
417431
class PytestAsyncioFunction(Function):
418432
"""Base class for all test functions managed by pytest-asyncio."""
419433

420-
loop_factory: Callable[[], AbstractEventLoop] | None
421-
422434
@classmethod
423435
def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | None:
424436
"""
@@ -433,12 +445,7 @@ def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | N
433445
return None
434446

435447
@classmethod
436-
def _from_function(
437-
cls,
438-
function: Function,
439-
loop_factory: Callable[[], AbstractEventLoop] | None = None,
440-
/,
441-
) -> Function:
448+
def _from_function(cls, function: Function, /) -> Function:
442449
"""
443450
Instantiates this specific PytestAsyncioFunction type from the specified
444451
Function item.
@@ -454,7 +461,6 @@ def _from_function(
454461
keywords=function.keywords,
455462
originalname=function.originalname,
456463
)
457-
subclass_instance.loop_factory = loop_factory
458464
subclass_instance.own_markers = function.own_markers
459465
assert subclass_instance.own_markers == function.own_markers
460466
return subclass_instance
@@ -516,6 +522,16 @@ def _can_substitute(item: Function) -> bool:
516522
func = item.obj
517523
return inspect.iscoroutinefunction(func)
518524

525+
<<<<<<< HEAD
526+
=======
527+
def runtest(self) -> None:
528+
# print(self.obj.pytestmark[0].__dict__)
529+
synchronized_obj = wrap_in_sync(self.obj, self.obj.pytestmark[0].kwargs.get('loop_factory', None))
530+
with MonkeyPatch.context() as c:
531+
c.setattr(self, "obj", synchronized_obj)
532+
super().runtest()
533+
534+
>>>>>>> edfbfef (figured out loop_factory :))
519535

520536
class AsyncGenerator(PytestAsyncioFunction):
521537
"""Pytest item created by an asynchronous generator"""
@@ -618,27 +634,9 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
618634
node.config
619635
) == Mode.AUTO and not node.get_closest_marker("asyncio"):
620636
node.add_marker("asyncio")
621-
if asyncio_marker := node.get_closest_marker("asyncio"):
622-
if loop_factory := asyncio_marker.kwargs.get("loop_factory", None):
623-
# multiply if loop_factory is an iterable object of factories
624-
if hasattr(loop_factory, "__iter__"):
625-
updated_item = [
626-
specialized_item_class._from_function(node, lf)
627-
for lf in loop_factory
628-
]
629-
else:
630-
updated_item = specialized_item_class._from_function(
631-
node, loop_factory
632-
)
633-
else:
634-
updated_item = specialized_item_class._from_function(node)
635-
636-
# we could have multiple factroies to test if so,
637-
# multiply the number of functions for us...
638-
if isinstance(updated_item, list):
639-
updated_node_collection.extend(updated_item)
640-
else:
641-
updated_node_collection.append(updated_item)
637+
if node.get_closest_marker("asyncio"):
638+
updated_item = specialized_item_class._from_function(node)
639+
updated_node_collection.append(updated_item)
642640
hook_result.force_result(updated_node_collection)
643641

644642

@@ -740,12 +738,12 @@ def _synchronize_coroutine(
740738
Return a sync wrapper around a coroutine executing it in the
741739
specified runner and context.
742740
"""
743-
744741
@functools.wraps(func)
745742
def inner(*args, **kwargs):
746743
coro = func(*args, **kwargs)
747744
runner.run(coro, context=context)
748745

746+
asyncio.set_event_loop(_last_loop)
749747
return inner
750748

751749

@@ -765,17 +763,12 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
765763
or default_loop_scope
766764
or fixturedef.scope
767765
)
768-
# XXX: Currently Confused as to where to debug and harvest and get the runner to use the loop_factory argument.
769766
loop_factory = getattr(fixturedef.func, "loop_factory", None)
770767

771-
print(f"LOOP FACTORY: {loop_factory} {fixturedef.func}")
772-
sys.stdout.flush()
773-
774768
runner_fixture_id = f"_{loop_scope}_scoped_runner"
775-
runner: Runner = request.getfixturevalue(runner_fixture_id)
776-
777-
synchronizer = _fixture_synchronizer(fixturedef, runner, request)
778-
_make_asyncio_fixture_function(synchronizer, loop_scope)
769+
runner = request.getfixturevalue(runner_fixture_id)
770+
synchronizer = _fixture_synchronizer(fixturedef, runner, request, loop_factory)
771+
_make_asyncio_fixture_function(synchronizer, loop_scope, loop_factory)
779772
with MonkeyPatch.context() as c:
780773
c.setattr(fixturedef, "func", synchronizer)
781774
hook_result = yield
@@ -798,12 +791,9 @@ def _get_marked_loop_scope(
798791
) -> _ScopeName:
799792
assert asyncio_marker.name == "asyncio"
800793
if asyncio_marker.args or (
801-
asyncio_marker.kwargs
802-
and set(asyncio_marker.kwargs) - {"loop_scope", "scope", "loop_factory"}
794+
asyncio_marker.kwargs and set(asyncio_marker.kwargs) - {"loop_scope", "scope", "loop_factory"}
803795
):
804-
raise ValueError(
805-
"mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'."
806-
)
796+
raise ValueError("mark.asyncio accepts only a keyword arguments 'loop_scope' or 'loop_factory'")
807797
if "scope" in asyncio_marker.kwargs:
808798
if "loop_scope" in asyncio_marker.kwargs:
809799
raise pytest.UsageError(_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR)

0 commit comments

Comments
 (0)