Skip to content

Commit c00b857

Browse files
committed
♻️ Fix optional fields in BotConfig and Config models; improve module validation
1 parent 47d852e commit c00b857

File tree

4 files changed

+17
-26
lines changed

4 files changed

+17
-26
lines changed

src/config/bot_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,11 @@ def load_json_recursive(data: dict[str, Any]) -> dict[str, Any]:
6262
path = "config.yml"
6363

6464
_config: Any
65+
config: Config
6566
if path:
6667
with open(path, encoding="utf-8") as f:
6768
_config = yaml.safe_load(f)
6869
else:
6970
_config = load_from_env()
7071

71-
config: Config = Config(**_config)
72+
config = Config(**_config) if _config else Config()

src/config/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class BotConfig(BaseModel):
4545
public_key: str | None = None
4646
prefix: PrefixConfig | str = PrefixConfig(prefix="!", enabled=False)
4747
slash: SlashConfig = SlashConfig(enabled=False)
48-
cache: CacheConfig
48+
cache: CacheConfig = CacheConfig()
4949
rest: bool = False
5050

5151

@@ -65,7 +65,7 @@ class DbConfig(BaseModel):
6565

6666
class Config(BaseModel):
6767
db: DbConfig = DbConfig(url="", enabled=False)
68-
bot: BotConfig
68+
bot: BotConfig = BotConfig(token="")
6969
logging: LoggingConfig = LoggingConfig()
7070
use: UseConfig = UseConfig()
7171
extensions: dict[str, Extension] = {}

src/utils/extensions.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import discord
1414
from quart import Quart
15-
from schema import Schema, SchemaError
1615

1716
from src.log import logger
1817

@@ -40,8 +39,7 @@ def check_func(module: ModuleType, func: Callable, max_args: int, types: dict[st
4039
# check_typing(module, func, types) # temporarily disabled due to unwanted behavior # noqa: ERA001
4140

4241

43-
# noinspection DuplicatedCode
44-
def validate_module(module: ModuleType, config: dict[str, Any] | None = None) -> None:
42+
def validate_module(module: ModuleType, config: dict[str, Any] | None = None) -> None: # pyright: ignore [reportUnusedParameter] # noqa: ARG001
4543
"""Validate the module to ensure it has the required functions and attributes to be loaded as an extension.
4644
4745
:param module: The module to validate
@@ -77,26 +75,6 @@ def validate_module(module: ModuleType, config: dict[str, Any] | None = None) ->
7775
assert "enabled" in module.default, (
7876
f"Extension {module.__name__} does not have an enabled key in its default configuration"
7977
)
80-
if hasattr(module, "schema"):
81-
assert isinstance(
82-
module.schema,
83-
Schema | dict,
84-
), f"Extension {module.__name__} has a schema of type {type(module.schema)} instead of Schema or dict"
85-
86-
if isinstance(module.schema, dict):
87-
module.schema = Schema(module.schema)
88-
if config:
89-
module.schema.validate(config)
90-
else:
91-
try:
92-
module.schema.validate(module.default)
93-
except SchemaError as e:
94-
warnings.warn(
95-
f"Default configuration for extension {module.__name__} does not match schema: {e}",
96-
stacklevel=1,
97-
)
98-
else:
99-
warnings.warn(f"Extension {module.__name__} does not have a schema", stacklevel=1)
10078

10179

10280
def unzip_extensions() -> None:

src/utils/misc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) NiceBots.xyz
22
# SPDX-License-Identifier: MIT
3+
from collections.abc import Callable
34

45
import discord
56

@@ -10,3 +11,14 @@ def mention_command(*command: str, bot: discord.Bot) -> str:
1011
if isinstance(command, discord.SlashCommand):
1112
return command.mention
1213
raise ValueError("Command not found")
14+
15+
16+
class LazyProxy[T]:
17+
def __init__(self, func: Callable[..., T]) -> None:
18+
self._func: Callable[..., T] = func
19+
self._value: T | None = None
20+
21+
def __getattr__(self, name: str) -> T:
22+
if self._value is None:
23+
self._value = self._func()
24+
return getattr(self._value, name)

0 commit comments

Comments
 (0)