7272 # and https://github.com/python-adaptive/adaptive/issues/301
7373 _default_executor = loky .get_reusable_executor
7474
75- GoalTypes : TypeAlias = Union [
75+ _GoalTypes : TypeAlias = Union [
7676 Callable [[BaseLearner ], bool ], int , float , datetime , timedelta , None
7777]
7878
@@ -86,15 +86,22 @@ class BaseRunner(metaclass=abc.ABCMeta):
8686 goal : callable, optional
8787 The end condition for the calculation. This function must take
8888 the learner as its sole argument, and return True when we should
89- stop requesting more points. (Advanced use) Instead of providing a
90- function, see `auto_goal` for other types that are accepted here.
89+ stop requesting more points.
9190 loss_goal : float, optional
9291 Convenience argument, use instead of ``goal``. The end condition for the
9392 calculation. Stop when the loss is smaller than this value.
9493 npoints_goal : int, optional
9594 Convenience argument, use instead of ``goal``. The end condition for the
9695 calculation. Stop when the number of points is larger or
9796 equal than this value.
97+ datetime_goal : datetime, optional
98+ Convenience argument, use instead of ``goal``. The end condition for the
99+ calculation. Stop when the current time is larger or equal than this
100+ value.
101+ timedelta_goal : timedelta, optional
102+ Convenience argument, use instead of ``goal``. The end condition for the
103+ calculation. Stop when the current time is larger or equal than
104+ ``start_time + timedelta_goal``.
98105 executor : `concurrent.futures.Executor`, `distributed.Client`,\
99106 `mpi4py.futures.MPIPoolExecutor`, `ipyparallel.Client` or\
100107 `loky.get_reusable_executor`, optional
@@ -144,10 +151,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
144151 def __init__ (
145152 self ,
146153 learner ,
147- goal : GoalTypes = None ,
154+ goal : Callable [[ BaseLearner ], bool ] | None = None ,
148155 * ,
149156 loss_goal : float | None = None ,
150157 npoints_goal : int | None = None ,
158+ datetime_goal : datetime | None = None ,
159+ timedelta_goal : timedelta | None = None ,
151160 executor = None ,
152161 ntasks = None ,
153162 log = False ,
@@ -158,7 +167,15 @@ def __init__(
158167 ):
159168
160169 self .executor = _ensure_executor (executor )
161- self .goal = _goal (learner , goal , loss_goal , npoints_goal , allow_running_forever )
170+ self .goal = _goal (
171+ learner ,
172+ goal ,
173+ loss_goal ,
174+ npoints_goal ,
175+ datetime_goal ,
176+ timedelta_goal ,
177+ allow_running_forever ,
178+ )
162179
163180 self ._max_tasks = ntasks
164181
@@ -348,8 +365,7 @@ class BlockingRunner(BaseRunner):
348365 goal : callable
349366 The end condition for the calculation. This function must take
350367 the learner as its sole argument, and return True when we should
351- stop requesting more points. (Advanced use) Instead of providing a
352- function, see `auto_goal` for other types that are accepted here.
368+ stop requesting more points.
353369 loss_goal : float
354370 Convenience argument, use instead of ``goal``. The end condition for the
355371 calculation. Stop when the loss is smaller than this value.
@@ -410,10 +426,12 @@ class BlockingRunner(BaseRunner):
410426 def __init__ (
411427 self ,
412428 learner ,
413- goal : GoalTypes = None ,
429+ goal : Callable [[ BaseLearner ], bool ] | None = None ,
414430 * ,
415431 loss_goal : float | None = None ,
416432 npoints_goal : int | None = None ,
433+ datetime_goal : datetime | None = None ,
434+ timedelta_goal : timedelta | None = None ,
417435 executor = None ,
418436 ntasks = None ,
419437 log = False ,
@@ -481,8 +499,7 @@ class AsyncRunner(BaseRunner):
481499 goal : callable, optional
482500 The end condition for the calculation. This function must take
483501 the learner as its sole argument, and return True when we should
484- stop requesting more points. (Advanced use) Instead of providing a
485- function, see `auto_goal` for other types that are accepted here.
502+ stop requesting more points.
486503 If not provided, the runner will run forever (or stop when no more
487504 points can be added), or until ``self.task.cancel()`` is called.
488505 loss_goal : float, optional
@@ -492,6 +509,14 @@ class AsyncRunner(BaseRunner):
492509 Convenience argument, use instead of ``goal``. The end condition for the
493510 calculation. Stop when the number of points is larger or
494511 equal than this value.
512+ datetime_goal : datetime, optional
513+ Convenience argument, use instead of ``goal``. The end condition for the
514+ calculation. Stop when the current time is larger or equal than this
515+ value.
516+ timedelta_goal : timedelta, optional
517+ Convenience argument, use instead of ``goal``. The end condition for the
518+ calculation. Stop when the current time is larger or equal than
519+ ``start_time + timedelta_goal``.
495520 executor : `concurrent.futures.Executor`, `distributed.Client`,\
496521 `mpi4py.futures.MPIPoolExecutor`, `ipyparallel.Client` or\
497522 `loky.get_reusable_executor`, optional
@@ -555,10 +580,12 @@ class AsyncRunner(BaseRunner):
555580 def __init__ (
556581 self ,
557582 learner ,
558- goal : GoalTypes = None ,
583+ goal : Callable [[ BaseLearner ], bool ] | None = None ,
559584 * ,
560585 loss_goal : float | None = None ,
561586 npoints_goal : int | None = None ,
587+ datetime_goal : datetime | None = None ,
588+ timedelta_goal : timedelta | None = None ,
562589 executor = None ,
563590 ntasks = None ,
564591 log = False ,
@@ -770,10 +797,12 @@ async def _saver():
770797
771798def simple (
772799 learner ,
773- goal : GoalTypes = None ,
800+ goal : Callable [[ BaseLearner ], bool ] | None = None ,
774801 * ,
775802 loss_goal : float | None = None ,
776803 npoints_goal : int | None = None ,
804+ datetime_goal : datetime | None = None ,
805+ timedelta_goal : timedelta | None = None ,
777806):
778807 """Run the learner until the goal is reached.
779808
@@ -800,8 +829,24 @@ def simple(
800829 Convenience argument, use instead of ``goal``. The end condition for the
801830 calculation. Stop when the number of points is larger or
802831 equal than this value.
832+ datetime_goal : datetime, optional
833+ Convenience argument, use instead of ``goal``. The end condition for the
834+ calculation. Stop when the current time is larger or equal than this
835+ value.
836+ timedelta_goal : timedelta, optional
837+ Convenience argument, use instead of ``goal``. The end condition for the
838+ calculation. Stop when the current time is larger or equal than
839+ ``start_time + timedelta_goal``.
803840 """
804- goal = _goal (learner , goal , loss_goal , npoints_goal , allow_running_forever = False )
841+ goal = _goal (
842+ learner ,
843+ goal ,
844+ loss_goal ,
845+ npoints_goal ,
846+ datetime_goal ,
847+ timedelta_goal ,
848+ allow_running_forever = False ,
849+ )
805850 while not goal (learner ):
806851 xs , _ = learner .ask (1 )
807852 for x in xs :
@@ -942,8 +987,12 @@ def __call__(self, _):
942987
943988
944989def auto_goal (
945- goal : GoalTypes ,
946- learner : BaseLearner ,
990+ * ,
991+ loss : float | None = None ,
992+ npoints : int | None = None ,
993+ datetime : datetime | None = None ,
994+ timedelta : timedelta | None = None ,
995+ learner : BaseLearner | None = None ,
947996 allow_running_forever : bool = True ,
948997) -> Callable [[BaseLearner ], bool ]:
949998 """Extract a goal from the learners.
@@ -954,7 +1003,6 @@ def auto_goal(
9541003 The goal to extract. Can be a callable, an integer, a float, a datetime,
9551004 a timedelta or None.
9561005 If the type of `goal` is:
957-
9581006 * ``callable``, it is returned as is.
9591007 * ``int``, the goal is reached after that many points have been added.
9601008 * ``float``, the goal is reached when the learner has reached a loss
@@ -980,23 +1028,36 @@ def auto_goal(
9801028 -------
9811029 Callable[[adaptive.BaseLearner], bool]
9821030 """
983- if callable (goal ):
984- return goal
985- if isinstance (goal , float ):
986- return lambda learner : learner .loss () <= goal
1031+ kw = dict (
1032+ loss = loss ,
1033+ npoints = npoints ,
1034+ datetime = datetime ,
1035+ timedelta = timedelta ,
1036+ allow_running_forever = allow_running_forever ,
1037+ )
1038+ opts = (loss , npoints , datetime , timedelta ) # all are mutually exclusive
1039+ if sum (v is not None for v in opts ) > 1 :
1040+ raise ValueError (
1041+ "Only one of loss, npoints, datetime, timedelta can be specified."
1042+ )
1043+
1044+ if loss is not None :
1045+ return lambda learner : learner .loss () <= loss
9871046 if isinstance (learner , BalancingLearner ):
9881047 # Note that the float loss goal is more efficiently implemented in the
9891048 # BalancingLearner itself. That is why the previous if statement is
9901049 # above this one.
991- goals = [auto_goal (goal , l , allow_running_forever ) for l in learner .learners ]
1050+ goals = [auto_goal (learner = l , ** kw ) for l in learner .learners ]
9921051 return lambda learner : all (goal (l ) for l , goal in zip (learner .learners , goals ))
993- if isinstance (goal , int ):
994- return lambda learner : learner .npoints >= goal
995- if isinstance (goal , (timedelta , datetime )):
996- return _TimeGoal (goal )
1052+ if npoints is not None :
1053+ return lambda learner : learner .npoints >= npoints
1054+ if datetime is not None :
1055+ return _TimeGoal (datetime )
1056+ if timedelta is not None :
1057+ return _TimeGoal (timedelta )
9971058 if isinstance (learner , DataSaver ):
998- return auto_goal (goal , learner .learner , allow_running_forever )
999- if goal is None :
1059+ return auto_goal (** kw , learner = learner .learner )
1060+ if all ( v is None for v in opts ) :
10001061 if isinstance (learner , SequenceLearner ):
10011062 return SequenceLearner .done
10021063 if isinstance (learner , IntegratorLearner ):
@@ -1012,17 +1073,32 @@ def auto_goal(
10121073
10131074
10141075def _goal (
1015- learner : BaseLearner ,
1016- goal : GoalTypes ,
1076+ learner : BaseLearner | None ,
1077+ goal : Callable [[ BaseLearner ], bool ] | None ,
10171078 loss_goal : float | None ,
10181079 npoints_goal : int | None ,
1080+ datetime_goal : datetime | None ,
1081+ timedelta_goal : timedelta | None ,
10191082 allow_running_forever : bool ,
10201083):
1021- # goal, loss_goal, npoints_goal are mutually exclusive, only one can be not None
1022- if goal is not None and (loss_goal is not None or npoints_goal is not None ):
1023- raise ValueError ("Either goal, loss_goal, or npoints_goal can be specified." )
1024- if loss_goal is not None :
1025- goal = float (loss_goal )
1026- if npoints_goal is not None :
1027- goal = int (npoints_goal )
1028- return auto_goal (goal , learner , allow_running_forever )
1084+ if callable (goal ):
1085+ return goal
1086+
1087+ if goal is not None and (
1088+ loss_goal is not None
1089+ or npoints_goal is not None
1090+ or datetime_goal is not None
1091+ or timedelta_goal is not None
1092+ ):
1093+ raise ValueError (
1094+ "Either goal, loss_goal, npoints_goal, datetime_goal or"
1095+ " timedelta_goal can be specified, not multiple."
1096+ )
1097+ return auto_goal (
1098+ learner = learner ,
1099+ loss = loss_goal ,
1100+ npoints = npoints_goal ,
1101+ datetime = datetime_goal ,
1102+ timedelta = timedelta_goal ,
1103+ allow_running_forever = allow_running_forever ,
1104+ )
0 commit comments