Skip to content

Commit 50377d7

Browse files
fix tests
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
1 parent bd7868b commit 50377d7

File tree

4 files changed

+148
-81
lines changed

4 files changed

+148
-81
lines changed

src/dispatch/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,27 @@ def run(init: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
9191

9292

9393
@contextmanager
94-
def serve(address: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")):
94+
def serve(
95+
address: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000"),
96+
poll_interval: float = 0.5,
97+
):
9598
"""Returns a context manager managing the operation of a Disaptch server
9699
running on the given address. The server is initialized before the context
97100
manager yields, then runs forever until the the program is interrupted.
98101
99102
Args:
100103
address: The address to bind the server to. Defaults to the value of the
101-
DISPATCH_ENDPOINT_ADDR environment variable, or 'localhost:8000' if it
102-
wasn't set.
104+
DISPATCH_ENDPOINT_ADDR environment variable, or 'localhost:8000' if
105+
it wasn't set.
106+
107+
poll_interval: Poll for shutdown every poll_interval seconds.
108+
Defaults to 0.5 seconds.
103109
"""
104110
parsed_url = urlsplit("//" + address)
105111
server_address = (parsed_url.hostname or "", parsed_url.port or 0)
106112
server = ThreadingHTTPServer(server_address, Dispatch(default_registry()))
107113
try:
108114
yield server
109-
server.serve_forever()
115+
server.serve_forever(poll_interval=poll_interval)
110116
finally:
111117
server.server_close()

tests/dispatch/signature/test_signature.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
import base64
2+
import os
13
import unittest
24
from datetime import datetime, timedelta
5+
from unittest import mock
36

47
from http_message_signatures import HTTPMessageSigner
58
from http_message_signatures._algorithms import ED25519
69

710
from dispatch.signature import (
811
CaseInsensitiveDict,
12+
Ed25519PublicKey,
913
InvalidSignature,
1014
Request,
15+
parse_verification_key,
1116
sign_request,
1217
verify_request,
1318
)
@@ -33,6 +38,18 @@
3338
"""
3439
)
3540

41+
public_key2_pem = """-----BEGIN PUBLIC KEY-----
42+
MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
43+
-----END PUBLIC KEY-----
44+
"""
45+
public_key2_pem2 = """-----BEGIN PUBLIC KEY-----
46+
MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
47+
-----END PUBLIC KEY-----
48+
"""
49+
public_key2 = public_key_from_pem(public_key2_pem)
50+
public_key2_bytes = public_key2.public_bytes_raw()
51+
public_key2_b64 = base64.b64encode(public_key2_bytes)
52+
3653

3754
class TestSignature(unittest.TestCase):
3855
def setUp(self):
@@ -125,3 +142,70 @@ def test_known_signature(self):
125142
ValueError, "public key 'test-key-ed25519' not available"
126143
):
127144
verify_request(request, public_key, max_age=timedelta(weeks=9000))
145+
146+
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_pem})
147+
def test_parse_verification_key_env_pem_str(self):
148+
verification_key = parse_verification_key(None)
149+
self.assertIsInstance(verification_key, Ed25519PublicKey)
150+
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
151+
152+
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_pem2})
153+
def test_parse_verification_key_env_pem_escaped_newline_str(self):
154+
verification_key = parse_verification_key(None)
155+
self.assertIsInstance(verification_key, Ed25519PublicKey)
156+
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
157+
158+
@mock.patch.dict(
159+
os.environ, {"DISPATCH_VERIFICATION_KEY": public_key2_b64.decode()}
160+
)
161+
def test_parse_verification_key_env_b64_str(self):
162+
verification_key = parse_verification_key(None)
163+
self.assertIsInstance(verification_key, Ed25519PublicKey)
164+
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
165+
166+
def test_parse_verification_key_none(self):
167+
# The verification key is optional. Both Dispatch(verification_key=...) and
168+
# DISPATCH_VERIFICATION_KEY may be omitted/None.
169+
verification_key = parse_verification_key(None)
170+
self.assertIsNone(verification_key)
171+
172+
def test_parse_verification_key_ed25519publickey(self):
173+
verification_key = parse_verification_key(public_key2)
174+
self.assertIsInstance(verification_key, Ed25519PublicKey)
175+
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
176+
177+
def test_parse_verification_key_pem_str(self):
178+
verification_key = parse_verification_key(public_key2_pem)
179+
self.assertIsInstance(verification_key, Ed25519PublicKey)
180+
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
181+
182+
def test_parse_verification_key_pem_escaped_newline_str(self):
183+
verification_key = parse_verification_key(public_key2_pem2)
184+
self.assertIsInstance(verification_key, Ed25519PublicKey)
185+
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
186+
187+
def test_parse_verification_key_pem_bytes(self):
188+
verification_key = parse_verification_key(public_key2_pem.encode())
189+
self.assertIsInstance(verification_key, Ed25519PublicKey)
190+
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
191+
192+
def test_parse_verification_key_b64_str(self):
193+
verification_key = parse_verification_key(public_key2_b64.decode())
194+
self.assertIsInstance(verification_key, Ed25519PublicKey)
195+
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
196+
197+
def test_parse_verification_key_b64_bytes(self):
198+
verification_key = parse_verification_key(public_key2_b64)
199+
self.assertIsInstance(verification_key, Ed25519PublicKey)
200+
self.assertEqual(verification_key.public_bytes_raw(), public_key2_bytes)
201+
202+
def test_parse_verification_key_invalid(self):
203+
with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"):
204+
parse_verification_key("foo")
205+
206+
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": "foo"})
207+
def test_parse_verification_key_invalid_env(self):
208+
with self.assertRaisesRegex(
209+
ValueError, "invalid DISPATCH_VERIFICATION_KEY 'foo'"
210+
):
211+
parse_verification_key(None)

tests/test_fastapi.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@
2323
from dispatch.status import Status
2424
from dispatch.test import EndpointClient
2525

26-
public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----"
27-
public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----"
28-
public_key = public_key_from_pem(public_key_pem)
29-
public_key_bytes = public_key.public_bytes_raw()
30-
public_key_b64 = base64.b64encode(public_key_bytes)
31-
3226

3327
def create_dispatch_instance(app, endpoint):
3428
return Dispatch(
@@ -107,71 +101,6 @@ def my_function(input: Input) -> Output:
107101

108102
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")
109103

110-
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem})
111-
def test_parse_verification_key_env_pem_str(self):
112-
verification_key = parse_verification_key(None)
113-
self.assertIsInstance(verification_key, Ed25519PublicKey)
114-
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
115-
116-
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem2})
117-
def test_parse_verification_key_env_pem_escaped_newline_str(self):
118-
verification_key = parse_verification_key(None)
119-
self.assertIsInstance(verification_key, Ed25519PublicKey)
120-
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
121-
122-
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_b64.decode()})
123-
def test_parse_verification_key_env_b64_str(self):
124-
verification_key = parse_verification_key(None)
125-
self.assertIsInstance(verification_key, Ed25519PublicKey)
126-
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
127-
128-
def test_parse_verification_key_none(self):
129-
# The verification key is optional. Both Dispatch(verification_key=...) and
130-
# DISPATCH_VERIFICATION_KEY may be omitted/None.
131-
verification_key = parse_verification_key(None)
132-
self.assertIsNone(verification_key)
133-
134-
def test_parse_verification_key_ed25519publickey(self):
135-
verification_key = parse_verification_key(public_key)
136-
self.assertIsInstance(verification_key, Ed25519PublicKey)
137-
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
138-
139-
def test_parse_verification_key_pem_str(self):
140-
verification_key = parse_verification_key(public_key_pem)
141-
self.assertIsInstance(verification_key, Ed25519PublicKey)
142-
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
143-
144-
def test_parse_verification_key_pem_escaped_newline_str(self):
145-
verification_key = parse_verification_key(public_key_pem2)
146-
self.assertIsInstance(verification_key, Ed25519PublicKey)
147-
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
148-
149-
def test_parse_verification_key_pem_bytes(self):
150-
verification_key = parse_verification_key(public_key_pem.encode())
151-
self.assertIsInstance(verification_key, Ed25519PublicKey)
152-
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
153-
154-
def test_parse_verification_key_b64_str(self):
155-
verification_key = parse_verification_key(public_key_b64.decode())
156-
self.assertIsInstance(verification_key, Ed25519PublicKey)
157-
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
158-
159-
def test_parse_verification_key_b64_bytes(self):
160-
verification_key = parse_verification_key(public_key_b64)
161-
self.assertIsInstance(verification_key, Ed25519PublicKey)
162-
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
163-
164-
def test_parse_verification_key_invalid(self):
165-
with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"):
166-
parse_verification_key("foo")
167-
168-
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": "foo"})
169-
def test_parse_verification_key_invalid_env(self):
170-
with self.assertRaisesRegex(
171-
ValueError, "invalid DISPATCH_VERIFICATION_KEY 'foo'"
172-
):
173-
parse_verification_key(None)
174-
175104

176105
def response_output(resp: function_pb.RunResponse) -> Any:
177106
return any_unpickle(resp.exit.result.output)

tests/test_http.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
import struct
55
import threading
66
import unittest
7+
from http.server import HTTPServer
78
from typing import Any
89
from unittest import mock
910

1011
import fastapi
1112
import google.protobuf.any_pb2
1213
import google.protobuf.wrappers_pb2
1314
import httpx
14-
from http.server import HTTPServer
1515
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
1616

1717
from dispatch.experimental.durable.registry import clear_functions
18-
from dispatch.http import Dispatch
1918
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
19+
from dispatch.http import Dispatch
2020
from dispatch.proto import _any_unpickle as any_unpickle
2121
from dispatch.sdk.v1 import call_pb2 as call_pb
2222
from dispatch.sdk.v1 import function_pb2 as function_pb
@@ -30,6 +30,8 @@
3030
public_key_bytes = public_key.public_bytes_raw()
3131
public_key_b64 = base64.b64encode(public_key_bytes)
3232

33+
from datetime import datetime
34+
3335

3436
def create_dispatch_instance(endpoint: str):
3537
return Dispatch(
@@ -43,11 +45,14 @@ def create_dispatch_instance(endpoint: str):
4345

4446
class TestHTTP(unittest.TestCase):
4547
def setUp(self):
46-
self.server_address = ('127.0.0.1', 9999)
48+
self.server_address = ("127.0.0.1", 9999)
4749
self.endpoint = f"http://{self.server_address[0]}:{self.server_address[1]}"
50+
self.dispatch = create_dispatch_instance(self.endpoint)
4851
self.client = httpx.Client(timeout=1.0)
49-
self.server = HTTPServer(self.server_address, create_dispatch_instance(self.endpoint))
50-
self.thread = threading.Thread(target=self.server.serve_forever)
52+
self.server = HTTPServer(self.server_address, self.dispatch)
53+
self.thread = threading.Thread(
54+
target=lambda: self.server.serve_forever(poll_interval=0.05)
55+
)
5156
self.thread.start()
5257

5358
def tearDown(self):
@@ -56,7 +61,50 @@ def tearDown(self):
5661
self.client.close()
5762
self.server.server_close()
5863

59-
def test_Dispatch_defaults(self):
64+
def test_content_length_missing(self):
6065
resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
6166
body = resp.read()
6267
self.assertEqual(resp.status_code, 400)
68+
self.assertEqual(
69+
body, b'{"code":"invalid_argument","message":"content length is required"}'
70+
)
71+
72+
def test_content_length_too_large(self):
73+
resp = self.client.post(
74+
f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run",
75+
data=b"a" * 16_000_001,
76+
)
77+
body = resp.read()
78+
self.assertEqual(resp.status_code, 400)
79+
self.assertEqual(
80+
body, b'{"code":"invalid_argument","message":"content length is too large"}'
81+
)
82+
83+
def test_simple_request(self):
84+
@self.dispatch.registry.primitive_function
85+
def my_function(input: Input) -> Output:
86+
return Output.value(
87+
f"You told me: '{input.input}' ({len(input.input)} characters)"
88+
)
89+
90+
client = EndpointClient.from_url(self.endpoint)
91+
92+
pickled = pickle.dumps("Hello World!")
93+
input_any = google.protobuf.any_pb2.Any()
94+
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))
95+
96+
req = function_pb.RunRequest(
97+
function=my_function.name,
98+
input=input_any,
99+
)
100+
101+
resp = client.run(req)
102+
103+
self.assertIsInstance(resp, function_pb.RunResponse)
104+
105+
resp.exit.result.output.Unpack(
106+
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
107+
)
108+
output = pickle.loads(output_bytes.value)
109+
110+
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")

0 commit comments

Comments
 (0)