Skip to content

Commit 84c5a34

Browse files
committed
Use the new container for pickled values
1 parent 5221d30 commit 84c5a34

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

src/dispatch/proto.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
import tblib # type: ignore[import-untyped]
1313
from google.protobuf import descriptor_pool, duration_pb2, message_factory
1414

15-
from dispatch.error import IncompatibleStateError
15+
from dispatch.error import IncompatibleStateError, InvalidArgumentError
1616
from dispatch.id import DispatchID
17+
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
1718
from dispatch.sdk.v1 import call_pb2 as call_pb
1819
from dispatch.sdk.v1 import error_pb2 as error_pb
1920
from dispatch.sdk.v1 import exit_pb2 as exit_pb
@@ -78,16 +79,7 @@ def __init__(self, req: function_pb.RunRequest):
7879

7980
self._has_input = req.HasField("input")
8081
if self._has_input:
81-
if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
82-
input_pb = google.protobuf.wrappers_pb2.BytesValue()
83-
req.input.Unpack(input_pb)
84-
input_bytes = input_pb.value
85-
try:
86-
self._input = pickle.loads(input_bytes)
87-
except Exception as e:
88-
self._input = input_bytes
89-
else:
90-
self._input = _pb_any_unpack(req.input)
82+
self._input = _pb_any_unpack(req.input)
9183
else:
9284
if req.poll_result.coroutine_state:
9385
raise IncompatibleStateError # coroutine_state is deprecated
@@ -450,21 +442,44 @@ def _as_proto(self) -> error_pb.Error:
450442

451443

452444
def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
453-
any.Unpack(value_bytes := google.protobuf.wrappers_pb2.BytesValue())
454-
return pickle.loads(value_bytes.value)
455-
456-
457-
def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
458-
value_bytes = pickle.dumps(x)
459-
pb_bytes = google.protobuf.wrappers_pb2.BytesValue(value=value_bytes)
460-
pb_any = google.protobuf.any_pb2.Any()
461-
pb_any.Pack(pb_bytes)
462-
return pb_any
463-
445+
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
446+
p = pickled_pb.Pickled()
447+
any.Unpack(p)
448+
return pickle.loads(p.pickled_value)
449+
450+
elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): # legacy container
451+
b = google.protobuf.wrappers_pb2.BytesValue()
452+
any.Unpack(b)
453+
return pickle.loads(b.value)
454+
455+
raise InvalidArgumentError("unsupported pickled value container")
456+
457+
458+
def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
459+
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
460+
any = google.protobuf.any_pb2.Any()
461+
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
462+
return any
463+
464+
465+
def _pb_any_unpack(any: google.protobuf.any_pb2.Any) -> Any:
466+
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
467+
p = pickled_pb.Pickled()
468+
any.Unpack(p)
469+
return pickle.loads(p.pickled_value)
470+
471+
elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
472+
b = google.protobuf.wrappers_pb2.BytesValue()
473+
any.Unpack(b)
474+
try:
475+
# Assume it's the legacy container for pickled values.
476+
return pickle.loads(b.value)
477+
except Exception as e:
478+
# Otherwise, return the literal bytes.
479+
return b.value
464480

465-
def _pb_any_unpack(x: google.protobuf.any_pb2.Any) -> Any:
466481
pool = descriptor_pool.Default()
467-
msg_descriptor = pool.FindMessageTypeByName(x.TypeName())
482+
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
468483
proto = message_factory.GetMessageClass(msg_descriptor)()
469-
x.Unpack(proto)
484+
any.Unpack(proto)
470485
return proto

0 commit comments

Comments
 (0)