Skip to content

Commit 2a4c856

Browse files
authored
Merge pull request #92 from stealthrocket/poll-min-results
Use min_results when polling
2 parents 5aa45e0 + 864ee7a commit 2a4c856

File tree

7 files changed

+97
-28
lines changed

7 files changed

+97
-28
lines changed

src/dispatch/proto.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def poll(
202202
cls,
203203
state: Any,
204204
calls: None | list[Call] = None,
205-
max_results: int = 1,
205+
min_results: int = 1,
206+
max_results: int = 10,
206207
max_wait_seconds: int | None = None,
207208
) -> Output:
208209
"""Suspend the function with a set of Calls, instructing the
@@ -216,6 +217,7 @@ def poll(
216217
)
217218
poll = poll_pb.Poll(
218219
coroutine_state=state_bytes,
220+
min_results=min_results,
219221
max_results=max_results,
220222
max_wait=max_wait,
221223
)

src/dispatch/scheduler.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
logger = logging.getLogger(__name__)
1414

15-
1615
CallID: TypeAlias = int
1716
CoroutineID: TypeAlias = int
1817
CorrelationID: TypeAlias = int
@@ -38,9 +37,13 @@ class CallResult:
3837

3938
class Future(Protocol):
4039
def add_result(self, result: CallResult | CoroutineResult): ...
40+
4141
def add_error(self, error: Exception): ...
42+
4243
def ready(self) -> bool: ...
44+
4345
def error(self) -> Exception | None: ...
46+
4447
def value(self) -> Any: ...
4548

4649

@@ -147,7 +150,9 @@ class State:
147150
next_coroutine_id: int
148151
next_call_id: int
149152

150-
prev_calls: list[Coroutine]
153+
prev_callers: list[Coroutine]
154+
155+
outstanding_calls: int
151156

152157

153158
class OneShotScheduler:
@@ -158,13 +163,46 @@ class OneShotScheduler:
158163
take over scheduling asynchronous calls.
159164
"""
160165

161-
__slots__ = ("entry_point", "version", "poll_max_wait_seconds")
166+
__slots__ = (
167+
"entry_point",
168+
"version",
169+
"poll_min_results",
170+
"poll_max_results",
171+
"poll_max_wait_seconds",
172+
)
162173

163174
def __init__(
164-
self, entry_point: Callable, version=sys.version, poll_max_wait_seconds=5
175+
self,
176+
entry_point: Callable,
177+
version: str = sys.version,
178+
poll_min_results: int = 1,
179+
poll_max_results: int = 10,
180+
poll_max_wait_seconds: int | None = None,
165181
):
182+
"""Initialize the scheduler.
183+
184+
Args:
185+
entry_point: Entry point for the main coroutine.
186+
187+
version: Version string to attach to scheduler/coroutine state.
188+
If the scheduler sees a version mismatch, it will respond to
189+
Dispatch with an INCOMPATIBLE_STATE status code.
190+
191+
poll_min_results: Minimum number of call results to wait for before
192+
coroutine execution should continue. Dispatch waits until this
193+
many results are available, or the poll_max_wait_seconds
194+
timeout is reached, whichever comes first.
195+
196+
poll_max_results: Maximum number of calls to receive from Dispatch
197+
per request.
198+
199+
poll_max_wait_seconds: Maximum amount of time to suspend coroutines
200+
while waiting for call results. Optional.
201+
"""
166202
self.entry_point = entry_point
167203
self.version = version
204+
self.poll_min_results = poll_min_results
205+
self.poll_max_results = poll_max_results
168206
self.poll_max_wait_seconds = poll_max_wait_seconds
169207
logger.debug(
170208
"booting coroutine scheduler with entry point '%s' version '%s'",
@@ -198,7 +236,8 @@ def _init_state(self, input: Input) -> State:
198236
ready=[Coroutine(id=0, parent_id=None, coroutine=main)],
199237
next_coroutine_id=1,
200238
next_call_id=1,
201-
prev_calls=[],
239+
prev_callers=[],
240+
outstanding_calls=0,
202241
)
203242

