Skip to content

Commit 125de62

Browse files
Paillat-devopenhands-agentnicebots-xyz-bot
authored
✨ Enhanced cooldown system with bucket types (#57)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: nicebots-xyz-bot <hello@nicebots.xyz>
1 parent eecc0c3 commit 125de62

File tree

5 files changed

+110
-8
lines changed

5 files changed

+110
-8
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) NiceBots
2+
# SPDX-License-Identifier: MIT
3+
4+
from typing import Any, final, override
5+
6+
import discord
7+
8+
from src import custom
9+
from src.i18n.classes import RawTranslation, apply_locale
10+
from src.utils.cooldown import CooldownExceeded
11+
12+
from .base import BaseErrorHandler, ErrorHandlerRType
13+
14+
15+
@final
16+
class CooldownErrorHandler(BaseErrorHandler[CooldownExceeded]):
17+
def __init__(self, translations: dict[str, RawTranslation]) -> None:
18+
self.translations = translations
19+
super().__init__(CooldownExceeded)
20+
21+
@override
22+
async def __call__(
23+
self,
24+
error: CooldownExceeded,
25+
ctx: custom.Context | discord.Interaction,
26+
sendargs: dict[str, Any],
27+
message: str,
28+
report: bool,
29+
) -> ErrorHandlerRType:
30+
translations = apply_locale(self.translations, self._get_locale(ctx))
31+
32+
message = translations.error_cooldown_exceeded
33+
34+
sendargs["ephemeral"] = True
35+
36+
return False, False, message, sendargs

src/extensions/nice_errors/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from schema import Optional, Schema
99

1010
from src import custom
11+
from src.utils.cooldown import CooldownExceeded
1112

1213
from .handlers import error_handler
14+
from .handlers.cooldown import CooldownErrorHandler
1315
from .handlers.forbidden import ForbiddenErrorHandler
1416
from .handlers.generic import GenericErrorHandler
1517
from .handlers.not_found import NotFoundErrorHandler
@@ -64,3 +66,4 @@ def setup(bot: custom.Bot, config: dict[str, Any]) -> None:
6466
error_handler.add_error_handler(None, GenericErrorHandler(config["translations"]))
6567
error_handler.add_error_handler(commands.CommandNotFound, NotFoundErrorHandler(config["translations"]))
6668
error_handler.add_error_handler(discord.Forbidden, ForbiddenErrorHandler(config["translations"]))
69+
error_handler.add_error_handler(CooldownExceeded, CooldownErrorHandler(config["translations"]))

src/extensions/nice_errors/translations.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ strings:
2626
it: Ops! Non ho i permessi necessari per farlo.
2727
es-ES: ¡Ups! No tengo el permiso necesario para hacer eso.
2828
ru: Упс! У меня нет необходимых прав для выполнения этого действия.
29+
error_cooldown_exceeded:
30+
en-US: Whoops! You're doing that too fast. Please wait before trying again.
31+
de: Hoppla! Du machst das zu schnell. Bitte warte, bevor du es erneut versuchst.
32+
nl: Oeps! Je doet dat te snel. Wacht even voordat je het opnieuw probeert.
33+
fr: Oups ! Vous faites cela trop vite. Veuillez attendre avant de réessayer.
34+
it: Ops! Stai facendo troppo in fretta. Attendi prima di riprovare.
35+
es-ES: ¡Ups! Estás haciendo eso demasiado rápido. Por favor, espera antes de intentarlo de nuevo.
36+
ru: Упс! Вы делаете это слишком быстро. Пожалуйста, подождите, прежде чем попробовать снова.
2937
error_generic:
3038
en-US: Whoops! An error occurred while executing this command.
3139
de: Hoppla! Bei der Ausführung dieses Kommandos ist ein Fehler aufgetreten.

src/extensions/ping/ping.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from src import custom
1111
from src.log import logger
12-
from src.utils.cooldown import cooldown
12+
from src.utils.cooldown import BucketType, cooldown
1313

