1+ from __future__ import annotations
2+
13import abc
24import asyncio
35import concurrent .futures as concurrent
1113import traceback
1214import warnings
1315from contextlib import suppress
16+ from typing import TYPE_CHECKING , Any , Callable
1417
1518import loky
1619
1720from adaptive .notebook_integration import in_ipynb , live_info , live_plot
1821
22+ if TYPE_CHECKING :
23+ from adaptive import BaseLearner
24+
1925try :
2026 import ipyparallel
2127
@@ -663,15 +669,26 @@ def elapsed_time(self):
663669 end_time = time .time ()
664670 return end_time - self .start_time
665671
666- def start_periodic_saving (self , save_kwargs , interval ):
672+ def start_periodic_saving (
673+ self ,
674+ save_kwargs : dict [str , Any ] | None = None ,
675+ interval : int = 30 ,
676+ method : Callable [[BaseLearner ], None ] | None = None ,
677+ ):
667678 """Periodically save the learner's data.
668679
669680 Parameters
670681 ----------
671682 save_kwargs : dict
672683 Key-word arguments for ``learner.save(**save_kwargs)``.
684+ Only used if ``method=None``.
673685 interval : int
674686 Number of seconds between saving the learner.
687+ method : callable
688+ The method to use for saving the learner. If None, the default
689+ saves the learner using "pickle" which calls
690+ ``learner.save(**save_kwargs)``. Otherwise provide a callable
691+ that takes the learner and saves the learner.
675692
676693 Example
677694 -------
@@ -681,11 +698,19 @@ def start_periodic_saving(self, save_kwargs, interval):
681698 ... interval=600)
682699 """
683700
684- async def _saver (save_kwargs = save_kwargs , interval = interval ):
701+ def default_save (learner ):
702+ learner .save (** save_kwargs )
703+
704+ if method is None :
705+ method = default_save
706+ if save_kwargs is None :
707+ raise ValueError ("Must provide `save_kwargs` if method=None." )
708+
709+ async def _saver ():
685710 while self .status () == "running" :
686- self .learner . save ( ** save_kwargs )
711+ method ( self .learner )
687712 await asyncio .sleep (interval )
688- self .learner . save ( ** save_kwargs ) # one last time
713+ method ( self .learner ) # one last time
689714
690715 self .saving_task = self .ioloop .create_task (_saver ())
691716 return self .saving_task
0 commit comments