Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 91 additions & 62 deletions asyncstdlib/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Union,
Callable,
Optional,
Deque,
Generic,
Iterable,
Iterator,
Expand All @@ -17,7 +16,7 @@
overload,
AsyncGenerator,
)
from collections import deque
from typing_extensions import TypeAlias

from ._typing import ACloseable, R, T, AnyIterable, ADD
from ._utility import public_module
Expand All @@ -32,6 +31,7 @@
enumerate as aenumerate,
iter as aiter,
)
from itertools import count as _counter

S = TypeVar("S")
T_co = TypeVar("T_co", covariant=True)
Expand Down Expand Up @@ -346,57 +346,79 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
return None


async def tee_peer(
iterator: AsyncIterator[T],
# the buffer specific to this peer
buffer: Deque[T],
# the buffers of all peers, including our own
peers: List[Deque[T]],
lock: AsyncContextManager[Any],
) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`"""
try:
while True:
if not buffer:
async with lock:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = await iterator.__anext__()
except StopAsyncIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
# this peer is done – remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)
break
# if we are the last peer, try and close the iterator
if not peers and isinstance(iterator, ACloseable):
await iterator.aclose()
_get_tee_index = _counter().__next__


_TeeNode: TypeAlias = "list[T | _TeeNode[T]]"


class TeePeer(Generic[T]):
def __init__(
self,
iterator: AsyncIterator[T],
buffer: "_TeeNode[T]",
lock: AsyncContextManager[Any],
tee_peers: "set[int]",
) -> None:
self._iterator = iterator
self._lock = lock
self._buffer: _TeeNode[T] = buffer
self._tee_peers = tee_peers
self._tee_idx = _get_tee_index()
self._tee_peers.add(self._tee_idx)

def __aiter__(self):
return self

async def __anext__(self) -> T:
# the buffer is a singly-linked list as [value, [value, [...]]] | []
next_node = self._buffer
value: T
# for any most advanced TeePeer, the node is just []
# fetch the next value so we can mutate the node to [value, [...]]
if not next_node:
async with self._lock:
# Check if another peer produced an item while we were waiting for the lock
if not next_node:
await self._extend_buffer(next_node)
# for any other TeePeer, the node is already some [value, [...]]
value, self._buffer = next_node # type: ignore
return value

async def _extend_buffer(self, next_node: "_TeeNode[T]") -> None:
"""Extend the buffer by fetching a new item from the iterable"""
try:
# another peer may fill the buffer while we wait here
next_value = await self._iterator.__anext__()
except StopAsyncIteration:
# no one else managed to fetch a value either
if not next_node:
raise
else:
# skip nodes that were filled in the meantime
while next_node:
_, next_node = next_node # type: ignore
next_node[:] = next_value, []

async def aclose(self) -> None:
self._tee_peers.discard(self._tee_idx)
if not self._tee_peers and isinstance(self._iterator, ACloseable):
await self._iterator.aclose()

def __del__(self) -> None:
self._tee_peers.discard(self._tee_idx)


@public_module(__name__, "tee")
class Tee(Generic[T]):
"""
r"""
Create ``n`` separate asynchronous iterators over ``iterable``

This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.
All child iterators may advance separately but share the same items
from ``iterable`` -- when the most advanced iterator retrieves an item,
it is buffered until the least advanced iterator has yielded it as well.
it is buffered until all other iterators have yielded it as well.
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
that all iterators advance.

Expand All @@ -407,26 +429,25 @@ async def derivative(sensor_data):
await a.anext(previous) # advance one iterator
return a.map(operator.sub, previous, current)

Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
immediately closes all children, and it can be used in an ``async with`` context
for the same effect.

If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
provide these items. Also, ``tee`` must internally buffer each item until the
last iterator has yielded it; if the most and least advanced iterator differ
by most data, using a :py:class:`list` is more efficient (but not lazy).
If ``iterable`` is an iterator and read elsewhere, ``tee`` will generally *not*
provide these items. However, a ``tee`` of a ``tee`` shares its buffer with parent,
sibling and child ``tee``\ s so that each sees the same items.

If the underlying iterable is concurrency safe (``anext`` may be awaited
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
the iterators are safe if there is only ever one single "most advanced" iterator.
To enforce sequential use of ``anext``, provide a ``lock``
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
and access is automatically synchronised.

Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
immediately closes all children, and it can be used in an ``async with`` context
for the same effect.
"""

__slots__ = ("_iterator", "_buffers", "_children")
__slots__ = ("_children",)

def __init__(
self,
Expand All @@ -435,16 +456,24 @@ def __init__(
*,
lock: Optional[AsyncContextManager[Any]] = None,
):
self._iterator = aiter(iterable)
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
buffer: _TeeNode[T]
peers: set[int]
if not isinstance(iterable, TeePeer):
iterator = aiter(iterable)
buffer = []
peers = set()
else:
iterator = iterable._iterator # pyright: ignore[reportPrivateUsage]
buffer = iterable._buffer # pyright: ignore[reportPrivateUsage]
peers = iterable._tee_peers # pyright: ignore[reportPrivateUsage]
self._children = tuple(
tee_peer(
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
lock=lock if lock is not None else NoLock(),
TeePeer(
iterator,
buffer,
lock if lock is not None else NoLock(),
peers,
)
for buffer in self._buffers
for _ in range(n)
)

def __len__(self) -> int:
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/itertools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ Iterator splitting

The ``lock`` keyword parameter.

.. versionchanged:: 3.13.2

``tee``\ s share their buffer with parents, siblings and children.

.. autofunction:: pairwise(iterable: (async) iter T)
:async-for: :(T, T)

Expand Down
38 changes: 37 additions & 1 deletion unittests/test_itertools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import AsyncIterator
import itertools
import sys
import platform
Expand Down Expand Up @@ -341,7 +342,7 @@ async def test_tee():

@sync
async def test_tee_concurrent_locked():
"""Test that properly uses a lock for synchronisation"""
"""Test that tee properly uses a lock for synchronisation"""
items = [1, 2, 3, -5, 12, 78, -1, 111]

async def iter_values():
Expand Down Expand Up @@ -393,6 +394,41 @@ async def test_peer(peer_tee):
await test_peer(this)


@pytest.mark.parametrize("size", [2, 3, 5, 9, 12])
@sync
async def test_tee_concurrent_ordering(size: int):
"""Test that tee respects concurrent ordering for all peers"""

class ConcurrentInvertedIterable:
"""Helper that concurrently iterates with earlier items taking longer"""

def __init__(self, count: int) -> None:
self.count = count
self._counter = itertools.count()

def __aiter__(self):
return self

async def __anext__(self):
value = next(self._counter)
if value >= self.count:
raise StopAsyncIteration()
await Switch(self.count - value)
return value

async def test_peer(peer_tee: AsyncIterator[int]):
# consume items from the tee with a delay so that slower items can arrive
seen_items: list[int] = []
async for item in peer_tee:
seen_items.append(item)
await Switch()
assert seen_items == expected_items

expected_items = list(range(size)[::-1])
peers = a.tee(ConcurrentInvertedIterable(size), n=size)
await Schedule(*map(test_peer, peers))


@sync
async def test_pairwise():
assert await a.list(a.pairwise(range(5))) == [(0, 1), (1, 2), (2, 3), (3, 4)]
Expand Down