11"""FastAPI Users database adapter for SQLAlchemy."""
2- from typing import Optional , Type
2+ import uuid
3+ from typing import Any , Dict , Generic , Optional , Type , TypeVar
34
45from fastapi_users .db .base import BaseUserDatabase
5- from fastapi_users .models import UD
6- from pydantic import UUID4
7- from sqlalchemy import (
8- Boolean ,
9- Column ,
10- ForeignKey ,
11- Integer ,
12- String ,
13- delete ,
14- func ,
15- select ,
16- update ,
17- )
6+ from fastapi_users .models import ID , OAP , UP
7+ from sqlalchemy import Boolean , Column , ForeignKey , Integer , String , func , select
188from sqlalchemy .ext .asyncio import AsyncSession
199from sqlalchemy .ext .declarative import declared_attr
20- from sqlalchemy .orm import joinedload
2110from sqlalchemy .sql import Select
2211
23- from fastapi_users_db_sqlalchemy .guid import GUID
12+ from fastapi_users_db_sqlalchemy .generics import GUID
2413
2514__version__ = "3.0.1"
2615
16+ UUID_ID = uuid .UUID
2717
28- class SQLAlchemyBaseUserTable :
18+
19+ class SQLAlchemyBaseUserTable (Generic [ID ]):
2920 """Base SQLAlchemy users table definition."""
3021
3122 __tablename__ = "user"
3223
33- id = Column (GUID , primary_key = True )
34- email = Column (String (length = 320 ), unique = True , index = True , nullable = False )
35- hashed_password = Column (String (length = 1024 ), nullable = False )
36- is_active = Column (Boolean , default = True , nullable = False )
37- is_superuser = Column (Boolean , default = False , nullable = False )
38- is_verified = Column (Boolean , default = False , nullable = False )
24+ id : ID
25+ email : str = Column (String (length = 320 ), unique = True , index = True , nullable = False )
26+ hashed_password : str = Column (String (length = 1024 ), nullable = False )
27+ is_active : bool = Column (Boolean , default = True , nullable = False )
28+ is_superuser : bool = Column (Boolean , default = False , nullable = False )
29+ is_verified : bool = Column (Boolean , default = False , nullable = False )
30+
31+
32+ UP_SQLALCHEMY = TypeVar ("UP_SQLALCHEMY" , bound = SQLAlchemyBaseUserTable )
33+
3934
35+ class SQLAlchemyBaseUserTableUUID (SQLAlchemyBaseUserTable [UUID_ID ]):
36+ id : UUID_ID = Column (GUID , primary_key = True , default = uuid .uuid4 )
4037
41- class SQLAlchemyBaseOAuthAccountTable :
38+
39+ class SQLAlchemyBaseOAuthAccountTable (Generic [ID ]):
4240 """Base SQLAlchemy OAuth account table definition."""
4341
4442 __tablename__ = "oauth_account"
4543
46- id = Column (GUID , primary_key = True )
47- oauth_name = Column (String (length = 100 ), index = True , nullable = False )
48- access_token = Column (String (length = 1024 ), nullable = False )
49- expires_at = Column (Integer , nullable = True )
50- refresh_token = Column (String (length = 1024 ), nullable = True )
51- account_id = Column (String (length = 320 ), index = True , nullable = False )
52- account_email = Column (String (length = 320 ), nullable = False )
44+ id : ID
45+ oauth_name : str = Column (String (length = 100 ), index = True , nullable = False )
46+ access_token : str = Column (String (length = 1024 ), nullable = False )
47+ expires_at : Optional [int ] = Column (Integer , nullable = True )
48+ refresh_token : Optional [str ] = Column (String (length = 1024 ), nullable = True )
49+ account_id : str = Column (String (length = 320 ), index = True , nullable = False )
50+ account_email : str = Column (String (length = 320 ), nullable = False )
51+
52+
53+ class SQLAlchemyBaseOAuthAccountTableUUID (SQLAlchemyBaseOAuthAccountTable [UUID_ID ]):
54+ id : UUID_ID = Column (GUID , primary_key = True , default = uuid .uuid4 )
5355
5456 @declared_attr
5557 def user_id (cls ):
5658 return Column (GUID , ForeignKey ("user.id" , ondelete = "cascade" ), nullable = False )
5759
5860
59- class SQLAlchemyUserDatabase (BaseUserDatabase [UD ]):
61+ class SQLAlchemyUserDatabase (
62+ Generic [UP_SQLALCHEMY , ID ], BaseUserDatabase [UP_SQLALCHEMY , ID ]
63+ ):
6064 """
6165 Database adapter for SQLAlchemy.
6266
@@ -67,86 +71,97 @@ class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
6771 """
6872
6973 session : AsyncSession
70- user_table : Type [SQLAlchemyBaseUserTable ]
74+ user_table : Type [UP_SQLALCHEMY ]
7175 oauth_account_table : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]]
7276
7377 def __init__ (
7478 self ,
75- user_db_model : Type [UD ],
7679 session : AsyncSession ,
77- user_table : Type [SQLAlchemyBaseUserTable ],
80+ user_table : Type [UP_SQLALCHEMY ],
7881 oauth_account_table : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]] = None ,
7982 ):
80- super ().__init__ (user_db_model )
8183 self .session = session
8284 self .user_table = user_table
8385 self .oauth_account_table = oauth_account_table
8486
85- async def get (self , id : UUID4 ) -> Optional [UD ]:
87+ async def get (self , id : ID ) -> Optional [UP_SQLALCHEMY ]:
8688 statement = select (self .user_table ).where (self .user_table .id == id )
8789 return await self ._get_user (statement )
8890
89- async def get_by_email (self , email : str ) -> Optional [UD ]:
91+ async def get_by_email (self , email : str ) -> Optional [UP_SQLALCHEMY ]:
9092 statement = select (self .user_table ).where (
9193 func .lower (self .user_table .email ) == func .lower (email )
9294 )
9395 return await self ._get_user (statement )
9496
95- async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UD ]:
96- if self .oauth_account_table is not None :
97- statement = (
98- select (self .user_table )
99- .join (self .oauth_account_table )
100- .where (self .oauth_account_table .oauth_name == oauth )
101- .where (self .oauth_account_table .account_id == account_id )
102- )
103- return await self ._get_user (statement )
104-
105- async def create (self , user : UD ) -> UD :
106- user_table = self .user_table (** user .dict (exclude = {"oauth_accounts" }))
107- self .session .add (user_table )
108-
109- if self .oauth_account_table is not None :
110- for oauth_account in user .oauth_accounts :
111- oauth_account_table = self .oauth_account_table (
112- ** oauth_account .dict (), user_id = user .id
113- )
114- self .session .add (oauth_account_table )
97+ async def get_by_oauth_account (
98+ self , oauth : str , account_id : str
99+ ) -> Optional [UP_SQLALCHEMY ]:
100+ if self .oauth_account_table is None :
101+ raise NotImplementedError ()
102+
103+ statement = (
104+ select (self .user_table )
105+ .join (self .oauth_account_table )
106+ .where (self .oauth_account_table .oauth_name == oauth )
107+ .where (self .oauth_account_table .account_id == account_id )
108+ )
109+ return await self ._get_user (statement )
115110
111+ async def create (self , create_dict : Dict [str , Any ]) -> UP_SQLALCHEMY :
112+ user = self .user_table (** create_dict )
113+ self .session .add (user )
116114 await self .session .commit ()
117- return await self .get (user .id )
118-
119- async def update (self , user : UD ) -> UD :
120- user_table = await self .session .get (self .user_table , user .id )
121- for key , value in user .dict (exclude = {"oauth_accounts" }).items ():
122- setattr (user_table , key , value )
123- self .session .add (user_table )
124-
125- if self .oauth_account_table is not None :
126- for oauth_account in user .oauth_accounts :
127- statement = update (
128- self .oauth_account_table ,
129- whereclause = self .oauth_account_table .id == oauth_account .id ,
130- values = {** oauth_account .dict (), "user_id" : user .id },
131- )
132- await self .session .execute (statement )
115+ await self .session .refresh (user )
116+ return user
117+
118+ async def update (
119+ self , user : UP_SQLALCHEMY , update_dict : Dict [str , Any ]
120+ ) -> UP_SQLALCHEMY :
121+ for key , value in update_dict .items ():
122+ setattr (user , key , value )
123+ self .session .add (user )
124+ await self .session .commit ()
125+ await self .session .refresh (user )
126+ return user
133127
128+ async def delete (self , user : UP_SQLALCHEMY ) -> None :
129+ await self .session .delete (user )
134130 await self .session .commit ()
135131
136- return await self .get (user .id )
132+ async def add_oauth_account (
133+ self , user : UP_SQLALCHEMY , create_dict : Dict [str , Any ]
134+ ) -> UP_SQLALCHEMY :
135+ if self .oauth_account_table is None :
136+ raise NotImplementedError ()
137+
138+ oauth_account = self .oauth_account_table (** create_dict )
139+ self .session .add (oauth_account )
140+ user .oauth_accounts .append (oauth_account ) # type: ignore
141+ self .session .add (user )
137142
138- async def delete (self , user : UD ) -> None :
139- statement = delete (self .user_table , self .user_table .id == user .id )
140- await self .session .execute (statement )
141143 await self .session .commit ()
144+ await self .session .refresh (user )
142145
143- async def _get_user (self , statement : Select ) -> Optional [UD ]:
144- if self .oauth_account_table is not None :
145- statement = statement .options (joinedload ("oauth_accounts" ))
146+ return user
147+
148+ async def update_oauth_account (
149+ self , user : UP_SQLALCHEMY , oauth_account : OAP , update_dict : Dict [str , Any ]
150+ ) -> UP_SQLALCHEMY :
151+ if self .oauth_account_table is None :
152+ raise NotImplementedError ()
153+
154+ for key , value in update_dict .items ():
155+ setattr (oauth_account , key , value )
156+ self .session .add (oauth_account )
157+ await self .session .commit ()
158+ await self .session .refresh (user )
159+ return user
146160
161+ async def _get_user (self , statement : Select ) -> Optional [UP ]:
147162 results = await self .session .execute (statement )
148163 user = results .first ()
149164 if user is None :
150165 return None
151166
152- return self . user_db_model . from_orm ( user [0 ])
167+ return user [0 ]
0 commit comments