@@ -137,29 +137,47 @@ def __init__(
137137 app .mount ("/dispatch.sdk.v1.FunctionService" , function_service )
138138
139139
140- class _GRPCResponse (fastapi .Response ):
140+ class _ConnectResponse (fastapi .Response ):
141141 media_type = "application/grpc+proto"
142142
143143
144+ class _ConnectError (fastapi .HTTPException ):
145+ __slots__ = ("status" , "code" , "message" )
146+
147+ def __init__ (self , status , code , message ):
148+ super ().__init__ (status )
149+ self .status = status
150+ self .code = code
151+ self .message = message
152+
153+
144154def _new_app (function_registry : Dispatch , verification_key : Ed25519PublicKey | None ):
145155 app = fastapi .FastAPI ()
146156
157+ @app .exception_handler (_ConnectError )
158+ async def on_error (request : fastapi .Request , exc : _ConnectError ):
159+ # https://connectrpc.com/docs/protocol/#error-end-stream
160+ return fastapi .responses .JSONResponse (
161+ status_code = exc .status , content = {"code" : exc .code , "message" : exc .message }
162+ )
163+
147164 @app .post (
148165 # The endpoint for execution is hardcoded at the moment. If the service
149166 # gains more endpoints, this should be turned into a dynamic dispatch
150167 # like the official gRPC server does.
151168 "/Run" ,
152- response_class = _GRPCResponse ,
169+ response_class = _ConnectResponse ,
153170 )
154171 async def execute (request : fastapi .Request ):
155172 # Raw request body bytes are only available through the underlying
156173 # starlette Request object's body method, which returns an awaitable,
157174 # forcing execute() to be async.
158175 data : bytes = await request .body ()
159-
160176 logger .debug ("handling run request with %d byte body" , len (data ))
161177
162- if verification_key is not None :
178+ if verification_key is None :
179+ logger .debug ("skipping request signature verification" )
180+ else :
163181 signed_request = Request (
164182 method = request .method ,
165183 url = str (request .url ),
@@ -169,29 +187,28 @@ async def execute(request: fastapi.Request):
169187 max_age = timedelta (minutes = 5 )
170188 try :
171189 verify_request (signed_request , verification_key , max_age )
172- except (InvalidSignature , ValueError ):
173- logger .error ("failed to verify request signature" , exc_info = True )
174- raise fastapi .HTTPException (
175- status_code = 403 , detail = "request signature is invalid"
176- )
177- else :
178- logger .debug ("skipping request signature verification" )
190+ except ValueError as e :
191+ raise _ConnectError (401 , "unauthenticated" , str (e ))
192+ except InvalidSignature as e :
193+ # The http_message_signatures package sometimes wraps does not
194+ # attach a message to the exception, so we set a default to
195+ # have some context about the reason for the error.
196+ message = str (e ) or "invalid signature"
197+ raise _ConnectError (403 , "permission_denied" , message )
179198
180199 req = function_pb .RunRequest .FromString (data )
181-
182200 if not req .function :
183- raise fastapi . HTTPException ( status_code = 400 , detail = "function is required" )
201+ raise _ConnectError ( 400 , "invalid_argument" , "function is required" )
184202
185203 try :
186204 func = function_registry ._functions [req .function ]
187205 except KeyError :
188206 logger .debug ("function '%s' not found" , req .function )
189- raise fastapi . HTTPException (
190- status_code = 404 , detail = f"Function '{ req .function } ' does not exist"
207+ raise _ConnectError (
208+ 404 , "not_found" , f"function '{ req .function } ' does not exist"
191209 )
192210
193211 input = Input (req )
194-
195212 logger .info ("running function '%s'" , req .function )
196213 try :
197214 output = func ._primitive_call (input )
@@ -203,8 +220,8 @@ async def execute(request: fastapi.Request):
203220 # so indicates a problem, and we return a 500 rather than attempt
204221 # to catch and categorize the error here.
205222 logger .error ("function '%s' fatal error" , req .function , exc_info = True )
206- raise fastapi . HTTPException (
207- status_code = 500 , detail = f"function '{ req .function } ' fatal error"
223+ raise _ConnectError (
224+ 500 , "internal" , f"function '{ req .function } ' fatal error"
208225 )
209226 else :
210227 response = output ._message
@@ -241,7 +258,6 @@ async def execute(request: fastapi.Request):
241258 )
242259
243260 logger .debug ("finished handling run request with status %s" , status .name )
244-
245261 return fastapi .Response (content = response .SerializeToString ())
246262
247263 return app
0 commit comments