Skip to content

Commit 1b21e50

Browse files
committed
Clamp min_results if greater than outstanding calls
1 parent 2a4c856 commit 1b21e50

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

src/dispatch/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def _run(self, input: Input) -> Output:
440440
return Output.poll(
441441
state=serialized_state,
442442
calls=pending_calls,
443-
min_results=max(1, self.poll_min_results),
443+
min_results=max(1, min(state.outstanding_calls, self.poll_min_results)),
444444
max_results=max(1, min(state.outstanding_calls, self.poll_max_results)),
445445
max_wait_seconds=self.poll_max_wait_seconds,
446446
)

tests/dispatch/test_scheduler.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,16 +291,32 @@ async def main():
291291
output = self.start(main)
292292
self.assert_exit_result_error(output, ValueError, "oops")
293293

294-
def start(self, main: Callable, *args: Any, **kwargs: Any) -> Output:
294+
def test_min_max_results_clamping(self):
295+
@durable
296+
async def main():
297+
return await call_concurrently("a", "b", "c")
298+
299+
output = self.start(main, poll_min_results=1, poll_max_results=10)
300+
self.assert_poll_call_functions(output, ["a", "b", "c"], min_results=1, max_results=3)
301+
302+
output = self.start(main, poll_min_results=1, poll_max_results=2)
303+
self.assert_poll_call_functions(output, ["a", "b", "c"], min_results=1, max_results=2)
304+
305+
output = self.start(main, poll_min_results=10, poll_max_results=10)
306+
self.assert_poll_call_functions(output, ["a", "b", "c"], min_results=3, max_results=3)
307+
308+
def start(self, main: Callable, *args: Any, poll_min_results=1, poll_max_results=10, poll_max_wait_seconds=None,
309+
**kwargs: Any) -> Output:
295310
input = Input.from_input_arguments(main.__qualname__, *args, **kwargs)
296-
return OneShotScheduler(main).run(input)
311+
return OneShotScheduler(main, poll_min_results=poll_min_results, poll_max_results=poll_max_results,
312+
poll_max_wait_seconds=poll_max_wait_seconds).run(input)
297313

298314
def resume(
299-
self,
300-
main: Callable,
301-
prev_output: Output,
302-
call_results: list[CallResult],
303-
poll_error: Exception | None = None,
315+
self,
316+
main: Callable,
317+
prev_output: Output,
318+
call_results: list[CallResult],
319+
poll_error: Exception | None = None,
304320
):
305321
poll = self.assert_poll(prev_output)
306322
input = Input.from_poll_results(
@@ -330,7 +346,7 @@ def assert_exit_result_value(self, output: Output, expect: Any):
330346
self.assertEqual(expect, any_unpickle(result.output))
331347

332348
def assert_exit_result_error(
333-
self, output: Output, expect: type[Exception], message: str | None = None
349+
self, output: Output, expect: type[Exception], message: str | None = None
334350
):
335351
result = self.assert_exit_result(output)
336352
self.assertFalse(result.HasField("output"))
@@ -357,7 +373,7 @@ def assert_empty_poll(self, output: Output):
357373
self.assertEqual(len(poll.calls), 0)
358374

359375
def assert_poll_call_functions(
360-
self, output: Output, expect: list[str], min_results=None, max_results=None
376+
self, output: Output, expect: list[str], min_results=None, max_results=None
361377
):
362378
poll = self.assert_poll(output)
363379
# Note: we're not testing endpoint/input here.

0 commit comments

Comments
 (0)