Skip to content

Commit 4624d5f

Browse files
Merge pull request #77 from stealthrocket/fix-fastapi-error-responses
FastAPI: return error responses formatted for the connectrpc protocol
2 parents 129293e + f9a8300 commit 4624d5f

File tree

4 files changed

+68
-24
lines changed

4 files changed

+68
-24
lines changed

src/dispatch/fastapi.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -137,29 +137,47 @@ def __init__(
137137
app.mount("/dispatch.sdk.v1.FunctionService", function_service)
138138

139139

140-
class _GRPCResponse(fastapi.Response):
140+
class _ConnectResponse(fastapi.Response):
141141
media_type = "application/grpc+proto"
142142

143143

144+
class _ConnectError(fastapi.HTTPException):
145+
__slots__ = ("status", "code", "message")
146+
147+
def __init__(self, status, code, message):
148+
super().__init__(status)
149+
self.status = status
150+
self.code = code
151+
self.message = message
152+
153+
144154
def _new_app(function_registry: Dispatch, verification_key: Ed25519PublicKey | None):
145155
app = fastapi.FastAPI()
146156

157+
@app.exception_handler(_ConnectError)
158+
async def on_error(request: fastapi.Request, exc: _ConnectError):
159+
# https://connectrpc.com/docs/protocol/#error-end-stream
160+
return fastapi.responses.JSONResponse(
161+
status_code=exc.status, content={"code": exc.code, "message": exc.message}
162+
)
163+
147164
@app.post(
148165
# The endpoint for execution is hardcoded at the moment. If the service
149166
# gains more endpoints, this should be turned into a dynamic dispatch
150167
# like the official gRPC server does.
151168
"/Run",
152-
response_class=_GRPCResponse,
169+
response_class=_ConnectResponse,
153170
)
154171
async def execute(request: fastapi.Request):
155172
# Raw request body bytes are only available through the underlying
156173
# starlette Request object's body method, which returns an awaitable,
157174
# forcing execute() to be async.
158175
data: bytes = await request.body()
159-
160176
logger.debug("handling run request with %d byte body", len(data))
161177

162-
if verification_key is not None:
178+
if verification_key is None:
179+
logger.debug("skipping request signature verification")
180+
else:
163181
signed_request = Request(
164182
method=request.method,
165183
url=str(request.url),
@@ -169,29 +187,28 @@ async def execute(request: fastapi.Request):
169187
max_age = timedelta(minutes=5)
170188
try:
171189
verify_request(signed_request, verification_key, max_age)
172-
except (InvalidSignature, ValueError):
173-
logger.error("failed to verify request signature", exc_info=True)
174-
raise fastapi.HTTPException(
175-
status_code=403, detail="request signature is invalid"
176-
)
177-
else:
178-
logger.debug("skipping request signature verification")
190+
except ValueError as e:
191+
raise _ConnectError(401, "unauthenticated", str(e))
192+
except InvalidSignature as e:
193+
# The http_message_signatures package sometimes wraps does not
194+
# attach a message to the exception, so we set a default to
195+
# have some context about the reason for the error.
196+
message = str(e) or "invalid signature"
197+
raise _ConnectError(403, "permission_denied", message)
179198

180199
req = function_pb.RunRequest.FromString(data)
181-
182200
if not req.function:
183-
raise fastapi.HTTPException(status_code=400, detail="function is required")
201+
raise _ConnectError(400, "invalid_argument", "function is required")
184202

185203
try:
186204
func = function_registry._functions[req.function]
187205
except KeyError:
188206
logger.debug("function '%s' not found", req.function)
189-
raise fastapi.HTTPException(
190-
status_code=404, detail=f"Function '{req.function}' does not exist"
207+
raise _ConnectError(
208+
404, "not_found", f"function '{req.function}' does not exist"
191209
)
192210

193211
input = Input(req)
194-
195212
logger.info("running function '%s'", req.function)
196213
try:
197214
output = func._primitive_call(input)
@@ -203,8 +220,8 @@ async def execute(request: fastapi.Request):
203220
# so indicates a problem, and we return a 500 rather than attempt
204221
# to catch and categorize the error here.
205222
logger.error("function '%s' fatal error", req.function, exc_info=True)
206-
raise fastapi.HTTPException(
207-
status_code=500, detail=f"function '{req.function}' fatal error"
223+
raise _ConnectError(
224+
500, "internal", f"function '{req.function}' fatal error"
208225
)
209226
else:
210227
response = output._message
@@ -241,7 +258,6 @@ async def execute(request: fastapi.Request):
241258
)
242259

243260
logger.debug("finished handling run request with status %s", status.name)
244-
245261
return fastapi.Response(content=response.SerializeToString())
246262

247263
return app

src/dispatch/signature/digest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import hmac
33

44
import http_sfv
5+
from http_message_signatures import InvalidSignature
56

67

78
def generate_content_digest(body: str | bytes) -> str:
@@ -34,7 +35,9 @@ def verify_content_digest(digest_header: str | bytes, body: str | bytes):
3435
digest = parsed_header["sha-256"].value
3536
expect_digest = hashlib.sha256(body).digest()
3637
else:
37-
raise ValueError("missing content digest")
38+
raise ValueError("missing content digest in http request header")
3839

3940
if not hmac.compare_digest(digest, expect_digest):
40-
raise ValueError("unexpected content digest")
41+
raise InvalidSignature(
42+
"digest of the request body does not match the Content-Digest header"
43+
)

tests/dispatch/signature/test_signature.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ def test_signature_too_old(self):
7474
def test_content_digest_invalid(self):
7575
sign_request(self.request, private_key, datetime.now())
7676
self.request.body = "foo"
77-
with self.assertRaisesRegex(ValueError, "unexpected content digest"):
77+
with self.assertRaisesRegex(
78+
InvalidSignature,
79+
"digest of the request body does not match the Content-Digest header",
80+
):
7881
verify_request(self.request, public_key, max_age=timedelta(minutes=1))
7982

8083
def test_signature_coverage(self):

tests/test_full.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22

33
import fastapi
4+
import httpx
45
from fastapi.testclient import TestClient
56

67
from dispatch import Call, Input, Output
@@ -39,8 +40,10 @@ def setUp(self):
3940
api_url="http://127.0.0.1:10000",
4041
)
4142

42-
http_client = TestClient(self.app, base_url="http://dispatch-service")
43-
self.app_client = function_service.client(http_client, signing_key=private_key)
43+
self.http_client = TestClient(self.app, base_url="http://dispatch-service")
44+
self.app_client = function_service.client(
45+
self.http_client, signing_key=private_key
46+
)
4447

4548
self.server = ServerTest()
4649
# shortcuts
@@ -68,3 +71,22 @@ def my_function(name: str) -> str:
6871
# Validate results.
6972
resp = self.servicer.response_for(dispatch_id)
7073
self.assertEqual(any_unpickle(resp.exit.result.output), "Hello world: 52")
74+
75+
def test_simple_missing_signature(self):
76+
@self.dispatch.function()
77+
def my_function(name: str) -> str:
78+
return f"Hello world: {name}"
79+
80+
[dispatch_id] = self.client.dispatch([my_function.build_call(52)])
81+
82+
self.app_client = function_service.client(self.http_client) # no signing key
83+
try:
84+
self.execute()
85+
except httpx.HTTPStatusError as e:
86+
assert e.response.status_code == 403
87+
assert e.response.json() == {
88+
"code": "permission_denied",
89+
"message": 'Expected "Signature-Input" header field to be present',
90+
}
91+
else:
92+
assert False, "Expected HTTPStatusError"

0 commit comments

Comments
 (0)