Skip to content

Commit a2c31d3

Browse files
committed
Add max_tries to transaction()
This new `max_tries` argument in the `transaction` helper allows one to more easily prevent (accidental) infinite loops. Commit includes documentation and tests. Signed-off-by: Hauke Daempfling <haukex@zero-g.net>
1 parent c291080 commit a2c31d3

File tree

5 files changed

+80
-0
lines changed

5 files changed

+80
-0
lines changed

docs/advanced_features.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ like this, which is much easier to read:
151151
Be sure to call pipe.multi() in the callable passed to Valkey.transaction
152152
prior to any write commands.
153153

154+
.. warning::
155+
156+
Transactions are retried an infinite amount of times by default, which
157+
can lead to infinite loops - for example, if one were to accidentally
158+
write ``r.set`` instead of ``pipe.set`` in the above example. You can
159+
use the ``transaction`` arguments ``watch_delay`` and ``max_tries``
160+
to mitigate this risk.
161+
154162
Pipelines in clusters
155163
~~~~~~~~~~~~~~~~~~~~~
156164

tests/test_asyncio/test_pipeline.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,28 @@ async def my_transaction(pipe):
329329
assert result == [True]
330330
assert await r.get("c") == b"4"
331331

332+
@pytest.mark.onlynoncluster
333+
async def test_transaction_loop(self, r):
334+
await r.set("a", 0)
335+
run_count = 0
336+
337+
async def my_transaction(pipe):
338+
nonlocal run_count
339+
run_count += 1
340+
if run_count > 10:
341+
raise RuntimeError("Run too many times")
342+
a_value = int(await pipe.get("a"))
343+
pipe.multi()
344+
await r.set("a", a_value + 1) # force WatchError
345+
346+
with pytest.raises(RuntimeError) as ex:
347+
await r.transaction(my_transaction, "a")
348+
assert str(ex.value).startswith("Run too many times")
349+
run_count = 0
350+
with pytest.raises(valkey.ValkeyError) as ex:
351+
await r.transaction(my_transaction, "a", max_tries=3)
352+
assert str(ex.value).startswith("Bailing out of transaction after 3 tries")
353+
332354
@pytest.mark.onlynoncluster
333355
async def test_transaction_callable_returns_value_from_callable(self, r):
334356
async def callback(pipe):

tests/test_pipeline.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,28 @@ def my_transaction(pipe):
330330
assert result == [True]
331331
assert r["c"] == b"4"
332332

333+
@pytest.mark.onlynoncluster
334+
def test_transaction_loop(self, r):
335+
r["a"] = 0
336+
run_count = 0
337+
338+
def my_transaction(pipe):
339+
nonlocal run_count
340+
run_count += 1
341+
if run_count > 10:
342+
raise RuntimeError("Run too many times")
343+
a_value = int(pipe.get("a"))
344+
pipe.multi()
345+
r.set("a", a_value + 1) # force WatchError
346+
347+
with pytest.raises(RuntimeError) as ex:
348+
r.transaction(my_transaction, "a")
349+
assert str(ex.value).startswith("Run too many times")
350+
run_count = 0
351+
with pytest.raises(valkey.ValkeyError) as ex:
352+
r.transaction(my_transaction, "a", max_tries=3)
353+
assert str(ex.value).startswith("Bailing out of transaction after 3 tries")
354+
333355
@pytest.mark.onlynoncluster
334356
def test_transaction_callable_returns_value_from_callable(self, r):
335357
def callback(pipe):

valkey/asyncio/client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,15 +437,28 @@ async def transaction(
437437
shard_hint: Optional[str] = None,
438438
value_from_callable: bool = False,
439439
watch_delay: Optional[float] = None,
440+
max_tries: Optional[int] = None,
440441
):
441442
"""
442443
Convenience method for executing the callable `func` as a transaction
443444
while watching all keys specified in `watches`. The 'func' callable
444445
should expect a single argument which is a Pipeline object.
446+
447+
:param watch_delay: Lets you specify a delay time in seconds after a
448+
`WatchError` before the transaction is retried. Default is no delay.
449+
:param max_tries: Lets you specify the maximum number of times the
450+
transaction is retried. If the limit is reached, a `ValkeyError`
451+
is raised. Default is an **infinite** number of retries!
445452
"""
446453
pipe: Pipeline
447454
async with self.pipeline(True, shard_hint) as pipe:
455+
tries = 0
448456
while True:
457+
tries += 1
458+
if max_tries and max_tries > 0 and tries > max_tries:
459+
raise ValkeyError(
460+
f"Bailing out of transaction after {tries - 1} tries"
461+
)
449462
try:
450463
if watches:
451464
await pipe.watch(*watches)

valkey/client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,12 +404,27 @@ def transaction(
404404
Convenience method for executing the callable `func` as a transaction
405405
while watching all keys specified in `watches`. The 'func' callable
406406
should expect a single argument which is a Pipeline object.
407+
408+
:param watch_delay: This keyword-only argument lets you specify a
409+
delay time in seconds after a `WatchError` before the transaction
410+
is retried. Default is no delay.
411+
:param max_tries: This keyword-only argument lets you specify the
412+
maximum number to times the transaction should be retried. If the
413+
limit is reached, a `ValkeyError` is raised. Default is an
414+
**infinite** number of retries!
407415
"""
408416
shard_hint = kwargs.pop("shard_hint", None)
409417
value_from_callable = kwargs.pop("value_from_callable", False)
410418
watch_delay = kwargs.pop("watch_delay", None)
419+
max_tries = kwargs.pop("max_tries", None)
411420
with self.pipeline(True, shard_hint) as pipe:
421+
tries = 0
412422
while True:
423+
tries += 1
424+
if max_tries and max_tries > 0 and tries > max_tries:
425+
raise ValkeyError(
426+
f"Bailing out of transaction after {tries - 1} tries"
427+
)
413428
try:
414429
if watches:
415430
pipe.watch(*watches)

0 commit comments

Comments
 (0)