Skip to content

Commit b7acacf

Browse files
committed
Fix a few tests
1 parent 84c5a34 commit b7acacf

File tree

4 files changed

+15
-32
lines changed

4 files changed

+15
-32
lines changed

src/dispatch/proto.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,10 @@ def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
452452
any.Unpack(b)
453453
return pickle.loads(b.value)
454454

455-
raise InvalidArgumentError("unsupported pickled value container")
455+
elif not any.type_url and not any.value:
456+
return None
457+
458+
raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")
456459

457460

458461
def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:

tests/test_fastapi.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from dispatch.fastapi import Dispatch
2222
from dispatch.function import Arguments, Error, Function, Input, Output
2323
from dispatch.proto import _any_unpickle as any_unpickle
24+
from dispatch.proto import _pb_any_pickle as any_pickle
2425
from dispatch.sdk.v1 import call_pb2 as call_pb
2526
from dispatch.sdk.v1 import function_pb2 as function_pb
2627
from dispatch.signature import (
@@ -91,23 +92,17 @@ async def my_function(input: Input) -> Output:
9192
)
9293

9394
client = create_endpoint_client(app)
94-
pickled = pickle.dumps("Hello World!")
95-
input_any = google.protobuf.any_pb2.Any()
96-
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))
9795

9896
req = function_pb.RunRequest(
9997
function=my_function.name,
100-
input=input_any,
98+
input=any_pickle("Hello World!"),
10199
)
102100

103101
resp = client.run(req)
104102

105103
self.assertIsInstance(resp, function_pb.RunResponse)
106104

107-
resp.exit.result.output.Unpack(
108-
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
109-
)
110-
output = pickle.loads(output_bytes.value)
105+
output = any_unpickle(resp.exit.result.output)
111106

112107
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")
113108

@@ -229,10 +224,8 @@ def execute(
229224
req = function_pb.RunRequest(function=func.name)
230225

231226
if input is not None:
232-
input_bytes = pickle.dumps(input)
233-
input_any = google.protobuf.any_pb2.Any()
234-
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=input_bytes))
235-
req.input.CopyFrom(input_any)
227+
any = any_pickle(input)
228+
req.input.CopyFrom(any)
236229
if state is not None:
237230
req.poll_result.coroutine_state = state
238231
if calls is not None:

tests/test_flask.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from dispatch.flask import Dispatch
2020
from dispatch.function import Arguments, Error, Function, Input, Output
2121
from dispatch.proto import _any_unpickle as any_unpickle
22+
from dispatch.proto import _pb_any_pickle as any_pickle
2223
from dispatch.sdk.v1 import call_pb2 as call_pb
2324
from dispatch.sdk.v1 import function_pb2 as function_pb
2425
from dispatch.signature import (
@@ -56,23 +57,16 @@ async def my_function(input: Input) -> Output:
5657
)
5758

5859
client = create_endpoint_client(app)
59-
pickled = pickle.dumps("Hello World!")
60-
input_any = google.protobuf.any_pb2.Any()
61-
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))
6260

6361
req = function_pb.RunRequest(
64-
function=my_function.name,
65-
input=input_any,
62+
function=my_function.name, input=any_pickle("Hello World!")
6663
)
6764

6865
resp = client.run(req)
6966

7067
self.assertIsInstance(resp, function_pb.RunResponse)
7168

72-
resp.exit.result.output.Unpack(
73-
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
74-
)
75-
output = pickle.loads(output_bytes.value)
69+
output = any_unpickle(resp.exit.result.output)
7670

7771
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")
7872

tests/test_http.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
2020
from dispatch.http import Dispatch
2121
from dispatch.proto import _any_unpickle as any_unpickle
22+
from dispatch.proto import _pb_any_pickle as any_pickle
2223
from dispatch.sdk.v1 import call_pb2 as call_pb
2324
from dispatch.sdk.v1 import function_pb2 as function_pb
2425
from dispatch.signature import parse_verification_key, public_key_from_pem
@@ -91,22 +92,14 @@ async def my_function(input: Input) -> Output:
9192
http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint))
9293
client = EndpointClient(http_client)
9394

94-
pickled = pickle.dumps("Hello World!")
95-
input_any = google.protobuf.any_pb2.Any()
96-
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))
97-
9895
req = function_pb.RunRequest(
99-
function=my_function.name,
100-
input=input_any,
96+
function=my_function.name, input=any_pickle("Hello World!")
10197
)
10298

10399
resp = client.run(req)
104100

105101
self.assertIsInstance(resp, function_pb.RunResponse)
106102

107-
resp.exit.result.output.Unpack(
108-
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
109-
)
110-
output = pickle.loads(output_bytes.value)
103+
output = any_unpickle(resp.exit.result.output)
111104

112105
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")

0 commit comments

Comments
 (0)