Skip to content

Commit 084b45f

Browse files
refactor: rewrite Botstrap to be a class instead of one big-old if statement
1 parent 82fa0b1 commit 084b45f

File tree

1 file changed

+179
-103
lines changed

1 file changed

+179
-103
lines changed

botstrap.py

Lines changed: 179 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
import sys
66
from pathlib import Path
7+
from types import TracebackType
78
from typing import Any, Final, cast
89

910
from dotenv import load_dotenv
@@ -26,7 +27,7 @@
2627
# Silence noisy httpcore logger
2728
get_logger("httpcore").setLevel("INFO")
2829

29-
env_file_path = Path(".env.server")
30+
ENV_FILE = Path(".env.server")
3031
BOT_TOKEN = os.getenv("BOT_TOKEN", None)
3132
GUILD_ID = os.getenv("GUILD_ID", None)
3233

@@ -76,6 +77,10 @@ def __getitem__(self, item: str):
7677
sys.exit(-1)
7778

7879

80+
class BotstrapError(Exception):
81+
"""Raised when an error occurs during the botstrap process."""
82+
83+
7984
class DiscordClient(Client):
8085
"""An HTTP client to communicate with Discord's APIs."""
8186

@@ -168,7 +173,7 @@ def upgrade_server_to_community_if_necessary(
168173
payload["rules_channel_id"] = rules_channel_id_
169174
payload["public_updates_channel_id"] = announcements_channel_id_
170175
self._guild_info = self.patch(f"/guilds/{self.guild_id}", json=payload).json()
171-
log.info(f"Server {self.guild_id} has been successfully updated to a community.")
176+
log.info("Server %s has been successfully updated to a community.", self.guild_id)
172177

173178
def create_forum_channel(self, channel_name_: str, category_id_: int | str | None = None) -> str:
174179
"""Creates a new forum channel."""
@@ -182,7 +187,7 @@ def create_forum_channel(self, channel_name_: str, category_id_: int | str | Non
182187
headers={"X-Audit-Log-Reason": "Creating forum channel as part of PyDis botstrap"},
183188
)
184189
forum_channel_id = response.json()["id"]
185-
log.info(f"New forum channel: {channel_name_} has been successfully created.")
190+
log.info("New forum channel: %s has been successfully created.", channel_name_)
186191
return forum_channel_id
187192

188193
def is_forum_channel(self, channel_id: str) -> bool:
@@ -252,7 +257,7 @@ def list_emojis(self) -> list[dict[str, Any]]:
252257
def get_emoji_contents(self, id_: str | int) -> bytes | None:
253258
"""Fetches the image data for an emoji by ID."""
254259
# emojis are located at https://cdn.discordapp.com/emojis/{emoji_id}.{ext}
255-
response = self.get(f"{self.CDN_BASE_URL}/emojis/{emoji_id!s}.webp")
260+
response = self.get(f"{self.CDN_BASE_URL}/emojis/{id_!s}.webp")
256261
return response.content
257262

258263
def clone_emoji(self, *, new_name: str, original_emoji_id: str | int) -> str:
@@ -276,118 +281,189 @@ def clone_emoji(self, *, new_name: str, original_emoji_id: str | int) -> str:
276281
return new_emoji["id"]
277282

278283

279-
with DiscordClient(guild_id=GUILD_ID) as discord_client:
280-
if discord_client.upgrade_application_flags_if_necessary():
281-
log.info("Application flags upgraded successfully, and necessary intents are now enabled.")
284+
class BotStrapper:
285+
"""Bootstrap the bot configuration for a given guild."""
282286

283-
if not discord_client.check_if_in_guild():
284-
client_id = discord_client.app_info["id"]
285-
log.error("The bot is not a member of the configured guild with ID %s.", GUILD_ID)
286-
log.warning(
287-
"Please invite with the following URL and rerun this script: "
288-
"https://discord.com/oauth2/authorize?client_id=%s&guild_id=%s&scope=bot+applications.commands&permissions=8",
289-
client_id,
290-
GUILD_ID,
291-
)
292-
sys.exit(69)
287+
def __init__(self, guild_id: int | str, env_file: Path):
288+
self.client = DiscordClient(guild_id=guild_id)
289+
self.env_file = env_file
293290

294-
config_str = "#Roles\n"
291+
def __enter__(self):
292+
return self
295293

296-
all_roles = discord_client.get_all_roles()
294+
def __exit__(
295+
self,
296+
exc_type: type[BaseException] | None = None,
297+
exc_value: BaseException | None = None,
298+
traceback: TracebackType | None = None,
299+
) -> None:
300+
self.client.__exit__(exc_type, exc_value, traceback)
297301

298-
for role_name in _Roles.model_fields:
299-
role_id = all_roles.get(role_name, None)
300-
if not role_id:
301-
log.warning("Couldn't find the role %s in the guild, PyDis' default values will be used.", role_name)
302-
continue
302+
def upgrade_client(self) -> bool:
303+
"""Upgrade the application's flags if necessary."""
304+
if self.client.upgrade_application_flags_if_necessary():
305+
log.info("Application flags upgraded successfully, and necessary intents are now enabled.")
306+
return True
307+
return False
303308

304-
config_str += f"roles_{role_name}={role_id}\n"
309+
def check_guild_membership(self) -> None:
310+
"""Check the bot is in the required guild."""
311+
if not self.client.check_if_in_guild():
312+
client_id = self.client.app_info["id"]
313+
log.error("The bot is not a member of the configured guild with ID %s.", GUILD_ID)
314+
log.warning(
315+
"Please invite with the following URL and rerun this script: "
316+
"https://discord.com/oauth2/authorize?client_id=%s&guild_id=%s&scope=bot+applications.commands&permissions=8",
317+
client_id,
318+
GUILD_ID,
319+
)
320+
raise BotstrapError("Bot is not a member of the configured guild.")
305321

306-
all_channels, all_categories = discord_client.get_all_channels_and_categories()
322+
def get_roles(self) -> dict[str, Any]:
323+
"""Get a config map of all of the roles in the guild."""
324+
all_roles = self.client.get_all_roles()
307325

308-
config_str += "\n#Channels\n"
326+
data: dict[str, int] = {}
309327

310-
rules_channel_id = all_channels[RULES_CHANNEL_NAME]
311-
announcements_channel_id = all_channels[ANNOUNCEMENTS_CHANNEL_NAME]
328+
for role_name in _Roles.model_fields:
329+
role_id = all_roles.get(role_name, None)
330+
if not role_id:
331+
log.warning("Couldn't find the role %s in the guild, PyDis' default values will be used.", role_name)
332+
continue
312333

313-
discord_client.upgrade_server_to_community_if_necessary(rules_channel_id, announcements_channel_id)
334+
data[role_name] = role_id
314335

315-
if python_help_channel_id := all_channels.get(PYTHON_HELP_CHANNEL_NAME):
316-
if not discord_client.is_forum_channel(python_help_channel_id):
317-
discord_client.delete_channel(python_help_channel_id)
318-
python_help_channel_id = None
336+
return data
319337

320-
if not python_help_channel_id:
321-
python_help_channel_name = PYTHON_HELP_CHANNEL_NAME.replace("_", "-")
322-
python_help_category_id = all_categories[PYTHON_HELP_CATEGORY_NAME]
323-
python_help_channel_id = discord_client.create_forum_channel(python_help_channel_name, python_help_category_id)
324-
all_channels[PYTHON_HELP_CHANNEL_NAME] = python_help_channel_id
338+
def get_channels(self) -> dict[str, Any]:
339+
"""Get a config map of all of the channels in the guild."""
340+
all_channels, all_categories = self.client.get_all_channels_and_categories()
325341

326-
for channel_name in _Channels.model_fields:
327-
channel_id = all_channels.get(channel_name, None)
328-
if not channel_id:
329-
log.warning("Couldn't find the channel %s in the guild, PyDis' default values will be used.", channel_name)
330-
continue
342+
rules_channel_id = all_channels[RULES_CHANNEL_NAME]
343+
announcements_channel_id = all_channels[ANNOUNCEMENTS_CHANNEL_NAME]
331344

332-
config_str += f"channels_{channel_name}={channel_id}\n"
333-
config_str += f"channels_{PYTHON_HELP_CHANNEL_NAME}={python_help_channel_id}\n"
345+
self.client.upgrade_server_to_community_if_necessary(rules_channel_id, announcements_channel_id)
334346

335-
config_str += "\n#Categories\n"
347+
if python_help_channel_id := all_channels.get(PYTHON_HELP_CHANNEL_NAME):
348+
if not self.client.is_forum_channel(python_help_channel_id):
349+
self.client.delete_channel(python_help_channel_id)
350+
python_help_channel_id = None
336351

337-
for category_name in _Categories.model_fields:
338-
category_id = all_categories.get(category_name, None)
339-
if not category_id:
340-
log.warning(
341-
"Couldn't find the category %s in the guild, PyDis' default values will be used.", category_name
342-
)
343-
continue
344-
345-
config_str += f"categories_{category_name}={category_id}\n"
346-
347-
env_file_path.write_text(config_str)
348-
349-
config_str += "\n#Webhooks\n"
350-
existing_webhooks = discord_client.get_all_guild_webhooks()
351-
for webhook_name, webhook_model in Webhooks:
352-
formatted_webhook_name = webhook_name.replace("_", " ").title()
353-
for existing_hook in existing_webhooks:
354-
if (
355-
# check the existing ID matches the configured one
356-
existing_hook["id"] == str(webhook_model.id)
357-
or (
358-
# check if the name and the channel ID match the configured ones
359-
existing_hook["name"] == formatted_webhook_name
360-
and existing_hook["channel_id"] == str(all_channels[webhook_name])
352+
if not python_help_channel_id:
353+
python_help_channel_name = PYTHON_HELP_CHANNEL_NAME.replace("_", "-")
354+
python_help_category_id = all_categories[PYTHON_HELP_CATEGORY_NAME]
355+
python_help_channel_id = self.client.create_forum_channel(python_help_channel_name, python_help_category_id)
356+
all_channels[PYTHON_HELP_CHANNEL_NAME] = python_help_channel_id
357+
358+
data: dict[str, str] = {}
359+
for channel_name in _Channels.model_fields:
360+
channel_id = all_channels.get(channel_name, None)
361+
if not channel_id:
362+
log.warning(
363+
"Couldn't find the channel %s in the guild, PyDis' default values will be used.", channel_name
361364
)
362-
):
363-
webhook_id = existing_hook["id"]
364-
break
365-
else:
366-
webhook_channel_id = int(all_channels[webhook_name])
367-
webhook_id = discord_client.create_webhook(formatted_webhook_name, webhook_channel_id)
368-
config_str += f"webhooks_{webhook_name}__id={webhook_id}\n"
369-
370-
config_str += "\n#Emojis\n"
371-
372-
existing_emojis = discord_client.list_emojis()
373-
log.debug("Syncing emojis with bot configuration.")
374-
for emoji_config_name, emoji_config in _Emojis.model_fields.items():
375-
if not (match := EMOJI_REGEX.match(emoji_config.default)):
376-
continue
377-
emoji_name = match.group(1)
378-
emoji_id = match.group(2)
379-
380-
for emoji in existing_emojis:
381-
if emoji["name"] == emoji_name:
382-
emoji_id = emoji["id"]
383-
break
384-
else:
385-
log.info("Creating emoji %s", emoji_name)
386-
emoji_id = discord_client.clone_emoji(new_name=emoji_name, original_emoji_id=emoji_id)
387-
388-
config_str += f"emojis_{emoji_config_name}=<:{emoji_name}:{emoji_id}>\n"
389-
390-
with env_file_path.open("wb") as file:
391-
file.write(config_str.encode("utf-8"))
392-
393-
log.info("Botstrap completed successfully. Configuration has been written to %s", env_file_path)
365+
continue
366+
367+
data[channel_name] = channel_id
368+
369+
return data
370+
371+
def get_categories(self) -> dict[str, Any]:
372+
"""Get a config map of all of the categories in guild."""
373+
_channels, all_categories = self.client.get_all_channels_and_categories()
374+
375+
data: dict[str, str] = {}
376+
for category_name in _Categories.model_fields:
377+
category_id = all_categories.get(category_name, None)
378+
if not category_id:
379+
log.warning(
380+
"Couldn't find the category %s in the guild, PyDis' default values will be used.", category_name
381+
)
382+
continue
383+
384+
data[category_name] = category_id
385+
return data
386+
387+
def sync_webhooks(self) -> dict[str, Any]:
388+
"""Get webhook config. Will create all webhooks that cannot be found."""
389+
all_channels, _categories = self.client.get_all_channels_and_categories()
390+
391+
data: dict[str, Any] = {}
392+
393+
existing_webhooks = self.client.get_all_guild_webhooks()
394+
for webhook_name, webhook_model in Webhooks:
395+
formatted_webhook_name = webhook_name.replace("_", " ").title()
396+
for existing_hook in existing_webhooks:
397+
if (
398+
# check the existing ID matches the configured one
399+
existing_hook["id"] == str(webhook_model.id)
400+
or (
401+
# check if the name and the channel ID match the configured ones
402+
existing_hook["name"] == formatted_webhook_name
403+
and existing_hook["channel_id"] == str(all_channels[webhook_name])
404+
)
405+
):
406+
webhook_id = existing_hook["id"]
407+
break
408+
else:
409+
webhook_channel_id = int(all_channels[webhook_name])
410+
webhook_id = self.client.create_webhook(formatted_webhook_name, webhook_channel_id)
411+
412+
data[webhook_name + "__id"] = webhook_id
413+
414+
return data
415+
416+
def sync_emojis(self) -> dict[str, Any]:
417+
"""Get emoji config. Will create all emojis that cannot be found."""
418+
existing_emojis = self.client.list_emojis()
419+
log.debug("Syncing emojis with bot configuration.")
420+
data: dict[str, Any] = {}
421+
for emoji_config_name, emoji_config in _Emojis.model_fields.items():
422+
if not (match := EMOJI_REGEX.match(emoji_config.default)):
423+
continue
424+
emoji_name = match.group(1)
425+
emoji_id = match.group(2)
426+
427+
for emoji in existing_emojis:
428+
if emoji["name"] == emoji_name:
429+
emoji_id = emoji["id"]
430+
break
431+
else:
432+
log.info("Creating emoji %s", emoji_name)
433+
emoji_id = self.client.clone_emoji(new_name=emoji_name, original_emoji_id=emoji_id)
434+
435+
data[emoji_config_name] = f"<:{emoji_name}:{emoji_id}>"
436+
437+
return data
438+
439+
def write_config_env(self, config: dict[str, dict[str, Any]], env_file: Path) -> None:
440+
"""Write the configuration to the specified env_file."""
441+
# in order to support commented sections, we write the following
442+
with self.env_file.open("wb") as file:
443+
# format the dictionary into .env style
444+
for category, category_values in config.items():
445+
file.write(f"# {category.capitalize()}\n".encode())
446+
for key, value in category_values.items():
447+
file.write(f"{category}_{key}={value}\n".encode())
448+
file.write(b"\n")
449+
450+
def run(self) -> None:
451+
"""Runs the botstrap process."""
452+
config: dict[str, dict[str, Any]] = {}
453+
self.upgrade_client()
454+
self.check_guild_membership()
455+
config["categories"] = self.get_categories()
456+
config["channels"] = self.get_channels()
457+
config["roles"] = self.get_roles()
458+
459+
config["webhooks"] = self.sync_webhooks()
460+
config["emojis"] = self.sync_emojis()
461+
462+
self.write_config_env(config, self.env_file)
463+
464+
465+
if __name__ == "__main__":
466+
botstrap = BotStrapper(guild_id=GUILD_ID, env_file=ENV_FILE)
467+
with botstrap:
468+
botstrap.run()
469+
log.info("Botstrap completed successfully. Configuration has been written to %s", ENV_FILE)

0 commit comments

Comments
 (0)