|
38 | 38 | Session, |
39 | 39 | StashKey, |
40 | 40 | ) |
| 41 | +from typing_extensions import Self |
41 | 42 |
|
42 | 43 | _R = TypeVar("_R") |
43 | 44 |
|
@@ -356,6 +357,21 @@ async def setup(): |
356 | 357 | class AsyncFunction(pytest.Function): |
357 | 358 | """Pytest item that is a coroutine or an asynchronous generator""" |
358 | 359 |
|
| 360 | + @classmethod |
| 361 | + def from_function(cls, function: pytest.Function, /) -> Self: |
| 362 | + """ |
| 363 | + Instantiates an AsyncFunction from the specified pytest.Function item. |
| 364 | + """ |
| 365 | + return cls.from_parent( |
| 366 | + function.parent, |
| 367 | + name=function.name, |
| 368 | + callspec=getattr(function, "callspec", None), |
| 369 | + callobj=function.obj, |
| 370 | + fixtureinfo=function._fixtureinfo, |
| 371 | + keywords=function.keywords, |
| 372 | + originalname=function.originalname, |
| 373 | + ) |
| 374 | + |
359 | 375 |
|
360 | 376 | _HOLDER: Set[FixtureDef] = set() |
361 | 377 |
|
@@ -396,27 +412,14 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass( |
396 | 412 | except TypeError: |
397 | 413 | # Treat single node as a single-element iterable |
398 | 414 | node_iterator = iter((node_or_list_of_nodes,)) |
399 | | - async_functions = [] |
400 | | - for collector_or_item in node_iterator: |
401 | | - if not ( |
402 | | - isinstance(collector_or_item, pytest.Function) |
403 | | - and _is_coroutine_or_asyncgen(obj) |
404 | | - ): |
405 | | - collector = collector_or_item |
406 | | - async_functions.append(collector) |
407 | | - continue |
408 | | - item = collector_or_item |
409 | | - async_function = AsyncFunction.from_parent( |
410 | | - item.parent, |
411 | | - name=item.name, |
412 | | - callspec=getattr(item, "callspec", None), |
413 | | - callobj=item.obj, |
414 | | - fixtureinfo=item._fixtureinfo, |
415 | | - keywords=item.keywords, |
416 | | - originalname=item.originalname, |
417 | | - ) |
418 | | - async_functions.append(async_function) |
419 | | - hook_result.force_result(async_functions) |
| 415 | + updated_node_collection = [] |
| 416 | + for node in node_iterator: |
| 417 | + if isinstance(node, pytest.Function) and _is_coroutine_or_asyncgen(obj): |
| 418 | + async_function = AsyncFunction.from_function(node) |
| 419 | + updated_node_collection.append(async_function) |
| 420 | + else: |
| 421 | + updated_node_collection.append(node) |
| 422 | + hook_result.force_result(updated_node_collection) |
420 | 423 |
|
421 | 424 |
|
422 | 425 | _event_loop_fixture_id = StashKey[str] |
|
0 commit comments