Skip to content

Commit c5a4461

Browse files
authored
Merge pull request #69 from stealthrocket/error-info
Preserve exception details
2 parents 7bcb89d + 2454090 commit c5a4461

File tree

3 files changed

+64
-20
lines changed

3 files changed

+64
-20
lines changed

src/dispatch/proto.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pickle
44
from dataclasses import dataclass
5+
from types import TracebackType
56
from typing import Any
67

78
import google.protobuf.any_pb2
@@ -280,13 +281,20 @@ class Error:
280281
Output.
281282
"""
282283

283-
def __init__(self, status: Status, type: str | None, message: str | None):
284+
def __init__(
285+
self,
286+
status: Status,
287+
type: str | None,
288+
message: str | None,
289+
value: Exception | None = None,
290+
):
284291
"""Create a new Error.
285292
286293
Args:
287294
status: categorization of the error.
288295
type: arbitrary string, used for humans. Optional.
289296
message: arbitrary message. Optional.
297+
value: arbitrary exception from which the error is derived. Optional.
290298
291299
Raises:
292300
ValueError: Neither type or message was provided or status is
@@ -300,6 +308,7 @@ def __init__(self, status: Status, type: str | None, message: str | None):
300308
self.type = type
301309
self.message = message
302310
self.status = status
311+
self.value = value
303312

304313
@classmethod
305314
def from_exception(cls, ex: Exception, status: Status | None = None) -> Error:
@@ -313,18 +322,31 @@ def from_exception(cls, ex: Exception, status: Status | None = None) -> Error:
313322
if status is None:
314323
status = status_for_error(ex)
315324

316-
return Error(status, ex.__class__.__qualname__, str(ex))
325+
return Error(status, ex.__class__.__qualname__, str(ex), ex)
317326

318327
def to_exception(self) -> Exception:
319-
# TODO: use correct error type
320-
return RuntimeError(self.message)
328+
"""Returns an equivalent exception."""
329+
if self.value is not None:
330+
return self.value
331+
332+
g = globals()
333+
try:
334+
assert isinstance(self.type, str)
335+
cls = g[self.type]
336+
assert issubclass(cls, Exception)
337+
except (KeyError, AssertionError):
338+
return RuntimeError(self.message)
339+
else:
340+
return cls(self.message)
321341

322342
@classmethod
323343
def _from_proto(cls, proto: error_pb.Error) -> Error:
324-
return cls(Status.UNSPECIFIED, proto.type, proto.message)
344+
value = pickle.loads(proto.value) if proto.value else None
345+
return cls(Status.UNSPECIFIED, proto.type, proto.message, value)
325346

326347
def _as_proto(self) -> error_pb.Error:
327-
return error_pb.Error(type=self.type, message=self.message)
348+
value = pickle.dumps(self.value) if self.value else None
349+
return error_pb.Error(type=self.type, message=self.message, value=value)
328350

329351

330352
def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:

src/dispatch/scheduler.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _run(self, input: Input) -> Output:
222222
future = owner.result
223223
assert future is not None
224224
except (KeyError, AssertionError):
225-
logger.warning("skipping unexpected call result %s", cr)
225+
logger.warning("discarding unexpected call result %s", cr)
226226
continue
227227

228228
logger.debug("dispatching %s to %s", call_result, owner)
@@ -254,7 +254,6 @@ def _run(self, input: Input) -> Output:
254254
coroutine_id=coroutine.id, value=e.value
255255
)
256256
except Exception as e:
257-
raise
258257
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e)
259258

260259
# Handle coroutines that return or raise.
@@ -266,23 +265,27 @@ def _run(self, input: Input) -> Output:
266265

267266
# If this is the main coroutine, we're done.
268267
if coroutine.parent_id is None:
269-
assert len(state.suspended) == 0
268+
for suspended in state.suspended.values():
269+
suspended.coroutine.close()
270270
if coroutine_result.error is not None:
271271
return Output.error(
272272
Error.from_exception(coroutine_result.error)
273273
)
274274
return Output.value(coroutine_result.value)
275275

276276
# Otherwise, notify the parent of the result.
277-
assert coroutine.parent_id in state.suspended
278-
parent = state.suspended[coroutine.parent_id]
279-
assert parent.result is not None
280-
future = parent.result
281-
future.add(coroutine_result)
282-
if future.ready():
283-
state.ready.insert(0, parent)
284-
del state.suspended[parent.id]
285-
logger.debug("parent %s is now ready", parent)
277+
try:
278+
parent = state.suspended[coroutine.parent_id]
279+
future = parent.result
280+
assert future is not None
281+
except (KeyError, AssertionError):
282+
logger.warning("discarding %s", coroutine_result)
283+
else:
284+
future.add(coroutine_result)
285+
if future.ready():
286+
state.ready.insert(0, parent)
287+
del state.suspended[parent.id]
288+
logger.debug("parent %s is now ready", parent)
286289
continue
287290

288291
# Handle coroutines that yield.
@@ -343,6 +346,11 @@ def _run(self, input: Input) -> Output:
343346
logger.exception("state could not be serialized")
344347
return Output.error(Error.from_exception(e, status=Status.PERMANENT_ERROR))
345348

349+
# Close coroutines before yielding.
350+
for suspended in state.suspended.values():
351+
suspended.coroutine.close()
352+
state.suspended = {}
353+
346354
# Yield to Dispatch.
347355
logger.debug(
348356
"yielding to Dispatch with %d call(s) and %d bytes of state",

tests/dispatch/test_scheduler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ async def call_sequentially(*functions):
3131
return results
3232

3333

34+
@durable
35+
async def raises_error():
36+
raise ValueError("oops")
37+
38+
3439
class TestOneShotScheduler(unittest.TestCase):
3540
def test_main_return(self):
3641
@durable
@@ -43,10 +48,10 @@ async def main():
4348
def test_main_raise(self):
4449
@durable
4550
async def main():
46-
raise RuntimeError("oops")
51+
raise ValueError("oops")
4752

4853
output = self.start(main)
49-
self.assert_exit_result_error(output, RuntimeError, "oops")
54+
self.assert_exit_result_error(output, ValueError, "oops")
5055

5156
def test_main_args(self):
5257
@durable
@@ -215,6 +220,14 @@ async def main():
215220

216221
self.assertEqual(len(correlation_ids), 8)
217222

223+
def test_raise_indirect(self):
224+
@durable
225+
async def main():
226+
return await gather(call_one("a"), raises_error())
227+
228+
output = self.start(main)
229+
self.assert_exit_result_error(output, ValueError, "oops")
230+
218231
def start(self, main: Callable, *args: Any, **kwargs: Any) -> Output:
219232
input = Input.from_input_arguments(main.__qualname__, *args, **kwargs)
220233
return OneShotScheduler(main).run(input)
@@ -258,6 +271,7 @@ def assert_exit_result_error(
258271
self.assertEqual(error.__class__, expect)
259272
if message is not None:
260273
self.assertEqual(str(error), message)
274+
return error
261275

262276
def assert_poll(self, output: Output) -> poll_pb.Poll:
263277
response = output._message

0 commit comments

Comments
 (0)