1414
default = {
1515
"enabled": True,
@@ -32,6 +32,7 @@ def __init__(self, bot: custom.Bot) -> None:
3232
limit=1,
3333
per=5,
3434
strong=True,
35+
bucket_type=BucketType.USER,
3536
)
3637
async def ping(
3738
self,

src/utils/cooldown.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import time
55
from collections.abc import Awaitable, Callable, Coroutine
6+
from enum import Enum
67
from functools import wraps
78
from inspect import isawaitable
89
from typing import Any, Concatenate, cast
@@ -15,6 +16,16 @@
1516
type CogCommandFunction[T: commands.Cog, **P] = Callable[Concatenate[T, custom.ApplicationContext, P], Awaitable[None]]
1617

1718

19+
class BucketType(Enum):
20+
DEFAULT = "default" # Uses provided key as is
21+
USER = "user" # Per-user cooldown
22+
MEMBER = "member" # Per-member (user+guild) cooldown
23+
GUILD = "guild" # Per-guild cooldown
24+
CHANNEL = "channel" # Per-channel cooldown
25+
CATEGORY = "category" # Per-category cooldown
26+
ROLE = "role" # Per-role cooldown (uses highest role)
27+
28+
1829
async def parse_reactive_setting[T](value: ReactiveCooldownSetting[T], bot: custom.Bot, ctx: custom.Context) -> T:
1930
if isinstance(value, type):
2031
return value # pyright: ignore [reportReturnType]
@@ -26,22 +37,58 @@ async def parse_reactive_setting[T](value: ReactiveCooldownSetting[T], bot: cust
2637

2738

2839
class CooldownExceeded(commands.CheckFailure):
29-
def __init__(self, retry_after: float) -> None:
40+
def __init__(self, retry_after: float, bucket_type: BucketType) -> None:
3041
self.retry_after: float = retry_after
31-
super().__init__("You are on cooldown")
42+
self.bucket_type: BucketType = bucket_type
43+
super().__init__(f"You are on {bucket_type.value} cooldown")
3244

3345

34-
# inspired by https://github.com/ItsDrike/code-jam-2024/blob/main/src/utils/ratelimit.py
46+
def get_bucket_key(ctx: custom.ApplicationContext, base_key: str, bucket_type: BucketType) -> str: # noqa: PLR0911
47+
"""Generate a cooldown key based on the bucket type."""
48+
match bucket_type:
49+
case BucketType.USER:
50+
return f"{base_key}:user:{ctx.author.id}"
51+
case BucketType.MEMBER:
52+
return (
53+
f"{base_key}:member:{ctx.guild_id}:{ctx.author.id}" if ctx.guild else f"{base_key}:user:{ctx.author.id}"
54+
)
55+
case BucketType.GUILD:
56+
return f"{base_key}:guild:{ctx.guild_id}" if ctx.guild else base_key
57+
case BucketType.CHANNEL:
58+
return f"{base_key}:channel:{ctx.channel.id}"
59+
case BucketType.CATEGORY:
60+
category_id = ctx.channel.category_id if hasattr(ctx.channel, "category_id") else None
61+
return f"{base_key}:category:{category_id}" if category_id else f"{base_key}:channel:{ctx.channel.id}"
62+
case BucketType.ROLE:
63+
if ctx.guild and hasattr(ctx.author, "roles"):
64+
top_role_id = max((role.id for role in ctx.author.roles), default=0)
65+
return f"{base_key}:role:{top_role_id}"
66+
return f"{base_key}:user:{ctx.author.id}"
67+
case _: # BucketType.DEFAULT
68+
return base_key
3569

3670

37-
def cooldown[C: commands.Cog, **P](
71+
def cooldown[C: commands.Cog, **P]( # noqa: PLR0913
3872
key: ReactiveCooldownSetting[str],
3973
*,
4074
limit: ReactiveCooldownSetting[int],
4175
per: ReactiveCooldownSetting[int],
76+
bucket_type: ReactiveCooldownSetting[BucketType] = BucketType.DEFAULT,
4277
strong: ReactiveCooldownSetting[bool] = False,
4378
cls: ReactiveCooldownSetting[type[CooldownExceeded]] = CooldownExceeded,
4479
) -> Callable[[CogCommandFunction[C, P]], CogCommandFunction[C, P]]:
80+
"""Enhanced cooldown decorator that supports different bucket types.
81+
82+
Args:
83+
key: Base key for the cooldown
84+
limit: Number of uses allowed
85+
per: Time period in seconds
86+
bucket_type: Type of bucket to use for the cooldown
87+
strong: If True, adds current timestamp even if limit is reached
88+
cls: Custom exception class to raise
89+
90+
"""
91+
4592
def inner(func: CogCommandFunction[C, P]) -> CogCommandFunction[C, P]:
4693
@wraps(func)
4794
async def wrapper(self: C, ctx: custom.ApplicationContext, *args: P.args, **kwargs: P.kwargs) -> None:
@@ -51,17 +98,24 @@ async def wrapper(self: C, ctx: custom.ApplicationContext, *args: P.args, **kwar
5198
per_value: int = await parse_reactive_setting(per, ctx.bot, ctx)
5299
strong_value: bool = await parse_reactive_setting(strong, ctx.bot, ctx)
53100
cls_value: type[CooldownExceeded] = await parse_reactive_setting(cls, ctx.bot, ctx)
101+
bucket_type_value: BucketType = await parse_reactive_setting(bucket_type, ctx.bot, ctx)
102+
103+
# Generate the full cooldown key based on bucket type
104+
full_key = get_bucket_key(ctx, key_value, bucket_type_value)
105+
54106
now = time.time()
55-
time_stamps = cast(tuple[float, ...], await cache.get(key_value, default=(), namespace="cooldown"))
107+
time_stamps = cast(tuple[float, ...], await cache.get(full_key, default=(), namespace="cooldown"))
56108
time_stamps = tuple(filter(lambda x: x > now - per_value, time_stamps))
57109
time_stamps = time_stamps[-limit_value:]
110+
58111
if len(time_stamps) < limit_value or strong_value:
59112
time_stamps = (*time_stamps, now)
60-
await cache.set(key_value, time_stamps, namespace="cooldown", ttl=per_value)
113+
await cache.set(full_key, time_stamps, namespace="cooldown", ttl=per_value)
61114
limit_value += 1 # to account for the current command
62115

63116
if len(time_stamps) >= limit_value:
64-
raise cls_value(min(time_stamps) - now + per_value)
117+
raise cls_value(min(time_stamps) - now + per_value, bucket_type_value)
118+
65119
await func(self, ctx, *args, **kwargs)
66120

67121
return wrapper

0 commit comments

Comments
 (0)