204243
def _rebuild_state(self, input: Input):
@@ -229,16 +268,17 @@ def _run(self, input: Input) -> Output:
229268
if poll_error is not None:
230269
error = poll_error.to_exception()
231270
logger.debug("dispatching poll error: %s", error)
232-
for coroutine in state.prev_calls:
271+
for coroutine in state.prev_callers:
233272
future = coroutine.result
234273
assert future is not None
235274
future.add_error(error)
236275
if future.ready() and coroutine.id in state.suspended:
237276
state.ready.append(coroutine)
238277
del state.suspended[coroutine.id]
239278
logger.debug("coroutine %s is now ready", coroutine)
279+
state.outstanding_calls -= 1
240280

241-
state.prev_calls = []
281+
state.prev_callers = []
242282

243283
logger.debug("dispatching %d call result(s)", len(input.call_results))
244284
for cr in input.call_results:
@@ -265,6 +305,7 @@ def _run(self, input: Input) -> Output:
265305
state.ready.append(owner)
266306
del state.suspended[owner.id]
267307
logger.debug("owner %s is now ready", owner)
308+
state.outstanding_calls -= 1
268309

269310
logger.debug(
270311
"%d/%d coroutines are ready",
@@ -342,7 +383,8 @@ def _run(self, input: Input) -> Output:
342383
pending_calls.append(call)
343384
coroutine.result = CallFuture()
344385
state.suspended[coroutine.id] = coroutine
345-
state.prev_calls.append(coroutine)
386+
state.prev_callers.append(coroutine)
387+
state.outstanding_calls += 1
346388

347389
case Gather():
348390
gather = coroutine_yield
@@ -398,9 +440,8 @@ def _run(self, input: Input) -> Output:
398440
return Output.poll(
399441
state=serialized_state,
400442
calls=pending_calls,
401-
max_results=1,
402-
# FIXME: use min_results + max_results + max_wait to balance latency/throughput
403-
# max_results=len(max_results),
443+
min_results=max(1, self.poll_min_results),
444+
max_results=max(1, min(state.outstanding_calls, self.poll_max_results)),
404445
max_wait_seconds=self.poll_max_wait_seconds,
405446
)
406447

src/dispatch/sdk/v1/poll_pb2.py

Lines changed: 8 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/dispatch/sdk/v1/poll_pb2.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,24 @@ from dispatch.sdk.v1 import error_pb2 as _error_pb2
1616
DESCRIPTOR: _descriptor.FileDescriptor
1717

1818
class Poll(_message.Message):
19-
__slots__ = ("coroutine_state", "calls", "max_wait", "max_results")
19+
__slots__ = ("coroutine_state", "calls", "max_wait", "max_results", "min_results")
2020
COROUTINE_STATE_FIELD_NUMBER: _ClassVar[int]
2121
CALLS_FIELD_NUMBER: _ClassVar[int]
2222
MAX_WAIT_FIELD_NUMBER: _ClassVar[int]
2323
MAX_RESULTS_FIELD_NUMBER: _ClassVar[int]
24+
MIN_RESULTS_FIELD_NUMBER: _ClassVar[int]
2425
coroutine_state: bytes
2526
calls: _containers.RepeatedCompositeFieldContainer[_call_pb2.Call]
2627
max_wait: _duration_pb2.Duration
2728
max_results: int
29+
min_results: int
2830
def __init__(
2931
self,
3032
coroutine_state: _Optional[bytes] = ...,
3133
calls: _Optional[_Iterable[_Union[_call_pb2.Call, _Mapping]]] = ...,
3234
max_wait: _Optional[_Union[_duration_pb2.Duration, _Mapping]] = ...,
3335
max_results: _Optional[int] = ...,
36+
min_results: _Optional[int] = ...,
3437
) -> None: ...
3538

3639
class PollResult(_message.Message):

src/dispatch/test/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class DispatchServer:
1919
def __init__(
2020
self,
2121
service: dispatch_grpc.DispatchServiceServicer,
22-
hostname="127.0.0.1",
23-
port=0,
22+
hostname: str = "127.0.0.1",
23+
port: int = 0,
2424
):
2525
self._thread_pool = concurrent.futures.thread.ThreadPoolExecutor()
2626
self._server = grpc.server(self._thread_pool)

src/dispatch/test/service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ class Poller:
275275
function: str
276276

277277
coroutine_state: bytes
278-
# TODO: support max_wait/max_results
278+
# TODO: support max_wait/min_results/max_results
279279

280280
waiting: dict[DispatchID, call_pb.Call]
281281
results: dict[DispatchID, call_pb.CallResult]

tests/dispatch/test_scheduler.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ async def main():
102102
output = self.start(main)
103103

104104
self.assert_poll_call_functions(
105-
output, ["a", "b", "c", "d", "e", "f", "g", "h"]
105+
output,
106+
["a", "b", "c", "d", "e", "f", "g", "h"],
107+
min_results=1,
108+
max_results=8,
106109
)
107110

108111
def test_resume_after_call(self):
@@ -175,31 +178,39 @@ async def main():
175178

176179
output = self.start(main)
177180
# a, b, c, d are called first. e is not because it depends on a.
178-
calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"])
181+
calls = self.assert_poll_call_functions(
182+
output, ["a", "b", "c", "d"], min_results=1, max_results=4
183+
)
179184
correlation_ids.update(call.correlation_id for call in calls)
180185
results = [
181186
CallResult.from_value(i, correlation_id=call.correlation_id)
182187
for i, call in enumerate(calls)
183188
]
184189
output = self.resume(main, output, results)
185190
# e is called next
186-
calls = self.assert_poll_call_functions(output, ["e"])
191+
calls = self.assert_poll_call_functions(
192+
output, ["e"], min_results=1, max_results=1
193+
)
187194
correlation_ids.update(call.correlation_id for call in calls)
188195
output = self.resume(
189196
main,
190197
output,
191198
[CallResult.from_value(4, correlation_id=calls[0].correlation_id)],
192199
)
193200
# f is called next
194-
calls = self.assert_poll_call_functions(output, ["f"])
201+
calls = self.assert_poll_call_functions(
202+
output, ["f"], min_results=1, max_results=1
203+
)
195204
correlation_ids.update(call.correlation_id for call in calls)
196205
output = self.resume(
197206
main,
198207
output,
199208
[CallResult.from_value(5, correlation_id=calls[0].correlation_id)],
200209
)
201210
# g, h are called next
202-
calls = self.assert_poll_call_functions(output, ["g", "h"])
211+
calls = self.assert_poll_call_functions(
212+
output, ["g", "h"], min_results=1, max_results=2
213+
)
203214
correlation_ids.update(call.correlation_id for call in calls)
204215
output = self.resume(
205216
main,
@@ -244,7 +255,9 @@ async def main(c_then_d):
244255
)
245256

246257
output = self.start(main, c_then_d)
247-
calls = self.assert_poll_call_functions(output, ["a", "b", "c"])
258+
calls = self.assert_poll_call_functions(
259+
output, ["a", "b", "c"], min_results=1, max_results=3
260+
)
248261

249262
call_a, call_b, call_c = calls
250263
a_result, b_result, c_result = 10, 20, 30
@@ -253,7 +266,7 @@ async def main(c_then_d):
253266
output,
254267
[CallResult.from_value(c_result, correlation_id=call_c.correlation_id)],
255268
)
256-
self.assert_poll_call_functions(output, ["d"])
269+
self.assert_poll_call_functions(output, ["d"], min_results=1, max_results=3)
257270

258271
output = self.resume(
259272
main, output, [], poll_error=RuntimeError("too many calls")
@@ -343,7 +356,9 @@ def assert_empty_poll(self, output: Output):
343356
poll = self.assert_poll(output)
344357
self.assertEqual(len(poll.calls), 0)
345358

346-
def assert_poll_call_functions(self, output: Output, expect: list[str]):
359+
def assert_poll_call_functions(
360+
self, output: Output, expect: list[str], min_results=None, max_results=None
361+
):
347362
poll = self.assert_poll(output)
348363
# Note: we're not testing endpoint/input here.
349364
# Check function names match:
@@ -355,4 +370,8 @@ def assert_poll_call_functions(self, output: Output, expect: list[str]):
355370
len(set(correlation_ids)),
356371
"correlation IDs were not unique",
357372
)
373+
if min_results is not None:
374+
self.assertEqual(min_results, poll.min_results)
375+
if max_results is not None:
376+
self.assertEqual(max_results, poll.max_results)
358377
return poll.calls

0 commit comments

Comments
 (0)