@@ -291,16 +291,32 @@ async def main():
291291 output = self .start (main )
292292 self .assert_exit_result_error (output , ValueError , "oops" )
293293
294- def start (self , main : Callable , * args : Any , ** kwargs : Any ) -> Output :
294+ def test_min_max_results_clamping (self ):
295+ @durable
296+ async def main ():
297+ return await call_concurrently ("a" , "b" , "c" )
298+
299+ output = self .start (main , poll_min_results = 1 , poll_max_results = 10 )
300+ self .assert_poll_call_functions (output , ["a" , "b" , "c" ], min_results = 1 , max_results = 3 )
301+
302+ output = self .start (main , poll_min_results = 1 , poll_max_results = 2 )
303+ self .assert_poll_call_functions (output , ["a" , "b" , "c" ], min_results = 1 , max_results = 2 )
304+
305+ output = self .start (main , poll_min_results = 10 , poll_max_results = 10 )
306+ self .assert_poll_call_functions (output , ["a" , "b" , "c" ], min_results = 3 , max_results = 3 )
307+
308+ def start (self , main : Callable , * args : Any , poll_min_results = 1 , poll_max_results = 10 , poll_max_wait_seconds = None ,
309+ ** kwargs : Any ) -> Output :
295310 input = Input .from_input_arguments (main .__qualname__ , * args , ** kwargs )
296- return OneShotScheduler (main ).run (input )
311+ return OneShotScheduler (main , poll_min_results = poll_min_results , poll_max_results = poll_max_results ,
312+ poll_max_wait_seconds = poll_max_wait_seconds ).run (input )
297313
298314 def resume (
299- self ,
300- main : Callable ,
301- prev_output : Output ,
302- call_results : list [CallResult ],
303- poll_error : Exception | None = None ,
315+ self ,
316+ main : Callable ,
317+ prev_output : Output ,
318+ call_results : list [CallResult ],
319+ poll_error : Exception | None = None ,
304320 ):
305321 poll = self .assert_poll (prev_output )
306322 input = Input .from_poll_results (
@@ -330,7 +346,7 @@ def assert_exit_result_value(self, output: Output, expect: Any):
330346 self .assertEqual (expect , any_unpickle (result .output ))
331347
332348 def assert_exit_result_error (
333- self , output : Output , expect : type [Exception ], message : str | None = None
349+ self , output : Output , expect : type [Exception ], message : str | None = None
334350 ):
335351 result = self .assert_exit_result (output )
336352 self .assertFalse (result .HasField ("output" ))
@@ -357,7 +373,7 @@ def assert_empty_poll(self, output: Output):
357373 self .assertEqual (len (poll .calls ), 0 )
358374
359375 def assert_poll_call_functions (
360- self , output : Output , expect : list [str ], min_results = None , max_results = None
376+ self , output : Output , expect : list [str ], min_results = None , max_results = None
361377 ):
362378 poll = self .assert_poll (output )
363379 # Note: we're not testing endpoint/input here.
0 commit comments