11import json
2- import os
32import re
4- from typing import Any , Dict , List , Optional
3+ from typing import Any
4+ from typing import Dict
5+ from typing import List
6+ from typing import Optional
7+ from urllib .parse import urljoin
58
69import httpx
710from oauthlib .oauth2 import WebApplicationClient
11+ from starlette .exceptions import HTTPException
812from starlette .requests import Request
913from starlette .responses import RedirectResponse
1014
11- from .config import JWT_EXPIRES
12- from .exceptions import OAuth2LoginError
13- from .utils import jwt_create
15+ from .client import OAuth2Client
16+
17+
18+ class OAuth2LoginError (HTTPException ):
19+ """Raised when any login-related error occurs
20+ (such as when user is not verified or if there was an attempt for fake login)
21+ """
1422
1523
1624class OAuth2Core :
@@ -19,7 +27,7 @@ class OAuth2Core:
1927 client_id : str = None
2028 client_secret : str = None
2129 callback_url : Optional [str ] = None
22- allow_insecure_http : bool = False
30+ allow_http : bool = False
2331 scope : Optional [List [str ]] = None
2432 state : Optional [str ] = None
2533 _oauth_client : Optional [WebApplicationClient ] = None
@@ -29,55 +37,47 @@ class OAuth2Core:
2937 token_endpoint : str = None
3038 userinfo_endpoint : str = None
3139
32- def __init__ (
33- self ,
34- client_id : str ,
35- client_secret : str ,
36- callback_url : Optional [str ] = None ,
37- allow_insecure_http : bool = False ,
38- scope : Optional [List [str ]] = None ,
39- ):
40- self .client_id = client_id
41- self .client_secret = client_secret
42- self .callback_url = callback_url
43- self .allow_insecure_http = allow_insecure_http
44- if allow_insecure_http :
45- os .environ ["OAUTHLIB_INSECURE_TRANSPORT" ] = "1"
46- self .scope = scope or self .scope
40+ def __init__ (self , client : OAuth2Client ) -> None :
41+ self .client_id = client .client_id
42+ self .client_secret = client .client_secret
43+ self .scope = client .scope or self .scope
44+ self .provider = client .backend .name
45+ self .authorization_endpoint = client .backend .AUTHORIZATION_URL
46+ self .token_endpoint = client .backend .ACCESS_TOKEN_URL
47+ self .userinfo_endpoint = "https://api.github.com/user"
48+ self .additional_headers = {"Content-Type" : "application/x-www-form-urlencoded" , "Accept" : "application/json" }
4749
4850 @property
4951 def oauth_client (self ) -> WebApplicationClient :
5052 if self ._oauth_client is None :
5153 self ._oauth_client = WebApplicationClient (self .client_id )
5254 return self ._oauth_client
5355
54- @property
55- def access_token (self ) -> Optional [str ]:
56- return self .oauth_client .access_token
57-
58- @property
59- def refresh_token (self ) -> Optional [str ]:
60- return self .oauth_client .refresh_token
56+ def get_redirect_uri (self , request : Request ) -> str :
57+ return urljoin (str (request .base_url ), "/oauth2/%s/token" % self .provider )
6158
6259 async def get_login_url (
6360 self ,
61+ request : Request ,
6462 * ,
6563 params : Optional [Dict [str , Any ]] = None ,
6664 state : Optional [str ] = None ,
6765 ) -> Any :
6866 self .state = state
6967 params = params or {}
68+ redirect_uri = self .get_redirect_uri (request )
7069 return self .oauth_client .prepare_request_uri (
71- self .authorization_endpoint , redirect_uri = self . callback_url , state = state , scope = self .scope , ** params
70+ self .authorization_endpoint , redirect_uri = redirect_uri , state = state , scope = self .scope , ** params
7271 )
7372
7473 async def login_redirect (
7574 self ,
75+ request : Request ,
7676 * ,
7777 params : Optional [Dict [str , Any ]] = None ,
7878 state : Optional [str ] = None ,
7979 ) -> RedirectResponse :
80- login_uri = await self .get_login_url (params = params , state = state )
80+ login_uri = await self .get_login_url (request , params = params , state = state )
8181 return RedirectResponse (login_uri , 303 )
8282
8383 async def get_token_data (
@@ -96,15 +96,14 @@ async def get_token_data(
9696 raise OAuth2LoginError (400 , "'state' parameter does not match" )
9797
9898 url = request .url
99- scheme = "http" if self .allow_insecure_http else "https"
100- current_path = f"{ scheme } ://{ url .netloc } { url .path } "
101- current_path = re .sub (r"^https?" , scheme , current_path )
99+ scheme = "http" if self .allow_http else "https"
102100 current_url = re .sub (r"^https?" , scheme , str (url ))
101+ redirect_uri = self .get_redirect_uri (request )
103102
104103 token_url , headers , content = self .oauth_client .prepare_token_request (
105104 self .token_endpoint ,
105+ redirect_url = redirect_uri ,
106106 authorization_response = current_url ,
107- redirect_url = self .callback_url or current_path ,
108107 code = request .query_params .get ("code" ),
109108 ** params ,
110109 )
@@ -129,13 +128,13 @@ async def token_redirect(
129128 headers : Optional [Dict [str , Any ]] = None ,
130129 ) -> RedirectResponse :
131130 token_data = await self .get_token_data (request , params = params , headers = headers )
132- access_token = jwt_create (token_data )
131+ access_token = request . auth . jwt_create (token_data )
133132 response = RedirectResponse (request .base_url )
134133 response .set_cookie (
135134 "Authorization" ,
136135 value = f"Bearer { access_token } " ,
137- httponly = self .allow_insecure_http ,
138- max_age = JWT_EXPIRES * 60 ,
139- expires = JWT_EXPIRES * 60 ,
136+ httponly = self .allow_http ,
137+ max_age = request . auth . expires ,
138+ expires = request . auth . expires ,
140139 )
141140 return response
0 commit comments