1616from jose .jwt import encode as jwt_encode
1717from starlette .authentication import AuthCredentials
1818from starlette .authentication import AuthenticationBackend
19+ from starlette .authentication import AuthenticationError
1920from starlette .authentication import BaseUser
2021from starlette .middleware .authentication import AuthenticationMiddleware
22+ from starlette .requests import HTTPConnection
2123from starlette .requests import Request
22- from starlette .responses import PlainTextResponse
24+ from starlette .responses import Response
2325from starlette .types import ASGIApp
2426from starlette .types import Receive
2527from starlette .types import Scope
2830from .claims import Claims
2931from .config import OAuth2Config
3032from .core import OAuth2Core
31- from .exceptions import OAuth2AuthenticationError
3233
3334
3435class Auth (AuthCredentials ):
@@ -108,9 +109,12 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
108109 if not scheme or not param :
109110 return Auth (), User ()
110111
111- token_data = Auth .jwt_decode (param )
112+ try :
113+ token_data = Auth .jwt_decode (param )
114+ except JOSEError as e :
115+ raise AuthenticationError (str (e ))
112116 if token_data ["exp" ] and token_data ["exp" ] < int (datetime .now (timezone .utc ).timestamp ()):
113- raise OAuth2AuthenticationError ( 401 , "Token expired" )
117+ raise AuthenticationError ( "Token expired" )
114118
115119 user = User (token_data )
116120 auth = Auth (user .pop ("scope" , []))
@@ -135,7 +139,7 @@ def __init__(
135139 app : ASGIApp ,
136140 config : Union [OAuth2Config , dict ],
137141 callback : Callable [[Auth , User ], Union [Awaitable [None ], None ]] = None ,
138- ** kwargs , # AuthenticationMiddleware kwargs
142+ on_error : Optional [ Callable [[ HTTPConnection , AuthenticationError ], Response ]] = None ,
139143 ) -> None :
140144 """Initiates the middleware with the given configuration.
141145
@@ -148,13 +152,10 @@ def __init__(
148152 elif not isinstance (config , OAuth2Config ):
149153 raise TypeError ("config is not a valid type" )
150154 self .default_application_middleware = app
151- self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), ** kwargs )
155+ on_error = on_error or AuthenticationMiddleware .default_on_error
156+ self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), on_error = on_error )
152157
153158 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
154159 if scope ["type" ] == "http" :
155- try :
156- return await self .auth_middleware (scope , receive , send )
157- except (JOSEError , Exception ) as e :
158- middleware = PlainTextResponse (str (e ), status_code = 401 )
159- return await middleware (scope , receive , send )
160+ return await self .auth_middleware (scope , receive , send )
160161 await self .default_application_middleware (scope , receive , send )
0 commit comments