@@ -816,59 +816,83 @@ async def command_error(ctx, error):
816816 message = f"Your command needs at least { 'three parameters to return self, context, and the' if self .extension else 'two parameter to return context and' } error." ,
817817 )
818818
819- self .error_callback = self .__wrap_coro (coro )
819+ self .error_callback = self .__wrap_coro (coro , error_callback = True )
820820 return coro
821821
822822 async def __call (
823823 self ,
824824 coro : Callable [..., Awaitable ],
825825 ctx : "CommandContext" ,
826- * args ,
826+ * args , # empty for now since all parameters are dispatched as kwargs
827827 _name : Optional [str ] = None ,
828828 _res : Optional [Union [BaseResult , GroupResult ]] = None ,
829829 ** kwargs ,
830830 ) -> Optional [Any ]:
831831 """Handles calling the coroutine based on parameter count."""
832- param_len = len (signature (coro ).parameters )
833- opt_len = self .num_options .get (_name , len (args ) + len (kwargs ))
832+ params = signature (coro ).parameters
833+ param_len = len (params )
834+ opt_len = self .num_options .get (_name , len (args ) + len (kwargs )) # options of slash command
835+ last = params [list (params )[- 1 ]] # last parameter
836+ has_args = any (param .kind == param .VAR_POSITIONAL for param in params .values ()) # any *args
837+ index_of_var_pos = next (
838+ (i for i , param in enumerate (params .values ()) if param .kind == param .VAR_POSITIONAL ),
839+ param_len ,
840+ ) # index of *args
841+ par_opts = list (params .keys ())[
842+ (num := 2 if self .extension else 1 ) : (
843+ - 1 if last .kind in (last .VAR_POSITIONAL , last .VAR_KEYWORD ) else index_of_var_pos
844+ )
845+ ] # parameters that are before *args and **kwargs
846+ keyword_only_args = list (params .keys ())[index_of_var_pos :] # parameters after *args
834847
835848 try :
836849 _coro = coro if hasattr (coro , "_wrapped" ) else self .__wrap_coro (coro )
837850
838- if param_len < (2 if self .extension else 1 ):
851+ if last .kind == last .VAR_KEYWORD : # foo(ctx, ..., **kwargs)
852+ return await _coro (ctx , * args , ** kwargs )
853+ if last .kind == last .VAR_POSITIONAL : # foo(ctx, ..., *args)
854+ return await _coro (
855+ ctx ,
856+ * (kwargs [opt ] for opt in par_opts if opt in kwargs ),
857+ * args ,
858+ )
859+ if has_args : # foo(ctx, ..., *args, ..., **kwargs) OR foo(ctx, *args, ...)
860+ return await _coro (
861+ ctx ,
862+ * (kwargs [opt ] for opt in par_opts if opt in kwargs ), # pos before *args
863+ * args ,
864+ * (
865+ kwargs [opt ]
866+ for opt in kwargs
867+ if opt not in par_opts and opt not in keyword_only_args
868+ ), # additional args
869+ ** {
870+ opt : kwargs [opt ]
871+ for opt in kwargs
872+ if opt not in par_opts and opt in keyword_only_args
873+ }, # kwargs after *args
874+ )
875+
876+ if param_len < num :
877+ inner_msg : str = f"{ num } parameter{ 's' if num > 1 else '' } to return" + (
878+ " self and" if self .extension else ""
879+ )
839880 raise LibraryException (
840- code = 11 ,
841- message = f"Your command needs at least { 'two parameters to return self and' if self .extension else 'one parameter to return' } context." ,
881+ code = 11 , message = f"Your command needs at least { inner_msg } context."
842882 )
843883
844- if param_len == ( 2 if self . extension else 1 ) :
884+ if param_len == num :
845885 return await _coro (ctx )
846886
847887 if _res :
848- if param_len - opt_len == ( 2 if self . extension else 1 ) :
888+ if param_len - opt_len == num :
849889 return await _coro (ctx , * args , ** kwargs )
850- elif param_len - opt_len == ( 3 if self . extension else 2 ) :
890+ elif param_len - opt_len == num + 1 :
851891 return await _coro (ctx , _res , * args , ** kwargs )
852892
853893 return await _coro (ctx , * args , ** kwargs )
854894 except CancelledError :
855895 pass
856- except Exception as e :
857- if self .error_callback :
858- num_params = len (signature (self .error_callback ).parameters )
859-
860- if num_params == (3 if self .extension else 2 ):
861- await self .error_callback (ctx , e )
862- elif num_params == (4 if self .extension else 3 ):
863- await self .error_callback (ctx , e , _res )
864- else :
865- await self .error_callback (ctx , e , _res , * args , ** kwargs )
866- elif self .listener and "on_command_error" in self .listener .events :
867- self .listener .dispatch ("on_command_error" , ctx , e )
868- else :
869- raise e
870-
871- return StopCommand
872896
873897 def __check_command (self , command_type : str ) -> None :
874898 """Checks if subcommands, groups, or autocompletions are created on context menus."""
@@ -895,7 +919,9 @@ async def __no_group(self, *args, **kwargs) -> None:
895919 """This is the coroutine used when no group coroutine is provided."""
896920 pass
897921
898- def __wrap_coro (self , coro : Callable [..., Awaitable ]) -> Callable [..., Awaitable ]:
922+ def __wrap_coro (
923+ self , coro : Callable [..., Awaitable ], / , * , error_callback : bool = False
924+ ) -> Callable [..., Awaitable ]:
899925 """Wraps a coroutine to make sure the :class:`interactions.client.bot.Extension` is passed to the coroutine, if any."""
900926
901927 @wraps (coro )
@@ -907,11 +933,28 @@ async def wrapper(ctx: "CommandContext", *args, **kwargs):
907933 except CancelledError :
908934 pass
909935 except Exception as e :
936+ if error_callback :
937+ raise e
910938 if self .error_callback :
911- num_params = len (signature (self .error_callback ).parameters )
912-
913- if num_params == (3 if self .extension else 2 ):
939+ params = signature (self .error_callback ).parameters
940+ num_params = len (params )
941+ last = params [list (params )[- 1 ]]
942+ num = 2 if self .extension else 1
943+
944+ if num_params == num :
945+ await self .error_callback (ctx )
946+ elif num_params == num + 1 :
914947 await self .error_callback (ctx , e )
948+ elif last .kind == last .VAR_KEYWORD :
949+ if num_params == num + 2 :
950+ await self .error_callback (ctx , e , ** kwargs )
951+ elif num_params >= num + 3 :
952+ await self .error_callback (ctx , e , * args , ** kwargs )
953+ elif last .kind == last .VAR_POSITIONAL :
954+ if num_params == num + 2 :
955+ await self .error_callback (ctx , e , * args )
956+ elif num_params >= num + 3 :
957+ await self .error_callback (ctx , e , * args , ** kwargs )
915958 else :
916959 await self .error_callback (ctx , e , * args , ** kwargs )
917960 elif self .listener and "on_command_error" in self .listener .events :
0 commit comments