|
1 | 1 | # Copyright (c) NiceBots |
2 | 2 | # SPDX-License-Identifier: MIT |
| 3 | +from collections.abc import Callable, Sequence |
| 4 | +from functools import partial |
| 5 | +from typing import Literal, overload |
3 | 6 |
|
4 | 7 | from discord.ext import commands |
5 | 8 |
|
6 | 9 | from src import custom |
7 | 10 | from src.database.models import Guild, User |
8 | 11 |
|
9 | 12 |
|
10 | | -async def _preload_user(ctx: custom.Context) -> bool: |
| 13 | +async def _preload_user(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]: |
11 | 14 | """Preload the user object into the context object. |
12 | 15 |
|
13 | 16 | Args: |
14 | 17 | ---- |
15 | 18 | ctx: The context object to preload the user object into. |
| 19 | + prefetch_related: List of related fields to prefetch. |
16 | 20 |
|
17 | 21 | Returns: |
18 | 22 | ------- |
19 | 23 | bool: (True) always. |
20 | 24 |
|
21 | 25 | """ |
22 | 26 | if isinstance(ctx, custom.ExtContext): |
23 | | - ctx.user_obj = await User.get_or_none(id=ctx.author.id) if ctx.author else None |
| 27 | + ctx.user_obj = ( |
| 28 | + await User.get_or_none(id=ctx.author.id).prefetch_related(*prefetch_related) if ctx.author else None |
| 29 | + ) |
24 | 30 | else: |
25 | | - ctx.user_obj = await User.get_or_none(id=ctx.user.id) if ctx.user else None |
| 31 | + ctx.user_obj = await User.get_or_none(id=ctx.user.id).prefetch_related(*prefetch_related) if ctx.user else None |
26 | 32 | return True |
27 | 33 |
|
28 | 34 |
|
29 | | -preload_user = commands.check(_preload_user) # pyright: ignore [reportArgumentType] |
30 | | - |
31 | | - |
32 | | -async def _preload_guild(ctx: custom.Context) -> bool: |
| 35 | +async def _preload_guild(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]: |
33 | 36 | """Preload the guild object into the context object. |
34 | 37 |
|
35 | 38 | Args: |
36 | 39 | ---- |
37 | 40 | ctx: The context object to preload the guild object into. |
| 41 | + prefetch_related: List of related fields to prefetch. |
38 | 42 |
|
39 | 43 | Returns: |
40 | 44 | ------- |
41 | 45 | bool: (True) always. |
42 | 46 |
|
43 | 47 | """ |
44 | | - ctx.guild_obj = await Guild.get_or_none(id=ctx.guild.id) if ctx.guild else None |
| 48 | + ctx.guild_obj = await Guild.get_or_none(id=ctx.guild.id).prefetch_related(*prefetch_related) if ctx.guild else None |
45 | 49 | return True |
46 | 50 |
|
47 | 51 |
|
48 | | -preload_guild = commands.check(_preload_guild) # pyright: ignore [reportArgumentType] |
49 | | - |
50 | | - |
51 | | -async def _preload_or_create_user(ctx: custom.Context) -> bool: |
| 52 | +async def _preload_or_create_user(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]: |
52 | 53 | """Preload or create the user object into the context object. If the user object does not exist, create it. |
53 | 54 |
|
54 | 55 | Args: |
55 | 56 | ---- |
56 | 57 | ctx: The context object to preload or create the user object into. |
| 58 | + prefetch_related: List of related fields to prefetch. |
57 | 59 |
|
58 | 60 | Returns: |
59 | 61 | ------- |
60 | 62 | bool: (True) always. |
61 | 63 |
|
62 | 64 | """ |
63 | | - ctx.user_obj, _ = await User.get_or_create(id=ctx.author.id) if ctx.author else (None, None) |
| 65 | + user: User | None |
| 66 | + user, _ = await User.get_or_create(id=ctx.author.id) if ctx.author else (None, None) |
| 67 | + if user is not None: |
| 68 | + await user.fetch_related(*prefetch_related) |
| 69 | + ctx.user_obj = user |
64 | 70 | return True |
65 | 71 |
|
66 | 72 |
|
67 | | -preload_or_create_user = commands.check(_preload_or_create_user) |
68 | | - |
69 | | - |
70 | | -async def _preload_or_create_guild(ctx: custom.Context) -> bool: |
| 73 | +async def _preload_or_create_guild(ctx: custom.Context, prefetch_related: Sequence[str]) -> Literal[True]: |
71 | 74 | """Preload or create the guild object into the context object. If the guild object does not exist, create it. |
72 | 75 |
|
73 | 76 | Args: |
74 | 77 | ---- |
75 | 78 | ctx: The context object to preload or create the guild object into. |
| 79 | + prefetch_related: List of related fields to prefetch. |
76 | 80 |
|
77 | 81 | Returns: |
78 | 82 | ------- |
79 | 83 | bool: (True) always. |
80 | 84 |
|
81 | 85 | """ |
82 | | - ctx.guild_obj, _ = await Guild.get_or_create(id=ctx.guild.id) if ctx.guild else (None, None) |
| 86 | + guild: Guild | None |
| 87 | + guild, _ = await Guild.get_or_create(id=ctx.guild.id) if ctx.guild else (None, None) |
| 88 | + if guild is not None: |
| 89 | + await guild.fetch_related(*prefetch_related) |
| 90 | + ctx.guild_obj = guild |
83 | 91 | return True |
84 | 92 |
|
85 | 93 |
|
86 | | -preload_or_create_guild = commands.check(_preload_or_create_guild) # pyright: ignore [reportArgumentType] |
| 94 | +type PreloadFunction = Callable[[custom.Context, Sequence[str]], Literal[True]] |
| 95 | + |
| 96 | + |
| 97 | +@overload |
| 98 | +def preload_x[T](f: T, preloader: PreloadFunction, prefetch_related: Sequence[str] | None = None) -> T: ... |
| 99 | + |
| 100 | + |
| 101 | +@overload |
| 102 | +def preload_x[T]( |
| 103 | + f: None = None, *, preloader: PreloadFunction, prefetch_related: Sequence[str] | None = None |
| 104 | +) -> Callable[[T], T]: ... |
| 105 | + |
| 106 | + |
| 107 | +def preload_x[T]( |
| 108 | + f: Callable[[T], T] | None = None, *, preloader: PreloadFunction, prefetch_related: Sequence[str] | None = None |
| 109 | +): |
| 110 | + if prefetch_related is None: |
| 111 | + prefetch_related = [] |
| 112 | + |
| 113 | + func = partial(preloader, prefetch_related=prefetch_related) |
| 114 | + |
| 115 | + check_decorator = commands.check(func) |
| 116 | + |
| 117 | + return check_decorator(f) if f is not None else check_decorator |
| 118 | + |
| 119 | + |
| 120 | +preload_guild = partial(preload_x, preloader=_preload_guild) |
| 121 | +preload_user = partial(preload_x, preloader=_preload_user) |
| 122 | +preload_or_create_guild = partial(preload_x, preloader=_preload_or_create_guild) |
| 123 | +preload_or_create_user = partial(preload_x, preloader=_preload_or_create_user) |
0 commit comments