11from datetime import datetime
22from datetime import timedelta
3+ from typing import Any
34from typing import Awaitable
45from typing import Callable
56from typing import Dict
67from typing import List
78from typing import Optional
9+ from typing import Sequence
810from typing import Tuple
911from typing import Union
1012
2123from starlette .types import Scope
2224from starlette .types import Send
2325
26+ from .claims import Claims
2427from .client import OAuth2Client
2528from .config import OAuth2Config
2629from .core import OAuth2Core
@@ -36,6 +39,17 @@ class Auth(AuthCredentials):
3639 scopes : List [str ]
3740 clients : Dict [str , OAuth2Core ] = {}
3841
42+ provider : str
43+ default_provider : str = "local"
44+
45+ def __init__ (
46+ self ,
47+ scopes : Optional [Sequence [str ]] = None ,
48+ provider : str = default_provider ,
49+ ) -> None :
50+ super ().__init__ (scopes )
51+ self .provider = provider
52+
3953 @classmethod
4054 def set_http (cls , http : bool ) -> None :
4155 cls .http = http
@@ -79,19 +93,29 @@ def is_authenticated(self) -> bool:
7993
8094 @property
8195 def display_name (self ) -> str :
82- return self .get ("display_name" , "" ) # name
96+ return self .__getprop__ ("display_name" )
8397
8498 @property
8599 def identity (self ) -> str :
86- return self .get ("identity" , "" ) # username
100+ return self .__getprop__ ("identity" )
87101
88102 @property
89103 def picture (self ) -> str :
90- return self .get ("picture" , "" ) # image
104+ return self .__getprop__ ("picture" )
91105
92106 @property
93107 def email (self ) -> str :
94- return self .get ("email" , "" ) # email
108+ return self .__getprop__ ("email" )
109+
110+ def use_claims (self , claims : Claims ) -> "User" :
111+ for attr , item in claims .items ():
112+ self [attr ] = self .__getprop__ (item )
113+ return self
114+
115+ def __getprop__ (self , item , default = "" ) -> Any :
116+ if callable (item ):
117+ return item (self )
118+ return self .get (item , default )
95119
96120
97121class OAuth2Backend (AuthenticationBackend ):
@@ -120,8 +144,12 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
120144 if not scheme or not param :
121145 return Auth (), User ()
122146
123- user = Auth .jwt_decode (param )
124- auth , user = Auth (user .pop ("scope" , [])), User (user )
147+ user = User (Auth .jwt_decode (param ))
148+ user .update (provider = user .get ("provider" , Auth .default_provider ))
149+ auth = Auth (user .pop ("scope" , []), user .get ("provider" ))
150+ client = Auth .clients .get (auth .provider )
151+ claims = client .claims if client else Claims ()
152+ user = user .use_claims (claims )
125153
126154 # Call the callback function on authentication
127155 if callable (self .callback ):
0 commit comments