|
12 | 12 | import tblib # type: ignore[import-untyped] |
13 | 13 | from google.protobuf import descriptor_pool, duration_pb2, message_factory |
14 | 14 |
|
15 | | -from dispatch.error import IncompatibleStateError |
| 15 | +from dispatch.error import IncompatibleStateError, InvalidArgumentError |
16 | 16 | from dispatch.id import DispatchID |
| 17 | +from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb |
17 | 18 | from dispatch.sdk.v1 import call_pb2 as call_pb |
18 | 19 | from dispatch.sdk.v1 import error_pb2 as error_pb |
19 | 20 | from dispatch.sdk.v1 import exit_pb2 as exit_pb |
@@ -78,16 +79,7 @@ def __init__(self, req: function_pb.RunRequest): |
78 | 79 |
|
79 | 80 | self._has_input = req.HasField("input") |
80 | 81 | if self._has_input: |
81 | | - if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): |
82 | | - input_pb = google.protobuf.wrappers_pb2.BytesValue() |
83 | | - req.input.Unpack(input_pb) |
84 | | - input_bytes = input_pb.value |
85 | | - try: |
86 | | - self._input = pickle.loads(input_bytes) |
87 | | - except Exception as e: |
88 | | - self._input = input_bytes |
89 | | - else: |
90 | | - self._input = _pb_any_unpack(req.input) |
| 82 | + self._input = _pb_any_unpack(req.input) |
91 | 83 | else: |
92 | 84 | if req.poll_result.coroutine_state: |
93 | 85 | raise IncompatibleStateError # coroutine_state is deprecated |
@@ -450,21 +442,44 @@ def _as_proto(self) -> error_pb.Error: |
450 | 442 |
|
451 | 443 |
|
452 | 444 | def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any: |
453 | | - any.Unpack(value_bytes := google.protobuf.wrappers_pb2.BytesValue()) |
454 | | - return pickle.loads(value_bytes.value) |
455 | | - |
456 | | - |
457 | | -def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any: |
458 | | - value_bytes = pickle.dumps(x) |
459 | | - pb_bytes = google.protobuf.wrappers_pb2.BytesValue(value=value_bytes) |
460 | | - pb_any = google.protobuf.any_pb2.Any() |
461 | | - pb_any.Pack(pb_bytes) |
462 | | - return pb_any |
463 | | - |
| 445 | + if any.Is(pickled_pb.Pickled.DESCRIPTOR): |
| 446 | + p = pickled_pb.Pickled() |
| 447 | + any.Unpack(p) |
| 448 | + return pickle.loads(p.pickled_value) |
| 449 | + |
| 450 | + elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container |
| 451 | + b = google.protobuf.wrappers_pb2.BytesValue() |
| 452 | + any.Unpack(b) |
| 453 | + return pickle.loads(b.value) |
| 454 | + |
| 455 | + raise InvalidArgumentError("unsupported pickled value container") |
| 456 | + |
| 457 | + |
| 458 | +def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any: |
| 459 | + p = pickled_pb.Pickled(pickled_value=pickle.dumps(value)) |
| 460 | + any = google.protobuf.any_pb2.Any() |
| 461 | + any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/") |
| 462 | + return any |
| 463 | + |
| 464 | + |
| 465 | +def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any: |
| 466 | + if any.Is(pickled_pb.Pickled.DESCRIPTOR): |
| 467 | + p = pickled_pb.Pickled() |
| 468 | + any.Unpack(p) |
| 469 | + return pickle.loads(p.pickled_value) |
| 470 | + |
| 471 | + elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): |
| 472 | + b = google.protobuf.wrappers_pb2.BytesValue() |
| 473 | + any.Unpack(b) |
| 474 | + try: |
| 475 | + # Assume it's the legacy container for pickled values. |
| 476 | + return pickle.loads(b.value) |
| 477 | + except Exception as e: |
| 478 | + # Otherwise, return the literal bytes. |
| 479 | + return b.value |
464 | 480 |
|
465 | | -def _pb_any_unpack(x: google.protobuf.any_pb2.Any) -> Any: |
466 | 481 | pool = descriptor_pool.Default() |
467 | | - msg_descriptor = pool.FindMessageTypeByName(x.TypeName()) |
| 482 | + msg_descriptor = pool.FindMessageTypeByName(any.TypeName()) |
468 | 483 | proto = message_factory.GetMessageClass(msg_descriptor)() |
469 | | - x.Unpack(proto) |
| 484 | + any.Unpack(proto) |
470 | 485 | return proto |
0 commit comments