diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index 4db58460..a707019e 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -151,6 +151,14 @@ like this, which is much easier to read: Be sure to call pipe.multi() in the callable passed to Valkey.transaction prior to any write commands. +.. warning:: + + Transactions are retried an infinite amount of times by default, which + can lead to infinite loops - for example, if one were to accidentally + write ``r.set`` instead of ``pipe.set`` in the above example. You can + use the ``transaction`` arguments ``watch_delay`` and ``max_tries`` + to mitigate this risk. + Pipelines in clusters ~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 5021f91c..8ef1873c 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -329,6 +329,47 @@ async def my_transaction(pipe): assert result == [True] assert await r.get("c") == b"4" + @pytest.mark.onlynoncluster + async def test_transaction_loop(self, r): + await r.set("a", 0) + run_count = 0 + + async def my_transaction(pipe): + nonlocal run_count + run_count += 1 + if run_count > 10: + raise RuntimeError("Run too many times") + a_value = int(await pipe.get("a")) + pipe.multi() + await r.set("a", a_value + 1) # force WatchError + + # without max_tries (infinite loop) + with pytest.raises(RuntimeError) as ex: + await r.transaction(my_transaction, "a") + assert str(ex.value).startswith("Run too many times") + assert run_count == 11 + + run_count = 0 + # with max_tries + with pytest.raises(valkey.ValkeyError) as ex: + await r.transaction(my_transaction, "a", max_tries=3) + assert str(ex.value).startswith("Bailing out of transaction after 3 tries") + assert run_count == 3 + + run_count = 0 + # with max_tries=0 (same as without; infinite loop) + with pytest.raises(RuntimeError) as ex: + await r.transaction(my_transaction, "a", max_tries=0) + assert str(ex.value).startswith("Run too many times") + assert run_count == 11 + + run_count = 0 + # with negative max_tries (immediate error) + with pytest.raises(valkey.ValkeyError) as ex: + await r.transaction(my_transaction, "a", max_tries=-3) + assert str(ex.value).startswith("Bailing out of transaction after 0 tries") + assert run_count == 0 + @pytest.mark.onlynoncluster async def test_transaction_callable_returns_value_from_callable(self, r): async def callback(pipe): diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 065f898c..3b8339ca 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -330,6 +330,47 @@ def my_transaction(pipe): assert result == [True] assert r["c"] == b"4" + @pytest.mark.onlynoncluster + def test_transaction_loop(self, r): + r["a"] = 0 + run_count = 0 + + def my_transaction(pipe): + nonlocal run_count + run_count += 1 + if run_count > 10: + raise RuntimeError("Run too many times") + a_value = int(pipe.get("a")) + pipe.multi() + r.set("a", a_value + 1) # force WatchError + + # without max_tries (infinite loop) + with pytest.raises(RuntimeError) as ex: + r.transaction(my_transaction, "a") + assert str(ex.value).startswith("Run too many times") + assert run_count == 11 + + run_count = 0 + # with max_tries + with pytest.raises(valkey.ValkeyError) as ex: + r.transaction(my_transaction, "a", max_tries=3) + assert str(ex.value).startswith("Bailing out of transaction after 3 tries") + assert run_count == 3 + + run_count = 0 + # with max_tries=0 (same as without; infinite loop) + with pytest.raises(RuntimeError) as ex: + r.transaction(my_transaction, "a", max_tries=0) + assert str(ex.value).startswith("Run too many times") + assert run_count == 11 + + run_count = 0 + # with negative max_tries (immediate error) + with pytest.raises(valkey.ValkeyError) as ex: + r.transaction(my_transaction, "a", max_tries=-3) + assert str(ex.value).startswith("Bailing out of transaction after 0 tries") + assert run_count == 0 + @pytest.mark.onlynoncluster def test_transaction_callable_returns_value_from_callable(self, r): def callback(pipe): diff --git a/valkey/asyncio/client.py b/valkey/asyncio/client.py index 10c47818..851b2cc1 100644 --- a/valkey/asyncio/client.py +++ b/valkey/asyncio/client.py @@ -437,15 +437,28 @@ async def transaction( shard_hint: Optional[str] = None, value_from_callable: bool = False, watch_delay: Optional[float] = None, + max_tries: Optional[int] = None, ): """ Convenience method for executing the callable `func` as a transaction while watching all keys specified in `watches`. The 'func' callable should expect a single argument which is a Pipeline object. + + :param watch_delay: Lets you specify a delay time in seconds after a + `WatchError` before the transaction is retried. Default is no delay. + :param max_tries: Lets you specify the maximum number of times the + transaction is retried. If the limit is reached, a `ValkeyError` + is raised. Default is 0, meaning an **infinite** number of retries! """ pipe: Pipeline async with self.pipeline(True, shard_hint) as pipe: + tries = 0 while True: + tries += 1 + if max_tries and tries > max_tries: + raise ValkeyError( + f"Bailing out of transaction after {tries - 1} tries" + ) try: if watches: await pipe.watch(*watches) diff --git a/valkey/client.py b/valkey/client.py index 88703129..df265ab4 100755 --- a/valkey/client.py +++ b/valkey/client.py @@ -404,12 +404,27 @@ def transaction( Convenience method for executing the callable `func` as a transaction while watching all keys specified in `watches`. The 'func' callable should expect a single argument which is a Pipeline object. + + :param watch_delay: This keyword-only argument lets you specify a + delay time in seconds after a `WatchError` before the transaction + is retried. Default is no delay. + :param max_tries: This keyword-only argument lets you specify the + maximum number to times the transaction should be retried. If the + limit is reached, a `ValkeyError` is raised. Default is 0, meaning + an **infinite** number of retries! """ shard_hint = kwargs.pop("shard_hint", None) value_from_callable = kwargs.pop("value_from_callable", False) watch_delay = kwargs.pop("watch_delay", None) + max_tries = kwargs.pop("max_tries", None) with self.pipeline(True, shard_hint) as pipe: + tries = 0 while True: + tries += 1 + if max_tries and tries > max_tries: + raise ValkeyError( + f"Bailing out of transaction after {tries - 1} tries" + ) try: if watches: pipe.watch(*watches)