11#!/usr/bin/env python3
22# -*- coding: utf-8 -*-
3- from starlette .authentication import AuthenticationBackend
4- from fastapi import Request
3+ from typing import Any
4+
5+ from fastapi import Request , Response
6+ from starlette .authentication import AuthenticationBackend , AuthenticationError
7+ from starlette .requests import HTTPConnection
8+ from starlette .responses import JSONResponse
59
610from backend .app .common import jwt
11+ from backend .app .common .exception .errors import TokenError
12+ from backend .app .core .conf import settings
713from backend .app .database .db_mysql import async_db_session
814
915
16+ class _AuthenticationError (AuthenticationError ):
17+ """重写内部认证错误类"""
18+
19+ def __init__ (self , * , code : int = None , msg : str = None , headers : dict [str , Any ] | None = None ):
20+ self .code = code
21+ self .msg = msg
22+ self .headers = headers
23+
24+
1025class JwtAuthMiddleware (AuthenticationBackend ):
1126 """JWT 认证中间件"""
1227
28+ @staticmethod
29+ def auth_exception_handler (conn : HTTPConnection , exc : Exception ) -> Response :
30+ """覆盖内部认证错误处理"""
31+ code = getattr (exc , 'code' , 500 )
32+ msg = getattr (exc , 'msg' , 'Internal Server Error' )
33+ return JSONResponse (content = {'code' : code , 'msg' : msg , 'data' : None }, status_code = code )
34+
1335 async def authenticate (self , request : Request ):
1436 auth = request .headers .get ('Authorization' )
1537 if not auth :
@@ -19,9 +41,15 @@ async def authenticate(self, request: Request):
1941 if scheme .lower () != 'bearer' :
2042 return
2143
22- sub = await jwt .jwt_authentication (token )
44+ try :
45+ sub = await jwt .jwt_authentication (token )
46+ async with async_db_session () as db :
47+ user = await jwt .get_current_user (db , data = sub )
48+ except TokenError as exc :
49+ raise _AuthenticationError (code = exc .code , msg = exc .detail , headers = exc .headers )
50+ except Exception :
51+ import traceback
2352
24- async with async_db_session () as db :
25- user = await jwt .get_current_user (db , data = sub )
53+ raise _AuthenticationError (msg = traceback .format_exc () if settings .ENVIRONMENT == 'dev' else None )
2654
2755 return auth , user
0 commit comments