11#!/usr/bin/env python3
22# -*- coding: utf-8 -*-
33from datetime import datetime , timedelta
4- from typing import Any
54
6- from fastapi import Depends
5+ from fastapi import Depends , Request
76from fastapi .security import OAuth2PasswordBearer
7+ from fastapi .security .utils import get_authorization_scheme_param
88from jose import jwt
99from passlib .context import CryptContext
1010from pydantic import ValidationError
@@ -43,7 +43,7 @@ def password_verify(plain_password: str, hashed_password: str) -> bool:
4343 return pwd_context .verify (plain_password , hashed_password )
4444
4545
46- async def create_access_token (sub : int | Any , expires_delta : timedelta | None = None , ** kwargs ) -> str :
46+ async def create_access_token (sub : str , expires_delta : timedelta | None = None , ** kwargs ) -> tuple [ str , datetime ] :
4747 """
4848 Generate encryption token
4949
@@ -52,41 +52,86 @@ async def create_access_token(sub: int | Any, expires_delta: timedelta | None =
5252 :return:
5353 """
5454 if expires_delta :
55- expires = datetime .utcnow () + expires_delta
56- expire_seconds = expires_delta .total_seconds ()
55+ expire = datetime .utcnow () + expires_delta
56+ expire_seconds = int ( expires_delta .total_seconds () )
5757 else :
58- expires = datetime .utcnow () + timedelta (seconds = settings .TOKEN_EXPIRE_SECONDS )
58+ expire = datetime .utcnow () + timedelta (seconds = settings .TOKEN_EXPIRE_SECONDS )
5959 expire_seconds = settings .TOKEN_EXPIRE_SECONDS
60- to_encode = {'exp' : expires , 'sub' : str ( sub ) , ** kwargs }
60+ to_encode = {'exp' : expire , 'sub' : sub , ** kwargs }
6161 token = jwt .encode (to_encode , settings .TOKEN_SECRET_KEY , settings .TOKEN_ALGORITHM )
6262 if sub not in settings .TOKEN_WHITE_LIST :
63- await redis_client .delete (f'token :{ sub } :* ' )
64- key = f'token :{ sub } :{ token } '
63+ await redis_client .delete_prefix (f'{ settings . TOKEN_REDIS_PREFIX } :{ sub } :' )
64+ key = f'{ settings . TOKEN_REDIS_PREFIX } :{ sub } :{ token } '
6565 await redis_client .setex (key , expire_seconds , token )
66- return token
66+ return token , expire
6767
6868
69- async def jwt_authentication ( token : str = Depends ( oauth2_schema )) :
69+ async def create_refresh_token ( sub : str , expire_time : datetime | None = None , ** kwargs ) -> tuple [ str , datetime ] :
7070 """
71- JWT authentication
71+ Generate encryption refresh token
72+
73+ :param sub: The subject/userid of the JWT
74+ :param expire_time: expiry time
75+ :return:
76+ """
77+ if expire_time :
78+ expires = expire_time + timedelta (seconds = settings .TOKEN_EXPIRE_SECONDS )
79+ expire_seconds = int ((expires - datetime .utcnow ()).total_seconds ())
80+ else :
81+ expires = datetime .utcnow () + timedelta (seconds = settings .TOKEN_EXPIRE_SECONDS )
82+ expire_seconds = settings .TOKEN_EXPIRE_SECONDS
83+ to_encode = {'exp' : expires , 'sub' : sub , ** kwargs }
84+ token = jwt .encode (to_encode , settings .TOKEN_SECRET_KEY , settings .TOKEN_ALGORITHM )
85+ # 刷新 token 时,保持旧 token 有效,不执行删除操作
86+ key = f'{ settings .TOKEN_REDIS_PREFIX } :{ sub } :{ token } '
87+ await redis_client .setex (key , expire_seconds , token )
88+ return token , expires
89+
90+
91+ def get_token (request : Request ) -> str :
92+ """
93+ Get token for request header
94+
95+ :return:
96+ """
97+ authorization = request .headers .get ('Authorization' )
98+ scheme , param = get_authorization_scheme_param (authorization )
99+ if not authorization or scheme .lower () != 'bearer' :
100+ raise TokenError
101+ return param
102+
103+
104+ def jwt_decode (token : str ) -> tuple [int , list [int ]]:
105+ """
106+ Decode token
72107
73108 :param token:
74109 :return:
75110 """
76111 try :
77112 payload = jwt .decode (token , settings .TOKEN_SECRET_KEY , algorithms = [settings .TOKEN_ALGORITHM ])
78- user_id = payload .get ('sub' )
79- user_role = payload .get ('role_ids' )
80- if not user_id or not user_role :
81- raise TokenError
82- # 验证token是否有效
83- key = f'token:{ user_id } :{ token } '
84- valid_token = await redis_client .get (key )
85- if not valid_token :
113+ user_id = int (payload .get ('sub' ))
114+ user_roles = list (payload .get ('role_ids' ))
115+ if not user_id or not user_roles :
86116 raise TokenError
87- return {'payload' : payload , 'token' : token }
88- except (jwt .JWTError , ValidationError ):
117+ except (jwt .JWTError , ValidationError , Exception ):
118+ raise TokenError
119+ return user_id , user_roles
120+
121+
122+ async def jwt_authentication (token : str = Depends (oauth2_schema )) -> dict [str , int ]:
123+ """
124+ JWT authentication
125+
126+ :param token:
127+ :return:
128+ """
129+ user_id , _ = jwt_decode (token )
130+ key = f'{ settings .TOKEN_REDIS_PREFIX } :{ user_id } :{ token } '
131+ token_verify = await redis_client .get (key )
132+ if not token_verify :
89133 raise TokenError
134+ return {'sub' : user_id }
90135
91136
92137async def get_current_user (db : CurrentSession , data : dict = Depends (jwt_authentication )) -> User :
@@ -97,7 +142,7 @@ async def get_current_user(db: CurrentSession, data: dict = Depends(jwt_authenti
97142 :param data:
98143 :return:
99144 """
100- user_id = data .get ('payload' ). get ( ' sub' )
145+ user_id = data .get ('sub' )
101146 user = await UserDao .get_user_with_relation (db , user_id = user_id )
102147 if not user :
103148 raise TokenError
@@ -121,7 +166,7 @@ async def get_current_is_superuser(user: User = Depends(get_current_user)):
121166CurrentUser = Annotated [User , Depends (get_current_user )]
122167CurrentSuperUser = Annotated [bool , Depends (get_current_is_superuser )]
123168# Token dependency injection
124- JwtAuthentication = Annotated [dict , Depends (jwt_authentication )]
169+ CurrentJwtAuth = Annotated [dict , Depends (jwt_authentication )]
125170# Permission dependency injection
126171DependsUser = Depends (get_current_user )
127172DependsSuperUser = Depends (get_current_is_superuser )
0 commit comments