Skip to content

Commit 876b84c

Browse files
authored
feat: add global autocomplete (#1273)
* feat: add global autocomplete * refactor: linter pass * fix: remove cheeky breakpoint
1 parent 4faf842 commit 876b84c

File tree

7 files changed

+120
-24
lines changed

7 files changed

+120
-24
lines changed

interactions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@
133133
FlatUIColors,
134134
FlatUIColours,
135135
get_components_ids,
136+
global_autocomplete,
137+
GlobalAutoComplete,
136138
Greedy,
137139
Guild,
138140
guild_only,
@@ -444,7 +446,9 @@
444446
"FlatUIColours",
445447
"get_components_ids",
446448
"get_logger",
449+
"global_autocomplete",
447450
"GLOBAL_SCOPE",
451+
"GlobalAutoComplete",
448452
"GlobalScope",
449453
"Greedy",
450454
"Guild",

interactions/client/client.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import interactions.api.events as events
2828
import interactions.client.const as constants
29+
from interactions.models.internal.callback import CallbackObject
2930
from interactions.api.events import BaseEvent, RawGatewayEvent, processors
3031
from interactions.api.events.internal import CallbackAdded
3132
from interactions.api.gateway.gateway import GatewayClient
@@ -90,7 +91,7 @@
9091
from interactions.models.discord.file import UPLOADABLE_TYPE
9192
from interactions.models.discord.snowflake import Snowflake, to_snowflake_list
9293
from interactions.models.internal.active_voice_state import ActiveVoiceState
93-
from interactions.models.internal.application_commands import ContextMenu, ModalCommand
94+
from interactions.models.internal.application_commands import ContextMenu, ModalCommand, GlobalAutoComplete
9495
from interactions.models.internal.auto_defer import AutoDefer
9596
from interactions.models.internal.command import BaseCommand
9697
from interactions.models.internal.context import (
@@ -378,6 +379,7 @@ def __init__(
378379
"""A dictionary of registered application commands in a tree"""
379380
self._component_callbacks: Dict[str, Callable[..., Coroutine]] = {}
380381
self._modal_callbacks: Dict[str, Callable[..., Coroutine]] = {}
382+
self._global_autocompletes: Dict[str, GlobalAutoComplete] = {}
381383
self.processors: Dict[str, Callable[..., Coroutine]] = {}
382384
self.__modules = {}
383385
self.ext: Dict[str, Extension] = {}
@@ -1256,6 +1258,15 @@ def add_modal_callback(self, command: ModalCommand) -> None:
12561258
self._modal_callbacks[listener] = command
12571259
continue
12581260

1261+
def add_global_autocomplete(self, callback: GlobalAutoComplete) -> None:
1262+
"""
1263+
Add a global autocomplete to the client.
1264+
1265+
Args:
1266+
callback: The autocomplete to add
1267+
"""
1268+
self._global_autocompletes[callback.option_name] = callback
1269+
12591270
def add_command(self, func: Callable) -> None:
12601271
"""
12611272
Add a command to the client.
@@ -1271,6 +1282,8 @@ def add_command(self, func: Callable) -> None:
12711282
self.add_interaction(func)
12721283
elif isinstance(func, Listener):
12731284
self.add_listener(func)
1285+
elif isinstance(func, GlobalAutoComplete):
1286+
self.add_global_autocomplete(func)
12741287
elif not isinstance(func, BaseCommand):
12751288
raise TypeError("Invalid command type")
12761289

@@ -1302,12 +1315,10 @@ def process(callables, location: str) -> None:
13021315
self.logger.debug(f"{added} callbacks have been loaded from {location}.")
13031316

13041317
main_commands = [
1305-
obj for _, obj in inspect.getmembers(sys.modules["__main__"]) if isinstance(obj, (BaseCommand, Listener))
1318+
obj for _, obj in inspect.getmembers(sys.modules["__main__"]) if isinstance(obj, CallbackObject)
13061319
]
13071320
client_commands = [
1308-
obj.copy_with_binding(self)
1309-
for _, obj in inspect.getmembers(self)
1310-
if isinstance(obj, (BaseCommand, Listener))
1321+
obj.copy_with_binding(self) for _, obj in inspect.getmembers(self) if isinstance(obj, CallbackObject)
13111322
]
13121323
process(main_commands, "__main__")
13131324
process(client_commands, self.__class__.__name__)
@@ -1597,7 +1608,6 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None:
15971608
elif autocomplete := self._global_autocompletes.get(str(auto_opt.name)):
15981609
callback = autocomplete
15991610
else:
1600-
breakpoint()
16011611
raise ValueError(f"Autocomplete callback for {str(auto_opt.name)} not found")
16021612

16031613
await self.__dispatch_interaction(

interactions/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@
219219
DMConverter,
220220
DMGroupConverter,
221221
Extension,
222+
global_autocomplete,
223+
GlobalAutoComplete,
222224
Greedy,
223225
guild_only,
224226
GuildCategoryConverter,
@@ -382,6 +384,8 @@
382384
"FlatUIColors",
383385
"FlatUIColours",
384386
"get_components_ids",
387+
"global_autocomplete",
388+
"GlobalAutoComplete",
385389
"Greedy",
386390
"Guild",
387391
"guild_only",

interactions/models/internal/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
ComponentCommand,
2323
context_menu,
2424
ContextMenu,
25+
global_autocomplete,
26+
GlobalAutoComplete,
2527
InteractionCommand,
2628
LocalisedDesc,
2729
LocalisedName,
@@ -126,6 +128,8 @@
126128
"DMConverter",
127129
"DMGroupConverter",
128130
"Extension",
131+
"global_autocomplete",
132+
"GlobalAutoComplete",
129133
"Greedy",
130134
"guild_only",
131135
"GuildCategoryConverter",

interactions/models/internal/application_commands.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from interactions.models.discord.snowflake import to_snowflake_list, to_snowflake
4040
from interactions.models.discord.user import BaseUser
4141
from interactions.models.internal.auto_defer import AutoDefer
42+
from interactions.models.internal.callback import CallbackObject
4243
from interactions.models.internal.command import BaseCommand
4344
from interactions.models.internal.localisation import LocalisedField
4445

@@ -48,28 +49,30 @@
4849
from interactions import Client
4950

5051
__all__ = (
51-
"OptionType",
52+
"application_commands_to_dict",
53+
"auto_defer",
5254
"CallbackType",
53-
"InteractionCommand",
54-
"ContextMenu",
55-
"SlashCommandChoice",
56-
"SlashCommandOption",
57-
"SlashCommand",
55+
"component_callback",
5856
"ComponentCommand",
57+
"context_menu",
58+
"ContextMenu",
59+
"global_autocomplete",
60+
"GlobalAutoComplete",
61+
"InteractionCommand",
62+
"LocalisedDesc",
63+
"LocalisedName",
64+
"LocalizedDesc",
65+
"LocalizedName",
5966
"ModalCommand",
67+
"OptionType",
6068
"slash_command",
61-
"subcommand",
62-
"context_menu",
63-
"component_callback",
64-
"slash_option",
6569
"slash_default_member_permission",
66-
"auto_defer",
67-
"application_commands_to_dict",
70+
"slash_option",
71+
"SlashCommand",
72+
"SlashCommandChoice",
73+
"SlashCommandOption",
74+
"subcommand",
6875
"sync_needed",
69-
"LocalisedName",
70-
"LocalizedName",
71-
"LocalizedDesc",
72-
"LocalisedDesc",
7376
)
7477

7578

@@ -674,11 +677,36 @@ def _unpack_helper(iterable: typing.Iterable[str]) -> list[str]:
674677
return unpack
675678

676679

680+
class GlobalAutoComplete(CallbackObject):
681+
def __init__(self, option_name: str, callback: Callable) -> None:
682+
self.callback = callback
683+
self.option_name = option_name
684+
685+
677686
##############
678687
# Decorators #
679688
##############
680689

681690

691+
def global_autocomplete(option_name: str) -> Callable[[AsyncCallable], GlobalAutoComplete]:
692+
"""
693+
Decorator for global autocomplete functions
694+
695+
Args:
696+
option_name: The name of the option to register the autocomplete function for
697+
698+
Returns:
699+
The decorator
700+
"""
701+
702+
def decorator(func: Callable) -> GlobalAutoComplete:
703+
if not asyncio.iscoroutinefunction(func):
704+
raise TypeError("Autocomplete functions must be coroutines")
705+
return GlobalAutoComplete(option_name, func)
706+
707+
return decorator
708+
709+
682710
def slash_command(
683711
name: str | LocalisedName,
684712
*,

interactions/models/internal/extension.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import interactions.models.internal as models
77
import interactions.api.events as events
8+
from interactions.models.internal.callback import CallbackObject
89
from interactions.client.const import MISSING
910
from interactions.client.utils.misc_utils import wrap_partial
1011
from interactions.models.internal.tasks import Task
@@ -94,7 +95,7 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension":
9495
instance._listeners = []
9596

9697
callables: list[tuple[str, typing.Callable]] = inspect.getmembers(
97-
instance, predicate=lambda x: isinstance(x, (models.BaseCommand, models.Listener, Task))
98+
instance, predicate=lambda x: isinstance(x, (CallbackObject, Task))
9899
)
99100

100101
for _name, val in callables:
@@ -112,6 +113,10 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension":
112113
val = wrap_partial(val, instance)
113114
bot.add_listener(val) # type: ignore
114115
instance._listeners.append(val)
116+
elif isinstance(val, models.GlobalAutoComplete):
117+
val.extension = instance
118+
val = wrap_partial(val, instance)
119+
bot.add_global_autocomplete(val)
115120
bot.dispatch(events.ExtensionCommandParse(extension=instance, callables=callables))
116121

117122
instance.extension_name = inspect.getmodule(instance).__name__

main.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import os
33
import uuid
44

5+
from thefuzz import process
6+
57
import interactions
6-
from interactions import Client, listen, slash_command, BrandColours
8+
from interactions import Client, listen, slash_command, BrandColours, FlatUIColours, MaterialColours
9+
from interactions.models.internal.application_commands import global_autocomplete, slash_option
710

811
logging.basicConfig()
912
logging.getLogger("interactions").setLevel(logging.DEBUG)
@@ -104,4 +107,42 @@ async def multi_image_embed_test(ctx: interactions.SlashContext):
104107
await ctx.send(embeds=embed)
105108

106109

110+
def get_colour(colour: str):
111+
if colour in interactions.MaterialColors.__members__:
112+
return interactions.MaterialColors[colour]
113+
elif colour in interactions.BrandColors.__members__:
114+
return interactions.BrandColors[colour]
115+
elif colour in interactions.FlatUIColours.__members__:
116+
return interactions.FlatUIColours[colour]
117+
else:
118+
return interactions.BrandColors.BLURPLE
119+
120+
121+
@slash_command("test")
122+
@slash_option("colour", "The colour to use", autocomplete=True, opt_type=interactions.OptionType.STRING, required=True)
123+
@slash_option("text", "some text", autocomplete=True, opt_type=interactions.OptionType.STRING, required=True)
124+
async def test(ctx: interactions.SlashContext, colour: str, text: str):
125+
embed = interactions.Embed(f"{text} {colour.title()}", color=get_colour(colour))
126+
await ctx.send(embeds=embed)
127+
128+
129+
@global_autocomplete("colour")
130+
async def colour_autocomplete(ctx: interactions.AutocompleteContext):
131+
colours = list((BrandColours.__members__ | FlatUIColours.__members__ | MaterialColours.__members__).keys())
132+
133+
if not ctx.input_text:
134+
colours = colours[:25]
135+
else:
136+
results = process.extract(ctx.input_text, colours, limit=25)
137+
colour_match = sorted([result for result in results if result[1] > 50], key=lambda x: x[1], reverse=True)
138+
colours = [colour[0] for colour in colour_match]
139+
140+
await ctx.send([{"name": colour.title(), "value": colour} for colour in colours])
141+
142+
143+
@test.autocomplete("text")
144+
async def text_autocomplete(ctx: interactions.AutocompleteContext):
145+
await ctx.send([{"name": c, "value": c} for c in ["colour", "color", "shade", "hue"]])
146+
147+
107148
bot.start(os.environ["TOKEN"])

0 commit comments

Comments
 (0)