@@ -56,6 +56,7 @@ class OAuth2Core:
5656 _oauth_client : Optional [WebApplicationClient ] = None
5757 _authorization_endpoint : str = None
5858 _token_endpoint : str = None
59+ _state : str = None
5960
6061 def __init__ (self , client : OAuth2Client ) -> None :
6162 self .client_id = client .client_id
@@ -83,6 +84,8 @@ def authorization_url(self, request: Request) -> str:
8384 oauth2_query_params = dict (state = state , scope = self .scope , redirect_uri = redirect_uri )
8485 oauth2_query_params .update (request .query_params )
8586
87+ self ._state = oauth2_query_params .get ("state" )
88+
8689 return str (self ._oauth_client .prepare_request_uri (
8790 self ._authorization_endpoint ,
8891 ** oauth2_query_params ,
@@ -96,6 +99,8 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
9699 raise OAuth2LoginError (400 , "'code' parameter was not found in callback request" )
97100 if not request .query_params .get ("state" ):
98101 raise OAuth2LoginError (400 , "'state' parameter was not found in callback request" )
102+ if request .query_params .get ("state" ) != self ._state :
103+ raise OAuth2LoginError (400 , "'state' parameter does not match" )
99104
100105 redirect_uri = self .get_redirect_uri (request )
101106 scheme = "http" if request .auth .http else "https"
0 commit comments