@@ -649,6 +649,7 @@ def __init__(
649649
650650 self .task = self .ioloop .create_task (self ._run ())
651651 self .saving_task = None
652+ self .callbacks = []
652653 if in_ipynb () and not self .ioloop .is_running ():
653654 warnings .warn (
654655 "The runner has been scheduled, but the asyncio "
@@ -753,6 +754,31 @@ def elapsed_time(self):
753754 end_time = time .time ()
754755 return end_time - self .start_time
755756
757+ def add_periodic_callback (
758+ self ,
759+ method : Callable [[AsyncRunner ]],
760+ interval : int = 30 ,
761+ ):
762+ """Start a periodic callback that calls the given method on the runner.
763+
764+ Parameters
765+ ----------
766+ method : callable
767+ The method to call periodically.
768+ interval : int
769+ The interval in seconds between the calls.
770+ """
771+
772+ async def _callback ():
773+ while self .status () == "running" :
774+ method (self )
775+ await asyncio .sleep (interval )
776+ method (self ) # one last time
777+
778+ task = self .ioloop .create_task (_callback ())
779+ self .callbacks .append (task )
780+ return task
781+
756782 def start_periodic_saving (
757783 self ,
758784 save_kwargs : dict [str , Any ] | None = None ,
@@ -781,6 +807,8 @@ def start_periodic_saving(
781807 ... save_kwargs=dict(fname='data/test.pickle'),
782808 ... interval=600)
783809 """
810+ if self .saving_task is not None :
811+ raise RuntimeError ("Already saving." )
784812
785813 def default_save (learner ):
786814 learner .save (** save_kwargs )
@@ -790,13 +818,7 @@ def default_save(learner):
790818 if save_kwargs is None :
791819 raise ValueError ("Must provide `save_kwargs` if method=None." )
792820
793- async def _saver ():
794- while self .status () == "running" :
795- method (self .learner )
796- await asyncio .sleep (interval )
797- method (self .learner ) # one last time
798-
799- self .saving_task = self .ioloop .create_task (_saver ())
821+ self .saving_task = self .add_periodic_callback (method , interval = interval )
800822 return self .saving_task
801823
802824
0 commit comments