1+ from datetime import datetime
2+ from datetime import timedelta
3+ from typing import Dict
14from typing import List
25from typing import Optional
36from typing import Tuple
47from typing import Union
58
69from fastapi .security .utils import get_authorization_scheme_param
10+ from jose .jwt import decode as jwt_decode
11+ from jose .jwt import encode as jwt_encode
712from starlette .authentication import AuthenticationBackend
813from starlette .middleware .authentication import AuthenticationMiddleware
914from starlette .requests import Request
1217from starlette .types import Scope
1318from starlette .types import Send
1419
20+ from .client import OAuth2Client
1521from .config import OAuth2Config
16- from .utils import jwt_decode
22+ from .core import OAuth2Core
1723
1824
1925class Auth :
26+ secret : str
27+ expires : int
28+ algorithm : str
2029 scopes : List [str ]
30+ clients : Dict [str , OAuth2Core ] = {}
2131
2232 def __init__ (self , scopes : Optional [List [str ]] = None ) -> None :
2333 self .scopes = scopes or []
2434
35+ @classmethod
36+ def set_secret (cls , secret : str ) -> None :
37+ cls .secret = secret
38+
39+ @classmethod
40+ def set_expires (cls , expires : int ) -> None :
41+ cls .expires = expires
42+
43+ @classmethod
44+ def set_algorithm (cls , algorithm : str ) -> None :
45+ cls .algorithm = algorithm
46+
47+ @classmethod
48+ def register_client (cls , client : OAuth2Client ) -> None :
49+ cls .clients [client .backend .name ] = OAuth2Core (client )
50+
51+ @classmethod
52+ def jwt_encode (cls , data : dict ) -> str :
53+ return jwt_encode (data , cls .secret , algorithm = cls .algorithm )
54+
55+ @classmethod
56+ def jwt_decode (cls , token : str ) -> dict :
57+ return jwt_decode (token , cls .secret , algorithms = [cls .algorithm ])
58+
59+ @classmethod
60+ def jwt_create (cls , token_data : dict ) -> str :
61+ expire = datetime .utcnow () + timedelta (minutes = cls .expires )
62+ return cls .jwt_encode ({** token_data , "exp" : expire })
63+
2564
2665class User (dict ):
2766 is_authenticated : bool
@@ -32,30 +71,34 @@ def __init__(self, seq: Optional[dict] = None, **kwargs) -> None:
3271
3372
3473class OAuth2Backend (AuthenticationBackend ):
74+ def __init__ (self , config : OAuth2Config ) -> None :
75+ Auth .set_secret (config .jwt_secret )
76+ Auth .set_expires (config .jwt_expires )
77+ Auth .set_algorithm (config .jwt_algorithm )
78+ OAuth2Core .allow_http = config .allow_http
79+ for client in config .clients :
80+ Auth .register_client (client )
81+
3582 async def authenticate (self , request : Request ) -> Optional [Tuple ["Auth" , "User" ]]:
3683 authorization = request .cookies .get ("Authorization" )
3784 scheme , param = get_authorization_scheme_param (authorization )
3885
3986 if not scheme or not param :
4087 return Auth (), User ()
4188
42- user = jwt_decode (param )
43- scopes = user .pop ("scope" )
44- return Auth (scopes ), User (user )
89+ user = Auth .jwt_decode (param )
90+ return Auth (user .pop ("scope" )), User (user )
4591
4692
4793class OAuth2Middleware :
48- config : OAuth2Config
49- auth_middleware : AuthenticationMiddleware
94+ auth_middleware : AuthenticationMiddleware = None
5095
5196 def __init__ (self , app : ASGIApp , config : Union [OAuth2Config , dict ]) -> None :
52- if isinstance (config , OAuth2Config ):
53- self .config = config
54- elif isinstance (config , dict ):
55- self .config = OAuth2Config (** config )
56- else :
97+ if isinstance (config , dict ):
98+ config = OAuth2Config (** config )
99+ elif not isinstance (config , OAuth2Config ):
57100 raise TypeError ("config is not a valid type" )
58- self .auth_middleware = AuthenticationMiddleware (app , OAuth2Backend ())
101+ self .auth_middleware = AuthenticationMiddleware (app , OAuth2Backend (config ))
59102
60103 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
61104 await self .auth_middleware (scope , receive , send )
0 commit comments