@@ -676,12 +676,16 @@ def elapsed_time(self):
676676 return end_time - self .start_time
677677
678678 def cancel_point (
679- self , point : Any | None = None , future : asyncio . Future | None = None
679+ self , future : asyncio . Future | None = None , point : Any | None = None
680680 ):
681- """Cancel a point that is currently being evaluated.
681+ """Cancel a future or point that is currently being evaluated.
682+
683+ Either the ``future`` or the ``point`` must be provided.
682684
683685 Parameters
684686 ----------
687+ future : asyncio.Future
688+ The future that is currently being evaluated.
685689 point
686690 The point that should be cancelled.
687691 """
@@ -691,9 +695,9 @@ def cancel_point(
691695 future = next (fut for fut , p in self .pending_points if p == point )
692696 future .cancel ()
693697
694- def add_periodic_callback (
698+ def start_periodic_callback (
695699 self ,
696- method : Callable [[AsyncRunner ]],
700+ method : Callable [[AsyncRunner ], None ],
697701 interval : int = 30 ,
698702 ):
699703 """Start a periodic callback that calls the given method on the runner.
@@ -753,9 +757,11 @@ def default_save(learner):
753757 if method is None :
754758 method = default_save
755759 if save_kwargs is None :
756- raise ValueError ("Must provide `save_kwargs` if method=None." )
760+ raise ValueError ("Must provide `save_kwargs` if ` method=None` ." )
757761
758- self .saving_task = self .add_periodic_callback (method , interval = interval )
762+ self .saving_task = self .start_periodic_callback (
763+ lambda r : method (r .learner ), interval = interval
764+ )
759765 return self .saving_task
760766
761767
0 commit comments