33
44import time
55from collections .abc import Awaitable , Callable , Coroutine
6+ from enum import Enum
67from functools import wraps
78from inspect import isawaitable
89from typing import Any , Concatenate , cast
1516type CogCommandFunction [T : commands .Cog , ** P ] = Callable [Concatenate [T , custom .ApplicationContext , P ], Awaitable [None ]]
1617
1718
19+ class BucketType (Enum ):
20+ DEFAULT = "default" # Uses provided key as is
21+ USER = "user" # Per-user cooldown
22+ MEMBER = "member" # Per-member (user+guild) cooldown
23+ GUILD = "guild" # Per-guild cooldown
24+ CHANNEL = "channel" # Per-channel cooldown
25+ CATEGORY = "category" # Per-category cooldown
26+ ROLE = "role" # Per-role cooldown (uses highest role)
27+
28+
1829async def parse_reactive_setting [T ](value : ReactiveCooldownSetting [T ], bot : custom .Bot , ctx : custom .Context ) -> T :
1930 if isinstance (value , type ):
2031 return value # pyright: ignore [reportReturnType]
@@ -26,22 +37,58 @@ async def parse_reactive_setting[T](value: ReactiveCooldownSetting[T], bot: cust
2637
2738
2839class CooldownExceeded (commands .CheckFailure ):
29- def __init__ (self , retry_after : float ) -> None :
40+ def __init__ (self , retry_after : float , bucket_type : BucketType ) -> None :
3041 self .retry_after : float = retry_after
31- super ().__init__ ("You are on cooldown" )
42+ self .bucket_type : BucketType = bucket_type
43+ super ().__init__ (f"You are on { bucket_type .value } cooldown" )
3244
3345
34- # inspired by https://github.com/ItsDrike/code-jam-2024/blob/main/src/utils/ratelimit.py
46+ def get_bucket_key (ctx : custom .ApplicationContext , base_key : str , bucket_type : BucketType ) -> str : # noqa: PLR0911
47+ """Generate a cooldown key based on the bucket type."""
48+ match bucket_type :
49+ case BucketType .USER :
50+ return f"{ base_key } :user:{ ctx .author .id } "
51+ case BucketType .MEMBER :
52+ return (
53+ f"{ base_key } :member:{ ctx .guild_id } :{ ctx .author .id } " if ctx .guild else f"{ base_key } :user:{ ctx .author .id } "
54+ )
55+ case BucketType .GUILD :
56+ return f"{ base_key } :guild:{ ctx .guild_id } " if ctx .guild else base_key
57+ case BucketType .CHANNEL :
58+ return f"{ base_key } :channel:{ ctx .channel .id } "
59+ case BucketType .CATEGORY :
60+ category_id = ctx .channel .category_id if hasattr (ctx .channel , "category_id" ) else None
61+ return f"{ base_key } :category:{ category_id } " if category_id else f"{ base_key } :channel:{ ctx .channel .id } "
62+ case BucketType .ROLE :
63+ if ctx .guild and hasattr (ctx .author , "roles" ):
64+ top_role_id = max ((role .id for role in ctx .author .roles ), default = 0 )
65+ return f"{ base_key } :role:{ top_role_id } "
66+ return f"{ base_key } :user:{ ctx .author .id } "
67+ case _: # BucketType.DEFAULT
68+ return base_key
3569
3670
37- def cooldown [C : commands .Cog , ** P ](
71+ def cooldown [C : commands .Cog , ** P ]( # noqa: PLR0913
3872 key : ReactiveCooldownSetting [str ],
3973 * ,
4074 limit : ReactiveCooldownSetting [int ],
4175 per : ReactiveCooldownSetting [int ],
76+ bucket_type : ReactiveCooldownSetting [BucketType ] = BucketType .DEFAULT ,
4277 strong : ReactiveCooldownSetting [bool ] = False ,
4378 cls : ReactiveCooldownSetting [type [CooldownExceeded ]] = CooldownExceeded ,
4479) -> Callable [[CogCommandFunction [C , P ]], CogCommandFunction [C , P ]]:
80+ """Enhanced cooldown decorator that supports different bucket types.
81+
82+ Args:
83+ key: Base key for the cooldown
84+ limit: Number of uses allowed
85+ per: Time period in seconds
86+ bucket_type: Type of bucket to use for the cooldown
87+ strong: If True, adds current timestamp even if limit is reached
88+ cls: Custom exception class to raise
89+
90+ """
91+
4592 def inner (func : CogCommandFunction [C , P ]) -> CogCommandFunction [C , P ]:
4693 @wraps (func )
4794 async def wrapper (self : C , ctx : custom .ApplicationContext , * args : P .args , ** kwargs : P .kwargs ) -> None :
@@ -51,17 +98,24 @@ async def wrapper(self: C, ctx: custom.ApplicationContext, *args: P.args, **kwar
5198 per_value : int = await parse_reactive_setting (per , ctx .bot , ctx )
5299 strong_value : bool = await parse_reactive_setting (strong , ctx .bot , ctx )
53100 cls_value : type [CooldownExceeded ] = await parse_reactive_setting (cls , ctx .bot , ctx )
101+ bucket_type_value : BucketType = await parse_reactive_setting (bucket_type , ctx .bot , ctx )
102+
103+ # Generate the full cooldown key based on bucket type
104+ full_key = get_bucket_key (ctx , key_value , bucket_type_value )
105+
54106 now = time .time ()
55- time_stamps = cast (tuple [float , ...], await cache .get (key_value , default = (), namespace = "cooldown" ))
107+ time_stamps = cast (tuple [float , ...], await cache .get (full_key , default = (), namespace = "cooldown" ))
56108 time_stamps = tuple (filter (lambda x : x > now - per_value , time_stamps ))
57109 time_stamps = time_stamps [- limit_value :]
110+
58111 if len (time_stamps ) < limit_value or strong_value :
59112 time_stamps = (* time_stamps , now )
60- await cache .set (key_value , time_stamps , namespace = "cooldown" , ttl = per_value )
113+ await cache .set (full_key , time_stamps , namespace = "cooldown" , ttl = per_value )
61114 limit_value += 1 # to account for the current command
62115
63116 if len (time_stamps ) >= limit_value :
64- raise cls_value (min (time_stamps ) - now + per_value )
117+ raise cls_value (min (time_stamps ) - now + per_value , bucket_type_value )
118+
65119 await func (self , ctx , * args , ** kwargs )
66120
67121 return wrapper
0 commit comments