@@ -297,26 +297,43 @@ async def main():
297297 return await call_concurrently ("a" , "b" , "c" )
298298
299299 output = self .start (main , poll_min_results = 1 , poll_max_results = 10 )
300- self .assert_poll_call_functions (output , ["a" , "b" , "c" ], min_results = 1 , max_results = 3 )
300+ self .assert_poll_call_functions (
301+ output , ["a" , "b" , "c" ], min_results = 1 , max_results = 3
302+ )
301303
302304 output = self .start (main , poll_min_results = 1 , poll_max_results = 2 )
303- self .assert_poll_call_functions (output , ["a" , "b" , "c" ], min_results = 1 , max_results = 2 )
305+ self .assert_poll_call_functions (
306+ output , ["a" , "b" , "c" ], min_results = 1 , max_results = 2
307+ )
304308
305309 output = self .start (main , poll_min_results = 10 , poll_max_results = 10 )
306- self .assert_poll_call_functions (output , ["a" , "b" , "c" ], min_results = 3 , max_results = 3 )
310+ self .assert_poll_call_functions (
311+ output , ["a" , "b" , "c" ], min_results = 3 , max_results = 3
312+ )
307313
308- def start (self , main : Callable , * args : Any , poll_min_results = 1 , poll_max_results = 10 , poll_max_wait_seconds = None ,
309- ** kwargs : Any ) -> Output :
314+ def start (
315+ self ,
316+ main : Callable ,
317+ * args : Any ,
318+ poll_min_results = 1 ,
319+ poll_max_results = 10 ,
320+ poll_max_wait_seconds = None ,
321+ ** kwargs : Any ,
322+ ) -> Output :
310323 input = Input .from_input_arguments (main .__qualname__ , * args , ** kwargs )
311- return OneShotScheduler (main , poll_min_results = poll_min_results , poll_max_results = poll_max_results ,
312- poll_max_wait_seconds = poll_max_wait_seconds ).run (input )
324+ return OneShotScheduler (
325+ main ,
326+ poll_min_results = poll_min_results ,
327+ poll_max_results = poll_max_results ,
328+ poll_max_wait_seconds = poll_max_wait_seconds ,
329+ ).run (input )
313330
314331 def resume (
315- self ,
316- main : Callable ,
317- prev_output : Output ,
318- call_results : list [CallResult ],
319- poll_error : Exception | None = None ,
332+ self ,
333+ main : Callable ,
334+ prev_output : Output ,
335+ call_results : list [CallResult ],
336+ poll_error : Exception | None = None ,
320337 ):
321338 poll = self .assert_poll (prev_output )
322339 input = Input .from_poll_results (
@@ -346,7 +363,7 @@ def assert_exit_result_value(self, output: Output, expect: Any):
346363 self .assertEqual (expect , any_unpickle (result .output ))
347364
348365 def assert_exit_result_error (
349- self , output : Output , expect : type [Exception ], message : str | None = None
366+ self , output : Output , expect : type [Exception ], message : str | None = None
350367 ):
351368 result = self .assert_exit_result (output )
352369 self .assertFalse (result .HasField ("output" ))
@@ -373,7 +390,7 @@ def assert_empty_poll(self, output: Output):
373390 self .assertEqual (len (poll .calls ), 0 )
374391
375392 def assert_poll_call_functions (
376- self , output : Output , expect : list [str ], min_results = None , max_results = None
393+ self , output : Output , expect : list [str ], min_results = None , max_results = None
377394 ):
378395 poll = self .assert_poll (output )
379396 # Note: we're not testing endpoint/input here.
0 commit comments