diff --git a/asyncstdlib/itertools.py b/asyncstdlib/itertools.py index febd591..e47b4de 100644 --- a/asyncstdlib/itertools.py +++ b/asyncstdlib/itertools.py @@ -8,7 +8,6 @@ Union, Callable, Optional, - Deque, Generic, Iterable, Iterator, @@ -17,7 +16,7 @@ overload, AsyncGenerator, ) -from collections import deque +from typing_extensions import TypeAlias from ._typing import ACloseable, R, T, AnyIterable, ADD from ._utility import public_module @@ -32,6 +31,7 @@ enumerate as aenumerate, iter as aiter, ) +from itertools import count as _counter S = TypeVar("S") T_co = TypeVar("T_co", covariant=True) @@ -346,57 +346,79 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: return None -async def tee_peer( - iterator: AsyncIterator[T], - # the buffer specific to this peer - buffer: Deque[T], - # the buffers of all peers, including our own - peers: List[Deque[T]], - lock: AsyncContextManager[Any], -) -> AsyncGenerator[T, None]: - """An individual iterator of a :py:func:`~.tee`""" - try: - while True: - if not buffer: - async with lock: - # Another peer produced an item while we were waiting for the lock. - # Proceed with the next loop iteration to yield the item. - if buffer: - continue - try: - item = await iterator.__anext__() - except StopAsyncIteration: - break - else: - # Append to all buffers, including our own. We'll fetch our - # item from the buffer again, instead of yielding it directly. - # This ensures the proper item ordering if any of our peers - # are fetching items concurrently. They may have buffered their - # item already. - for peer_buffer in peers: - peer_buffer.append(item) - yield buffer.popleft() - finally: - # this peer is done – remove its buffer - for idx, peer_buffer in enumerate(peers): # pragma: no branch - if peer_buffer is buffer: - peers.pop(idx) - break - # if we are the last peer, try and close the iterator - if not peers and isinstance(iterator, ACloseable): - await iterator.aclose() +_get_tee_index = _counter().__next__ + + +_TeeNode: TypeAlias = "list[T | _TeeNode[T]]" + + +class TeePeer(Generic[T]): + def __init__( + self, + iterator: AsyncIterator[T], + buffer: "_TeeNode[T]", + lock: AsyncContextManager[Any], + tee_peers: "set[int]", + ) -> None: + self._iterator = iterator + self._lock = lock + self._buffer: _TeeNode[T] = buffer + self._tee_peers = tee_peers + self._tee_idx = _get_tee_index() + self._tee_peers.add(self._tee_idx) + + def __aiter__(self): + return self + + async def __anext__(self) -> T: + # the buffer is a singly-linked list as [value, [value, [...]]] | [] + next_node = self._buffer + value: T + # for any most advanced TeePeer, the node is just [] + # fetch the next value so we can mutate the node to [value, [...]] + if not next_node: + async with self._lock: + # Check if another peer produced an item while we were waiting for the lock + if not next_node: + await self._extend_buffer(next_node) + # for any other TeePeer, the node is already some [value, [...]] + value, self._buffer = next_node # type: ignore + return value + + async def _extend_buffer(self, next_node: "_TeeNode[T]") -> None: + """Extend the buffer by fetching a new item from the iterable""" + try: + # another peer may fill the buffer while we wait here + next_value = await self._iterator.__anext__() + except StopAsyncIteration: + # no one else managed to fetch a value either + if not next_node: + raise + else: + # skip nodes that were filled in the meantime + while next_node: + _, next_node = next_node # type: ignore + next_node[:] = next_value, [] + + async def aclose(self) -> None: + self._tee_peers.discard(self._tee_idx) + if not self._tee_peers and isinstance(self._iterator, ACloseable): + await self._iterator.aclose() + + def __del__(self) -> None: + self._tee_peers.discard(self._tee_idx) @public_module(__name__, "tee") class Tee(Generic[T]): - """ + r""" Create ``n`` separate asynchronous iterators over ``iterable`` This splits a single ``iterable`` into multiple iterators, each providing the same items in the same order. All child iterators may advance separately but share the same items from ``iterable`` -- when the most advanced iterator retrieves an item, - it is buffered until the least advanced iterator has yielded it as well. + it is buffered until all other iterators have yielded it as well. A ``tee`` works lazily and can handle an infinite ``iterable``, provided that all iterators advance. @@ -407,16 +429,9 @@ async def derivative(sensor_data): await a.anext(previous) # advance one iterator return a.map(operator.sub, previous, current) - Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead - of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked - to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method - immediately closes all children, and it can be used in an ``async with`` context - for the same effect. - - If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not* - provide these items. Also, ``tee`` must internally buffer each item until the - last iterator has yielded it; if the most and least advanced iterator differ - by most data, using a :py:class:`list` is more efficient (but not lazy). + If ``iterable`` is an iterator and read elsewhere, ``tee`` will generally *not* + provide these items. However, a ``tee`` of a ``tee`` shares its buffer with parent, + sibling and child ``tee``\ s so that each sees the same items. If the underlying iterable is concurrency safe (``anext`` may be awaited concurrently) the resulting iterators are concurrency safe as well. Otherwise, @@ -424,9 +439,15 @@ async def derivative(sensor_data): To enforce sequential use of ``anext``, provide a ``lock`` - e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application - and access is automatically synchronised. + + Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead + of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked + to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method + immediately closes all children, and it can be used in an ``async with`` context + for the same effect. """ - __slots__ = ("_iterator", "_buffers", "_children") + __slots__ = ("_children",) def __init__( self, @@ -435,16 +456,24 @@ def __init__( *, lock: Optional[AsyncContextManager[Any]] = None, ): - self._iterator = aiter(iterable) - self._buffers: List[Deque[T]] = [deque() for _ in range(n)] + buffer: _TeeNode[T] + peers: set[int] + if not isinstance(iterable, TeePeer): + iterator = aiter(iterable) + buffer = [] + peers = set() + else: + iterator = iterable._iterator # pyright: ignore[reportPrivateUsage] + buffer = iterable._buffer # pyright: ignore[reportPrivateUsage] + peers = iterable._tee_peers # pyright: ignore[reportPrivateUsage] self._children = tuple( - tee_peer( - iterator=self._iterator, - buffer=buffer, - peers=self._buffers, - lock=lock if lock is not None else NoLock(), + TeePeer( + iterator, + buffer, + lock if lock is not None else NoLock(), + peers, ) - for buffer in self._buffers + for _ in range(n) ) def __len__(self) -> int: diff --git a/docs/source/api/itertools.rst b/docs/source/api/itertools.rst index c61b750..ccfc50d 100644 --- a/docs/source/api/itertools.rst +++ b/docs/source/api/itertools.rst @@ -85,6 +85,10 @@ Iterator splitting The ``lock`` keyword parameter. + .. versionchanged:: 3.13.2 + + ``tee``\ s share their buffer with parents, siblings and children. + .. autofunction:: pairwise(iterable: (async) iter T) :async-for: :(T, T) diff --git a/unittests/test_itertools.py b/unittests/test_itertools.py index 82e4e7a..4e3b913 100644 --- a/unittests/test_itertools.py +++ b/unittests/test_itertools.py @@ -1,3 +1,4 @@ +from typing import AsyncIterator import itertools import sys import platform @@ -341,7 +342,7 @@ async def test_tee(): @sync async def test_tee_concurrent_locked(): - """Test that properly uses a lock for synchronisation""" + """Test that tee properly uses a lock for synchronisation""" items = [1, 2, 3, -5, 12, 78, -1, 111] async def iter_values(): @@ -393,6 +394,41 @@ async def test_peer(peer_tee): await test_peer(this) +@pytest.mark.parametrize("size", [2, 3, 5, 9, 12]) +@sync +async def test_tee_concurrent_ordering(size: int): + """Test that tee respects concurrent ordering for all peers""" + + class ConcurrentInvertedIterable: + """Helper that concurrently iterates with earlier items taking longer""" + + def __init__(self, count: int) -> None: + self.count = count + self._counter = itertools.count() + + def __aiter__(self): + return self + + async def __anext__(self): + value = next(self._counter) + if value >= self.count: + raise StopAsyncIteration() + await Switch(self.count - value) + return value + + async def test_peer(peer_tee: AsyncIterator[int]): + # consume items from the tee with a delay so that slower items can arrive + seen_items: list[int] = [] + async for item in peer_tee: + seen_items.append(item) + await Switch() + assert seen_items == expected_items + + expected_items = list(range(size)[::-1]) + peers = a.tee(ConcurrentInvertedIterable(size), n=size) + await Schedule(*map(test_peer, peers)) + + @sync async def test_pairwise(): assert await a.list(a.pairwise(range(5))) == [(0, 1), (1, 2), (2, 3), (3, 4)]