99from urllib .parse import urljoin
1010
1111import httpx
12+ from oauthlib .oauth2 import OAuth2Error
1213from oauthlib .oauth2 import WebApplicationClient
13- from oauthlib .oauth2 .rfc6749 .errors import CustomOAuth2Error
1414from social_core .backends .oauth import BaseOAuth2
15+ from social_core .exceptions import AuthException
1516from social_core .strategy import BaseStrategy
1617from starlette .requests import Request
1718from starlette .responses import RedirectResponse
1819
1920from .claims import Claims
2021from .client import OAuth2Client
21- from .exceptions import OAuth2LoginError
22+ from .exceptions import OAuth2AuthenticationError
23+ from .exceptions import OAuth2BadCredentialsError
24+ from .exceptions import OAuth2InvalidRequestError
2225
2326
2427class OAuth2Strategy (BaseStrategy ):
@@ -92,11 +95,11 @@ def authorization_redirect(self, request: Request) -> RedirectResponse:
9295
9396 async def token_data (self , request : Request , ** httpx_client_args ) -> dict :
9497 if not request .query_params .get ("code" ):
95- raise OAuth2LoginError (400 , "'code' parameter was not found in callback request" )
98+ raise OAuth2InvalidRequestError (400 , "'code' parameter was not found in callback request" )
9699 if not request .query_params .get ("state" ):
97- raise OAuth2LoginError (400 , "'state' parameter was not found in callback request" )
100+ raise OAuth2InvalidRequestError (400 , "'state' parameter was not found in callback request" )
98101 if request .query_params .get ("state" ) != self ._state :
99- raise OAuth2LoginError (400 , "'state' parameter does not match" )
102+ raise OAuth2InvalidRequestError (400 , "'state' parameter does not match" )
100103
101104 redirect_uri = self .get_redirect_uri (request )
102105 scheme = "http" if request .auth .http else "https"
@@ -113,12 +116,16 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
113116 headers .update ({"Accept" : "application/json" })
114117 auth = httpx .BasicAuth (self .client_id , self .client_secret )
115118 async with httpx .AsyncClient (auth = auth , ** httpx_client_args ) as session :
116- response = await session .post (token_url , headers = headers , content = content )
117119 try :
120+ response = await session .post (token_url , headers = headers , content = content )
118121 self ._oauth_client .parse_request_body_response (json .dumps (response .json ()))
119122 return self .standardize (self .backend .user_data (self .access_token ))
120- except (CustomOAuth2Error , Exception ) as e :
121- raise OAuth2LoginError (400 , str (e ))
123+ except OAuth2Error as e :
124+ raise OAuth2InvalidRequestError (400 , str (e ))
125+ except httpx .HTTPError as e :
126+ raise OAuth2BadCredentialsError (400 , str (e ))
127+ except (AuthException , Exception ) as e :
128+ raise OAuth2AuthenticationError (401 , str (e ))
122129
123130 async def token_redirect (self , request : Request , ** kwargs ) -> RedirectResponse :
124131 access_token = request .auth .jwt_create (await self .token_data (request , ** kwargs ))
0 commit comments