Skip to content

Commit 6e3fee6

Browse files
committed
✨ Refactor prefetch decorators to handle related values prefetch
1 parent 7595305 commit 6e3fee6

File tree

1 file changed

+56
-19
lines changed

1 file changed

+56
-19
lines changed

src/database/utils/preload.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,123 @@
11
# Copyright (c) NiceBots
22
# SPDX-License-Identifier: MIT
3+
from collections.abc import Callable, Sequence
4+
from functools import partial
5+
from typing import Literal, overload
36

47
from discord.ext import commands
58

69
from src import custom
710
from src.database.models import Guild, User
811

912

10-
async def _preload_user(ctx: custom.Context) -> bool:
13+
async def _preload_user(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]:
1114
"""Preload the user object into the context object.
1215
1316
Args:
1417
----
1518
ctx: The context object to preload the user object into.
19+
prefetch_related: List of related fields to prefetch.
1620
1721
Returns:
1822
-------
1923
bool: (True) always.
2024
2125
"""
2226
if isinstance(ctx, custom.ExtContext):
23-
ctx.user_obj = await User.get_or_none(id=ctx.author.id) if ctx.author else None
27+
ctx.user_obj = (
28+
await User.get_or_none(id=ctx.author.id).prefetch_related(*prefetch_related) if ctx.author else None
29+
)
2430
else:
25-
ctx.user_obj = await User.get_or_none(id=ctx.user.id) if ctx.user else None
31+
ctx.user_obj = await User.get_or_none(id=ctx.user.id).prefetch_related(*prefetch_related) if ctx.user else None
2632
return True
2733

2834

29-
preload_user = commands.check(_preload_user) # pyright: ignore [reportArgumentType]
30-
31-
32-
async def _preload_guild(ctx: custom.Context) -> bool:
35+
async def _preload_guild(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]:
3336
"""Preload the guild object into the context object.
3437
3538
Args:
3639
----
3740
ctx: The context object to preload the guild object into.
41+
prefetch_related: List of related fields to prefetch.
3842
3943
Returns:
4044
-------
4145
bool: (True) always.
4246
4347
"""
44-
ctx.guild_obj = await Guild.get_or_none(id=ctx.guild.id) if ctx.guild else None
48+
ctx.guild_obj = await Guild.get_or_none(id=ctx.guild.id).prefetch_related(*prefetch_related) if ctx.guild else None
4549
return True
4650

4751

48-
preload_guild = commands.check(_preload_guild) # pyright: ignore [reportArgumentType]
49-
50-
51-
async def _preload_or_create_user(ctx: custom.Context) -> bool:
52+
async def _preload_or_create_user(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]:
5253
"""Preload or create the user object into the context object. If the user object does not exist, create it.
5354
5455
Args:
5556
----
5657
ctx: The context object to preload or create the user object into.
58+
prefetch_related: List of related fields to prefetch.
5759
5860
Returns:
5961
-------
6062
bool: (True) always.
6163
6264
"""
63-
ctx.user_obj, _ = await User.get_or_create(id=ctx.author.id) if ctx.author else (None, None)
65+
user: User | None
66+
user, _ = await User.get_or_create(id=ctx.author.id) if ctx.author else (None, None)
67+
if user is not None:
68+
await user.fetch_related(*prefetch_related)
69+
ctx.user_obj = user
6470
return True
6571

6672

67-
preload_or_create_user = commands.check(_preload_or_create_user)
68-
69-
70-
async def _preload_or_create_guild(ctx: custom.Context) -> bool:
73+
async def _preload_or_create_guild(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]:
7174
"""Preload or create the guild object into the context object. If the guild object does not exist, create it.
7275
7376
Args:
7477
----
7578
ctx: The context object to preload or create the guild object into.
79+
prefetch_related: List of related fields to prefetch.
7680
7781
Returns:
7882
-------
7983
bool: (True) always.
8084
8185
"""
82-
ctx.guild_obj, _ = await Guild.get_or_create(id=ctx.guild.id) if ctx.guild else (None, None)
86+
guild: Guild | None
87+
guild, _ = await Guild.get_or_create(id=ctx.guild.id) if ctx.guild else (None, None)
88+
if guild is not None:
89+
await guild.fetch_related(*prefetch_related)
90+
ctx.guild_obj = guild
8391
return True
8492

8593

86-
preload_or_create_guild = commands.check(_preload_or_create_guild) # pyright: ignore [reportArgumentType]
94+
type PreloadFunction = Callable[[custom.Context, Sequence[str]], Literal[True]]
95+
96+
97+
@overload
98+
def preload_x[T](f: T, preloader: PreloadFunction, prefetch_related: Sequence[str] | None = None) -> T: ...
99+
100+
101+
@overload
102+
def preload_x[T](
103+
f: None = None, *, preloader: PreloadFunction, prefetch_related: Sequence[str] | None = None
104+
) -> Callable[[T], T]: ...
105+
106+
107+
def preload_x[T](
108+
f: Callable[[T], T] | None = None, *, preloader: PreloadFunction, prefetch_related: Sequence[str] | None = None
109+
):
110+
if prefetch_related is None:
111+
prefetch_related = []
112+
113+
func = partial(preloader, prefetch_related=prefetch_related)
114+
115+
check_decorator = commands.check(func)
116+
117+
return check_decorator(f) if f is not None else check_decorator
118+
119+
120+
preload_guild = partial(preload_x, preloader=_preload_guild)
121+
preload_user = partial(preload_x, preloader=_preload_user)
122+
preload_or_create_guild = partial(preload_x, preloader=_preload_or_create_guild)
123+
preload_or_create_user = partial(preload_x, preloader=_preload_or_create_user)

0 commit comments

Comments
 (0)