11from datetime import datetime
22from datetime import timedelta
3+ from typing import Awaitable
4+ from typing import Callable
35from typing import Dict
46from typing import List
57from typing import Optional
@@ -87,13 +89,20 @@ def identity(self) -> str:
8789
8890
8991class OAuth2Backend (AuthenticationBackend ):
90- def __init__ (self , config : OAuth2Config ) -> None :
92+ """Authentication backend for AuthenticationMiddleware."""
93+
94+ def __init__ (
95+ self ,
96+ config : OAuth2Config ,
97+ callback : Callable [[User ], Union [Awaitable [None ], None ]] = None ,
98+ ) -> None :
9199 Auth .set_http (config .allow_http )
92100 Auth .set_secret (config .jwt_secret )
93101 Auth .set_expires (config .jwt_expires )
94102 Auth .set_algorithm (config .jwt_algorithm )
95103 for client in config .clients :
96104 Auth .register_client (client )
105+ self .callback = callback
97106
98107 async def authenticate (self , request : Request ) -> Optional [Tuple ["Auth" , "User" ]]:
99108 authorization = request .headers .get (
@@ -106,18 +115,39 @@ async def authenticate(self, request: Request) -> Optional[Tuple["Auth", "User"]
106115 return Auth (), User ()
107116
108117 user = Auth .jwt_decode (param )
109- return Auth (user .pop ("scope" , [])), User (user )
118+ auth , user = Auth (user .pop ("scope" , [])), User (user )
119+
120+ # Call the callback function on authentication
121+ if callable (self .callback ):
122+ coroutine = self .callback (user )
123+ if issubclass (type (coroutine ), Awaitable ):
124+ await coroutine
125+ return auth , user
110126
111127
112128class OAuth2Middleware :
129+ """Wrapper for the Starlette AuthenticationMiddleware."""
130+
113131 auth_middleware : AuthenticationMiddleware = None
114132
115- def __init__ (self , app : ASGIApp , config : Union [OAuth2Config , dict ]) -> None :
133+ def __init__ (
134+ self ,
135+ app : ASGIApp ,
136+ config : Union [OAuth2Config , dict ],
137+ callback : Callable [[User ], Union [Awaitable [None ], None ]] = None ,
138+ ** kwargs , # AuthenticationMiddleware kwargs
139+ ) -> None :
140+ """Initiates the middleware with the given configuration.
141+
142+ :param app: FastAPI application instance
143+ :param config: middleware configuration
144+ :param callback: callback function to be called after authentication
145+ """
116146 if isinstance (config , dict ):
117147 config = OAuth2Config (** config )
118148 elif not isinstance (config , OAuth2Config ):
119149 raise TypeError ("config is not a valid type" )
120- self .auth_middleware = AuthenticationMiddleware (app , OAuth2Backend (config ) )
150+ self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), ** kwargs )
121151
122152 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
123153 await self .auth_middleware (scope , receive , send )
0 commit comments