|
| 1 | +import os |
| 2 | +from datetime import datetime, timedelta |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +from fastapi import APIRouter |
| 6 | +from fastapi import Depends, HTTPException |
| 7 | +from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel |
| 8 | +from fastapi.responses import RedirectResponse |
| 9 | +from fastapi.security import HTTPBearer |
| 10 | +from fastapi.security import OAuth2 |
| 11 | +from fastapi.security.base import SecurityBase |
| 12 | +from fastapi.security.utils import get_authorization_scheme_param |
| 13 | +from fastapi_sso.sso.github import GithubSSO |
| 14 | +from jose import jwt |
| 15 | +from pydantic import BaseModel |
| 16 | +from starlette.requests import Request |
| 17 | +from starlette.status import HTTP_403_FORBIDDEN |
| 18 | + |
| 19 | +from config import CLIENT_ID, CLIENT_SECRET, redirect_url, SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES, \ |
| 20 | + redirect_url_main_page |
| 21 | + |
| 22 | +router = APIRouter() |
| 23 | + |
| 24 | +# config for github SSO |
| 25 | +os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" |
| 26 | + |
| 27 | +sso = GithubSSO( |
| 28 | + client_id=CLIENT_ID, |
| 29 | + client_secret=CLIENT_SECRET, |
| 30 | + redirect_uri=redirect_url, |
| 31 | + allow_insecure_http=True, |
| 32 | +) |
| 33 | + |
| 34 | +security = HTTPBearer() |
| 35 | + |
| 36 | + |
| 37 | +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): |
| 38 | + to_encode = data.copy() |
| 39 | + if expires_delta: |
| 40 | + expire = datetime.utcnow() + expires_delta |
| 41 | + else: |
| 42 | + expire = datetime.utcnow() + timedelta(minutes=15) |
| 43 | + to_encode.update({"exp": expire}) |
| 44 | + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
| 45 | + return encoded_jwt |
| 46 | + |
| 47 | + |
| 48 | +class Token(BaseModel): |
| 49 | + access_token: str |
| 50 | + token_type: str |
| 51 | + |
| 52 | + |
| 53 | +class TokenData(BaseModel): |
| 54 | + username: str = None |
| 55 | + |
| 56 | + |
| 57 | +class User(BaseModel): |
| 58 | + username: str |
| 59 | + email: str = None |
| 60 | + full_name: str = None |
| 61 | + disabled: bool = None |
| 62 | + |
| 63 | + |
| 64 | +class UserInDB(User): |
| 65 | + hashed_password: str |
| 66 | + |
| 67 | + |
| 68 | +class OAuth2PasswordBearerCookie(OAuth2): |
| 69 | + def __init__( |
| 70 | + self, |
| 71 | + tokenUrl: str, |
| 72 | + scheme_name: str = None, |
| 73 | + scopes: dict = None, |
| 74 | + auto_error: bool = True, |
| 75 | + ): |
| 76 | + if not scopes: |
| 77 | + scopes = {} |
| 78 | + flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) |
| 79 | + super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error) |
| 80 | + |
| 81 | + async def __call__(self, request: Request) -> Optional[str]: |
| 82 | + header_authorization: str = request.headers.get("Authorization") |
| 83 | + cookie_authorization: str = request.cookies.get("Authorization") |
| 84 | + |
| 85 | + header_scheme, header_param = get_authorization_scheme_param( |
| 86 | + header_authorization |
| 87 | + ) |
| 88 | + cookie_scheme, cookie_param = get_authorization_scheme_param( |
| 89 | + cookie_authorization |
| 90 | + ) |
| 91 | + |
| 92 | + if header_scheme.lower() == "bearer": |
| 93 | + authorization = True |
| 94 | + scheme = header_scheme |
| 95 | + param = header_param |
| 96 | + |
| 97 | + elif cookie_scheme.lower() == "bearer": |
| 98 | + authorization = True |
| 99 | + scheme = cookie_scheme |
| 100 | + param = cookie_param |
| 101 | + |
| 102 | + else: |
| 103 | + authorization = False |
| 104 | + |
| 105 | + if not authorization or scheme.lower() != "bearer": |
| 106 | + if self.auto_error: |
| 107 | + raise HTTPException( |
| 108 | + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
| 109 | + ) |
| 110 | + else: |
| 111 | + return None |
| 112 | + return param |
| 113 | + |
| 114 | + |
| 115 | +class BasicAuth(SecurityBase): |
| 116 | + def __init__(self, scheme_name: str = None, auto_error: bool = True): |
| 117 | + self.scheme_name = scheme_name or self.__class__.__name__ |
| 118 | + self.auto_error = auto_error |
| 119 | + |
| 120 | + async def __call__(self, request: Request) -> Optional[str]: |
| 121 | + authorization: str = request.headers.get("Authorization") |
| 122 | + scheme, param = get_authorization_scheme_param(authorization) |
| 123 | + if not authorization or scheme.lower() != "basic": |
| 124 | + if self.auto_error: |
| 125 | + raise HTTPException( |
| 126 | + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" |
| 127 | + ) |
| 128 | + else: |
| 129 | + return None |
| 130 | + return param |
| 131 | + |
| 132 | + |
| 133 | +basic_auth = BasicAuth(auto_error=False) |
| 134 | + |
| 135 | +oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token") |
| 136 | + |
| 137 | + |
| 138 | +async def get_current_user(token: str = Depends(oauth2_scheme)): |
| 139 | + credentials_exception = HTTPException( |
| 140 | + status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials" |
| 141 | + ) |
| 142 | + try: |
| 143 | + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| 144 | + return payload |
| 145 | + except PyJWTError: |
| 146 | + raise credentials_exception |
| 147 | + |
| 148 | + |
| 149 | +@router.get("/auth/login") |
| 150 | +async def auth_init(): |
| 151 | + """Initialize auth and redirect""" |
| 152 | + return await sso.get_login_redirect() |
| 153 | + |
| 154 | + |
| 155 | +@router.get("/auth/callback") |
| 156 | +async def auth_callback(request: Request): |
| 157 | + """Verify login""" |
| 158 | + user = await sso.verify_and_process(request) |
| 159 | + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
| 160 | + access_token = create_access_token( |
| 161 | + data=dict(user), expires_delta=access_token_expires |
| 162 | + ) |
| 163 | + print(dict(user)) |
| 164 | + response = RedirectResponse(redirect_url_main_page) |
| 165 | + response.set_cookie( |
| 166 | + "Authorization", |
| 167 | + value=f"Bearer {access_token}", |
| 168 | + httponly=True, |
| 169 | + max_age=1800, |
| 170 | + expires=1800, |
| 171 | + ) |
| 172 | + return response |
| 173 | + |
| 174 | + |
| 175 | +@router.get("/auth/logout") |
| 176 | +async def auth_logout(): |
| 177 | + response = RedirectResponse(redirect_url_main_page) |
| 178 | + response.delete_cookie("Authorization") |
| 179 | + return response |
| 180 | + |
| 181 | + |
| 182 | +@router.get("/auth/status") |
| 183 | +async def auth_status(user=Depends(get_current_user)): |
| 184 | + return { |
| 185 | + "ok": True, |
| 186 | + "user": user, |
| 187 | + } |
0 commit comments