2828import asyncio
2929import datetime
3030import inspect
31+ import logging
3132import sys
3233import traceback
3334from collections .abc import Sequence
4344
4445T = TypeVar ("T" )
4546_func = Callable [..., Awaitable [Any ]]
47+ _log = logging .getLogger (__name__ )
4648LF = TypeVar ("LF" , bound = _func )
4749FT = TypeVar ("FT" , bound = _func )
4850ET = TypeVar ("ET" , bound = Callable [[Any , BaseException ], Awaitable [Any ]])
4951
5052
53+ def is_ambiguous (dt : datetime .datetime ) -> bool :
54+ if dt .tzinfo is None or isinstance (dt .tzinfo , datetime .timezone ):
55+ return False
56+
57+ before = dt .replace (fold = 0 )
58+ after = dt .replace (fold = 1 )
59+
60+ same_offset = before .utcoffset () == after .utcoffset ()
61+ same_dst = before .dst () == after .dst ()
62+ return not (same_offset and same_dst )
63+
64+
65+ def is_imaginary (dt : datetime .datetime ) -> bool :
66+ if dt .tzinfo is None or isinstance (dt .tzinfo , datetime .timezone ):
67+ return False
68+
69+ tz = dt .tzinfo
70+ dt = dt .replace (tzinfo = None )
71+ roundtrip = dt .replace (tzinfo = tz ).astimezone (datetime .timezone .utc ).astimezone (tz ).replace (tzinfo = None )
72+ return dt != roundtrip
73+
74+
5175class SleepHandle :
5276 __slots__ = ("future" , "loop" , "handle" )
5377
5478 def __init__ (
5579 self , dt : datetime .datetime , * , loop : asyncio .AbstractEventLoop
5680 ) -> None :
57- self .loop = loop
58- self .future = future = loop .create_future ()
81+ self .loop : asyncio . AbstractEventLoop = loop
82+ self .future : asyncio . Future [ None ] = loop .create_future ()
5983 relative_delta = discord .utils .compute_timedelta (dt )
60- self .handle = loop .call_later (relative_delta , future .set_result , True )
84+ self .handle = loop .call_later (relative_delta , self ._safe_result , self .future )
85+
86+ @staticmethod
87+ def _safe_result (future : asyncio .Future ) -> None :
88+ if not future .done ():
89+ future .set_result (None )
6190
6291 def recalculate (self , dt : datetime .datetime ) -> None :
6392 self .handle .cancel ()
6493 relative_delta = discord .utils .compute_timedelta (dt )
65- self .handle = self .loop .call_later (relative_delta , self .future . set_result , True )
94+ self .handle = self .loop .call_later (relative_delta , self ._safe_result , self . future )
6695
6796 def wait (self ) -> asyncio .Future [Any ]:
6897 return self .future
@@ -95,7 +124,15 @@ def __init__(
95124 ) -> None :
96125 self .coro : LF = coro
97126 self .reconnect : bool = reconnect
98- self .loop : asyncio .AbstractEventLoop | None = loop
127+
128+ if loop is None :
129+ try :
130+ loop = asyncio .get_running_loop ()
131+ except RuntimeError :
132+ loop = asyncio .new_event_loop ()
133+
134+ self .loop = loop
135+
99136 self .name : str = f'pycord-ext-task ({ id (self ):#x} ): { coro .__qualname__ } ' if name in (None , MISSING ) else name
100137 self .count : int | None = count
101138 self ._current_loop = 0
@@ -147,53 +184,67 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non
147184 if name .endswith ("_loop" ):
148185 setattr (self , f"_{ name } _running" , False )
149186
150- def _create_task (self , * args : Any , ** kwargs : Any ) -> asyncio .Task [None ]:
151- if self .loop is None :
152- meth = asyncio .create_task
153- else :
154- meth = self .loop .create_task
155- return meth (self ._loop (* args , ** kwargs ), name = self .name )
156-
157187 def _try_sleep_until (self , dt : datetime .datetime ):
158188 self ._handle = SleepHandle (dt = dt , loop = asyncio .get_running_loop ())
159189 return self ._handle .wait ()
160190
191+ def _rel_time (self ) -> bool :
192+ return self ._time is MISSING
193+
194+ def _expl_time (self ) -> bool :
195+ return self ._time is not MISSING
196+
161197 async def _loop (self , * args : Any , ** kwargs : Any ) -> None :
162198 backoff = ExponentialBackoff ()
163199 await self ._call_loop_function ("before_loop" )
164200 self ._last_iteration_failed = False
165- if self ._time is not MISSING :
166- # the time index should be prepared every time the internal loop is started
167- self ._prepare_time_index ()
201+ if self ._expl_time ():
168202 self ._next_iteration = self ._get_next_sleep_time ()
169203 else :
170204 self ._next_iteration = datetime .datetime .now (datetime .timezone .utc )
205+
171206 try :
172- await self ._try_sleep_until (self ._next_iteration )
207+ if self ._stop_next_iteration :
208+ return
209+
173210 while True :
211+ if self ._expl_time ():
212+ await self ._try_sleep_until (self ._next_iteration )
174213 if not self ._last_iteration_failed :
175214 self ._last_iteration = self ._next_iteration
176215 self ._next_iteration = self ._get_next_sleep_time ()
216+
217+ while self ._expl_time () and self ._next_iteration <= self ._last_iteration :
218+ _log .warning (
219+ 'Task %s woke up at %s, which was before expected (%s). Sleeping again to fix it...' ,
220+ self .coro .__name__ ,
221+ discord .utils .utcnow (),
222+ self ._next_iteration ,
223+ )
224+ await self ._try_sleep_until (self ._next_iteration )
225+ self ._next_iteration = self ._get_next_sleep_time ()
177226 try :
178227 await self .coro (* args , ** kwargs )
179228 self ._last_iteration_failed = False
180- backoff = ExponentialBackoff ()
181- except self ._valid_exception :
229+ except self ._valid_exception as exc :
182230 self ._last_iteration_failed = True
183231 if not self .reconnect :
184232 raise
185- await asyncio .sleep (backoff .delay ())
186- else :
187- await self ._try_sleep_until (self ._next_iteration )
188233
234+ delay = backoff .delay ()
235+ _log .warning (
236+ 'Received an exception which was in the valid exception set. Task will run again in %s.2f seconds' ,
237+ self .coro .__name__ ,
238+ delay ,
239+ exc_info = exc ,
240+ )
241+ await asyncio .sleep (delay )
242+ else :
189243 if self ._stop_next_iteration :
190244 return
191245
192- now = datetime .datetime .now (datetime .timezone .utc )
193- if now > self ._next_iteration :
194- self ._next_iteration = now
195- if self ._time is not MISSING :
196- self ._prepare_time_index (now )
246+ if self ._rel_time ():
247+ await self ._try_sleep_until (self ._next_iteration )
197248
198249 self ._current_loop += 1
199250 if self ._current_loop == self .count :
@@ -208,7 +259,8 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
208259 raise exc
209260 finally :
210261 await self ._call_loop_function ("after_loop" )
211- self ._handle .cancel ()
262+ if self ._handle :
263+ self ._handle .cancel ()
212264 self ._is_being_cancelled = False
213265 self ._current_loop = 0
214266 self ._stop_next_iteration = False
@@ -226,8 +278,8 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]:
226278 time = self ._time ,
227279 count = self .count ,
228280 reconnect = self .reconnect ,
229- loop = self .loop ,
230281 name = self .name ,
282+ loop = self .loop ,
231283 )
232284 copy ._injected = obj
233285 copy ._before_loop = self ._before_loop
@@ -340,7 +392,7 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
340392 if self ._injected is not None :
341393 args = (self ._injected , * args )
342394
343- self ._task = self ._create_task ( * args , ** kwargs )
395+ self ._task = self .loop . create_task ( self . _loop ( * args , ** kwargs ), name = self . name )
344396 return self ._task
345397
346398 def stop (self ) -> None :
@@ -574,66 +626,51 @@ def error(self, coro: ET) -> ET:
574626 self ._error = coro # type: ignore
575627 return coro
576628
577- def _get_next_sleep_time (self ) -> datetime .datetime :
629+ def _get_next_sleep_time (self , now : datetime . datetime = MISSING ) -> datetime .datetime :
578630 if self ._sleep is not MISSING :
579631 return self ._last_iteration + datetime .timedelta (seconds = self ._sleep )
580632
581- if self ._time_index >= len (self ._time ):
582- self ._time_index = 0
583- if self ._current_loop == 0 :
584- # if we're at the last index on the first iteration, we need to sleep until tomorrow
585- return datetime .datetime .combine (
586- datetime .datetime .now (self ._time [0 ].tzinfo or datetime .timezone .utc )
587- + datetime .timedelta (days = 1 ),
588- self ._time [0 ],
589- )
633+ if now is MISSING :
634+ now = datetime .datetime .now (datetime .timezone .utc )
590635
591- next_time = self ._time [self ._time_index ]
592-
593- if self ._current_loop == 0 :
594- self ._time_index += 1
595- if (
596- next_time
597- > datetime .datetime .now (
598- next_time .tzinfo or datetime .timezone .utc
599- ).timetz ()
600- ):
601- return datetime .datetime .combine (
602- datetime .datetime .now (next_time .tzinfo or datetime .timezone .utc ),
603- next_time ,
604- )
605- else :
606- return datetime .datetime .combine (
607- datetime .datetime .now (next_time .tzinfo or datetime .timezone .utc )
608- + datetime .timedelta (days = 1 ),
609- next_time ,
610- )
636+ index = self ._start_time_relative_to (now )
611637
612- next_date = cast (
613- datetime .datetime , self ._last_iteration .astimezone (next_time .tzinfo )
614- )
615- if next_time < next_date .timetz ():
616- next_date += datetime .timedelta (days = 1 )
638+ if index is None :
639+ time = self ._time [0 ]
640+ tomorrow = now .astimezone (time .tzinfo ) + datetime .timedelta (days = 1 )
641+ date = tomorrow .date ()
642+ else :
643+ time = self ._time [index ]
644+ date = now .astimezone (time .tzinfo ).date ()
645+
646+ dt = datetime .datetime .combine (date , time , tzinfo = time .tzinfo )
617647
618- self ._time_index += 1
619- return datetime .datetime .combine (next_date , next_time )
648+ if dt .tzinfo is None or isinstance (dt .tzinfo , datetime .timezone ):
649+ return dt
650+
651+ if is_imaginary (dt ):
652+ tomorrow = dt + datetime .timedelta (days = 1 )
653+ yesterday = dt - datetime .timedelta (days = 1 )
654+ return dt + (tomorrow .utcoffset () - yesterday .utcoffset ()) # type: ignore
655+ elif is_ambiguous (dt ):
656+ return dt .replace (fold = 1 )
657+ else :
658+ return dt
620659
621- def _prepare_time_index (self , now : datetime .datetime = MISSING ) -> None :
660+ def _start_time_relative_to (self , now : datetime .datetime ) -> int | None :
622661 # now kwarg should be a datetime.datetime representing the time "now"
623662 # to calculate the next time index from
624663
625664 # pre-condition: self._time is set
626- time_now = (
627- now
628- if now is not MISSING
629- else datetime .datetime .now (datetime .timezone .utc ).replace (microsecond = 0 )
630- )
631665 for idx , time in enumerate (self ._time ):
632- if time >= time_now .astimezone (time .tzinfo ).timetz ():
633- self ._time_index = idx
634- break
666+ # Convert the current time to the target timezone
667+ # e.g. 18:00 UTC -> 03:00 UTC+9
668+ # Then compare the time instances to see if they're the same
669+ start = now .astimezone (time .tzinfo )
670+ if time >= start .timetz ():
671+ return idx
635672 else :
636- self . _time_index = 0
673+ return None
637674
638675 def _get_time_parameter (
639676 self ,
@@ -780,9 +817,6 @@ def loop(
780817 one used in :meth:`discord.Client.connect`.
781818 loop: Optional[:class:`asyncio.AbstractEventLoop`]
782819 The loop to use to register the task, defaults to ``None``.
783-
784- .. versionchanged:: 2.7
785- This can now be ``None``
786820 name: Optional[:class:`str`]
787821 The name to create the task with, defaults to ``None``.
788822
@@ -806,8 +840,8 @@ def decorator(func: LF) -> Loop[LF]:
806840 count = count ,
807841 time = time ,
808842 reconnect = reconnect ,
809- loop = loop ,
810843 name = name ,
844+ loop = loop ,
811845 )
812846
813847 return decorator
0 commit comments