Skip to content

Commit 49a1c2f

Browse files
authored
Merge pull request #120 from stealthrocket/verification-key-str
Improve UX around verification keys
2 parents 5139a05 + 76ce71e commit 49a1c2f

File tree

2 files changed

+112
-18
lines changed

2 files changed

+112
-18
lines changed

src/dispatch/fastapi.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
self,
5454
app: fastapi.FastAPI,
5555
endpoint: str | None = None,
56-
verification_key: Ed25519PublicKey | None = None,
56+
verification_key: Ed25519PublicKey | str | bytes | None = None,
5757
api_key: str | None = None,
5858
api_url: str | None = None,
5959
):
@@ -70,7 +70,7 @@ def __init__(
7070
7171
verification_key: Key to use when verifying signed requests. Uses
7272
the value of the DISPATCH_VERIFICATION_KEY environment variable
73-
by default. The environment variable is expected to carry an
73+
if omitted. The environment variable is expected to carry an
7474
Ed25519 public key in base64 or PEM format.
7575
If not set, request signature verification is disabled (a warning
7676
will be logged by the constructor).
@@ -99,21 +99,6 @@ def __init__(
9999
"missing application endpoint: set it with the DISPATCH_ENDPOINT_URL environment variable"
100100
)
101101

102-
if not verification_key:
103-
try:
104-
verification_key_raw = os.environ["DISPATCH_VERIFICATION_KEY"]
105-
except KeyError:
106-
pass
107-
else:
108-
# Be forgiving when accepting keys in PEM format.
109-
verification_key_raw = verification_key_raw.replace("\\n", "\n")
110-
try:
111-
verification_key = public_key_from_pem(verification_key_raw)
112-
except ValueError:
113-
verification_key = public_key_from_bytes(
114-
base64.b64decode(verification_key_raw)
115-
)
116-
117102
logger.info("configuring Dispatch endpoint %s", endpoint)
118103

119104
parsed_url = urlparse(endpoint)
@@ -122,6 +107,7 @@ def __init__(
122107
f"{endpoint_from} must be a full URL with protocol and domain (e.g., https://example.com)"
123108
)
124109

110+
verification_key = parse_verification_key(verification_key)
125111
if verification_key:
126112
base64_key = base64.b64encode(verification_key.public_bytes_raw()).decode()
127113
logger.info("verifying request signatures using key %s", base64_key)
@@ -137,6 +123,40 @@ def __init__(
137123
app.mount("/dispatch.sdk.v1.FunctionService", function_service)
138124

139125

126+
def parse_verification_key(
127+
verification_key: Ed25519PublicKey | str | bytes | None,
128+
) -> Ed25519PublicKey | None:
129+
if isinstance(verification_key, Ed25519PublicKey):
130+
return verification_key
131+
132+
from_env = False
133+
if not verification_key:
134+
try:
135+
verification_key = os.environ["DISPATCH_VERIFICATION_KEY"]
136+
except KeyError:
137+
return None
138+
from_env = True
139+
140+
if isinstance(verification_key, bytes):
141+
verification_key = verification_key.decode()
142+
143+
# Be forgiving when accepting keys in PEM format, which may span
144+
# multiple lines. Users attempting to pass a PEM key via an environment
145+
# variable may accidentally include literal "\n" bytes rather than a
146+
# newline char (0xA).
147+
try:
148+
return public_key_from_pem(verification_key.replace("\\n", "\n"))
149+
except ValueError:
150+
pass
151+
152+
try:
153+
return public_key_from_bytes(base64.b64decode(verification_key.encode()))
154+
except ValueError:
155+
if from_env:
156+
raise ValueError(f"invalid DISPATCH_VERIFICATION_KEY '{verification_key}'")
157+
raise ValueError(f"invalid verification key '{verification_key}'")
158+
159+
140160
class _ConnectResponse(fastapi.Response):
141161
media_type = "application/grpc+proto"
142162

tests/test_fastapi.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import os
23
import pickle
34
import unittest
@@ -8,17 +9,25 @@
89
import google.protobuf.any_pb2
910
import google.protobuf.wrappers_pb2
1011
import httpx
12+
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
1113
from fastapi.testclient import TestClient
1214

1315
from dispatch.experimental.durable.registry import clear_functions
14-
from dispatch.fastapi import Dispatch
16+
from dispatch.fastapi import Dispatch, parse_verification_key
1517
from dispatch.function import Arguments, Error, Function, Input, Output
1618
from dispatch.proto import _any_unpickle as any_unpickle
1719
from dispatch.sdk.v1 import call_pb2 as call_pb
1820
from dispatch.sdk.v1 import function_pb2 as function_pb
21+
from dispatch.signature import public_key_from_pem
1922
from dispatch.status import Status
2023
from dispatch.test import EndpointClient
2124

25+
public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----"
26+
public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----"
27+
public_key = public_key_from_pem(public_key_pem)
28+
public_key_bytes = public_key.public_bytes_raw()
29+
public_key_b64 = base64.b64encode(public_key_bytes)
30+
2231

2332
def create_dispatch_instance(app, endpoint):
2433
return Dispatch(
@@ -98,6 +107,71 @@ def my_function(input: Input) -> Output:
98107

99108
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")
100109

110+
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem})
111+
def test_parse_verification_key_env_pem_str(self):
112+
verification_key = parse_verification_key(None)
113+
self.assertIsInstance(verification_key, Ed25519PublicKey)
114+
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
115+
116+
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem2})
117+
def test_parse_verification_key_env_pem_escaped_newline_str(self):
118+
verification_key = parse_verification_key(None)
119+
self.assertIsInstance(verification_key, Ed25519PublicKey)
120+
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
121+
122+
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_b64.decode()})
123+
def test_parse_verification_key_env_b64_str(self):
124+
verification_key = parse_verification_key(None)
125+
self.assertIsInstance(verification_key, Ed25519PublicKey)
126+
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
127+
128+
def test_parse_verification_key_none(self):
129+
# The verification key is optional. Both Dispatch(verification_key=...) and
130+
# DISPATCH_VERIFICATION_KEY may be omitted/None.
131+
verification_key = parse_verification_key(None)
132+
self.assertIsNone(verification_key)
133+
134+
def test_parse_verification_key_ed25519publickey(self):
135+
verification_key = parse_verification_key(public_key)
136+
self.assertIsInstance(verification_key, Ed25519PublicKey)
137+
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
138+
139+
def test_parse_verification_key_pem_str(self):
140+
verification_key = parse_verification_key(public_key_pem)
141+
self.assertIsInstance(verification_key, Ed25519PublicKey)
142+
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
143+
144+
def test_parse_verification_key_pem_escaped_newline_str(self):
145+
verification_key = parse_verification_key(public_key_pem2)
146+
self.assertIsInstance(verification_key, Ed25519PublicKey)
147+
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
148+
149+
def test_parse_verification_key_pem_bytes(self):
150+
verification_key = parse_verification_key(public_key_pem.encode())
151+
self.assertIsInstance(verification_key, Ed25519PublicKey)
152+
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
153+
154+
def test_parse_verification_key_b64_str(self):
155+
verification_key = parse_verification_key(public_key_b64.decode())
156+
self.assertIsInstance(verification_key, Ed25519PublicKey)
157+
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
158+
159+
def test_parse_verification_key_b64_bytes(self):
160+
verification_key = parse_verification_key(public_key_b64)
161+
self.assertIsInstance(verification_key, Ed25519PublicKey)
162+
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)
163+
164+
def test_parse_verification_key_invalid(self):
165+
with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"):
166+
parse_verification_key("foo")
167+
168+
@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": "foo"})
169+
def test_parse_verification_key_invalid_env(self):
170+
with self.assertRaisesRegex(
171+
ValueError, "invalid DISPATCH_VERIFICATION_KEY 'foo'"
172+
):
173+
parse_verification_key(None)
174+
101175

102176
def response_output(resp: function_pb.RunResponse) -> Any:
103177
return any_unpickle(resp.exit.result.output)

0 commit comments

Comments
 (0)