1616from jose .jwt import encode as jwt_encode
1717from starlette .authentication import AuthCredentials
1818from starlette .authentication import AuthenticationBackend
19+ from starlette .authentication import AuthenticationError
1920from starlette .authentication import BaseUser
2021from starlette .middleware .authentication import AuthenticationMiddleware
2122from starlette .requests import Request
23+ from starlette .requests import HTTPConnection
2224from starlette .responses import PlainTextResponse
25+ from starlette .responses import Response
2326from starlette .types import ASGIApp
2427from starlette .types import Receive
2528from starlette .types import Scope
@@ -111,9 +114,9 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
111114 try :
112115 token_data = Auth .jwt_decode (param )
113116 except JOSEError as e :
114- raise OAuth2AuthenticationError ( 401 , str (e ))
117+ raise AuthenticationError ( str (e ))
115118 if token_data ["exp" ] and token_data ["exp" ] < int (datetime .now (timezone .utc ).timestamp ()):
116- raise OAuth2AuthenticationError ( 401 , "Token expired" )
119+ raise AuthenticationError ( "Token expired" )
117120
118121 user = User (token_data )
119122 auth = Auth (user .pop ("scope" , []))
@@ -138,6 +141,7 @@ def __init__(
138141 app : ASGIApp ,
139142 config : Union [OAuth2Config , dict ],
140143 callback : Callable [[Auth , User ], Union [Awaitable [None ], None ]] = None ,
144+ on_error : Callable [[HTTPConnection , AuthenticationError ], Response ] | None = None ,
141145 ** kwargs , # AuthenticationMiddleware kwargs
142146 ) -> None :
143147 """Initiates the middleware with the given configuration.
@@ -151,9 +155,13 @@ def __init__(
151155 elif not isinstance (config , OAuth2Config ):
152156 raise TypeError ("config is not a valid type" )
153157 self .default_application_middleware = app
154- self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), ** kwargs )
158+ self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), on_error = on_error or self . on_error , ** kwargs )
155159
156160 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
157161 if scope ["type" ] == "http" :
158162 return await self .auth_middleware (scope , receive , send )
159163 await self .default_application_middleware (scope , receive , send )
164+
165+ @staticmethod
166+ def on_error (conn : HTTPConnection , exc : Exception ) -> Response :
167+ return PlainTextResponse (str (exc ), status_code = 401 )
0 commit comments