11"""FastAPI Users database adapter for SQLModel."""
22import uuid
3- from typing import Generic , Optional , Type , TypeVar
3+ from typing import TYPE_CHECKING , Any , Dict , Generic , Optional , Type
44
55from fastapi_users .db .base import BaseUserDatabase
6- from fastapi_users .models import BaseOAuthAccount , BaseUserDB
6+ from fastapi_users .models import ID , OAP , UP
77from pydantic import UUID4 , EmailStr
88from sqlalchemy .ext .asyncio import AsyncSession
99from sqlalchemy .orm import selectinload
1212__version__ = "0.1.2"
1313
1414
15- class SQLModelBaseUserDB (BaseUserDB , SQLModel ):
15+ class SQLModelBaseUserDB (SQLModel ):
1616 __tablename__ = "user"
1717
1818 id : UUID4 = Field (default_factory = uuid .uuid4 , primary_key = True , nullable = False )
19- email : EmailStr = Field (
20- sa_column_kwargs = {"unique" : True , "index" : True }, nullable = False
21- )
19+ if TYPE_CHECKING : # pragma: no cover
20+ email : str
21+ else :
22+ email : EmailStr = Field (
23+ sa_column_kwargs = {"unique" : True , "index" : True }, nullable = False
24+ )
25+ hashed_password : str
2226
2327 is_active : bool = Field (True , nullable = False )
2428 is_superuser : bool = Field (False , nullable = False )
@@ -28,68 +32,59 @@ class Config:
2832 orm_mode = True
2933
3034
31- class SQLModelBaseOAuthAccount (BaseOAuthAccount , SQLModel ):
35+ class SQLModelBaseOAuthAccount (SQLModel ):
3236 __tablename__ = "oauthaccount"
3337
3438 id : UUID4 = Field (default_factory = uuid .uuid4 , primary_key = True )
3539 user_id : UUID4 = Field (foreign_key = "user.id" , nullable = False )
40+ oauth_name : str = Field (index = True , nullable = False )
41+ access_token : str = Field (nullable = False )
42+ expires_at : Optional [int ] = Field (nullable = True )
43+ refresh_token : Optional [str ] = Field (nullable = True )
44+ account_id : str = Field (index = True , nullable = False )
45+ account_email : str = Field (nullable = False )
3646
3747 class Config :
3848 orm_mode = True
3949
4050
41- UD = TypeVar ("UD" , bound = SQLModelBaseUserDB )
42- OA = TypeVar ("OA" , bound = SQLModelBaseOAuthAccount )
43-
44-
45- class NotSetOAuthAccountTableError (Exception ):
46- """
47- OAuth table was not set in DB adapter but was needed.
48-
49- Raised when trying to create/update a user with OAuth accounts set
50- but no table were specified in the DB adapter.
51- """
52-
53- pass
54-
55-
56- class SQLModelUserDatabase (Generic [UD , OA ], BaseUserDatabase [UD ]):
51+ class SQLModelUserDatabase (Generic [UP , ID ], BaseUserDatabase [UP , ID ]):
5752 """
5853 Database adapter for SQLModel.
5954
60- :param user_db_model: SQLModel model of a DB representation of a user.
6155 :param session: SQLAlchemy session.
6256 """
6357
6458 session : Session
65- oauth_account_model : Optional [Type [OA ]]
59+ user_model : Type [UP ]
60+ oauth_account_model : Optional [Type [SQLModelBaseOAuthAccount ]]
6661
6762 def __init__ (
6863 self ,
69- user_db_model : Type [UD ],
7064 session : Session ,
71- oauth_account_model : Optional [Type [OA ]] = None ,
65+ user_model : Type [UP ],
66+ oauth_account_model : Optional [Type [SQLModelBaseOAuthAccount ]] = None ,
7267 ):
73- super ().__init__ (user_db_model )
7468 self .session = session
69+ self .user_model = user_model
7570 self .oauth_account_model = oauth_account_model
7671
77- async def get (self , id : UUID4 ) -> Optional [UD ]:
72+ async def get (self , id : ID ) -> Optional [UP ]:
7873 """Get a single user by id."""
79- return self .session .get (self .user_db_model , id )
74+ return self .session .get (self .user_model , id )
8075
81- async def get_by_email (self , email : str ) -> Optional [UD ]:
76+ async def get_by_email (self , email : str ) -> Optional [UP ]:
8277 """Get a single user by email."""
83- statement = select (self .user_db_model ).where (
84- func .lower (self .user_db_model .email ) == func .lower (email )
78+ statement = select (self .user_model ).where (
79+ func .lower (self .user_model .email ) == func .lower (email )
8580 )
8681 results = self .session .exec (statement )
8782 return results .first ()
8883
89- async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UD ]:
84+ async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UP ]:
9085 """Get a single user by OAuth account id."""
91- if not self .oauth_account_model :
92- raise NotSetOAuthAccountTableError ()
86+ if self .oauth_account_model is None :
87+ raise NotImplementedError ()
9388 statement = (
9489 select (self .oauth_account_model )
9590 .where (self .oauth_account_model .oauth_name == oauth )
@@ -102,72 +97,93 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD
10297 return user
10398 return None
10499
105- async def create (self , user : UD ) -> UD :
100+ async def create (self , create_dict : Dict [ str , Any ] ) -> UP :
106101 """Create a user."""
102+ user = self .user_model (** create_dict )
107103 self .session .add (user )
108- if self .oauth_account_model is not None :
109- for oauth_account in user .oauth_accounts : # type: ignore
110- self .session .add (oauth_account )
111104 self .session .commit ()
112105 self .session .refresh (user )
113106 return user
114107
115- async def update (self , user : UD ) -> UD :
116- """Update a user."""
108+ async def update (self , user : UP , update_dict : Dict [str , Any ]) -> UP :
109+ for key , value in update_dict .items ():
110+ setattr (user , key , value )
117111 self .session .add (user )
118- if self .oauth_account_model is not None :
119- for oauth_account in user .oauth_accounts : # type: ignore
120- self .session .add (oauth_account )
121112 self .session .commit ()
122113 self .session .refresh (user )
123114 return user
124115
125- async def delete (self , user : UD ) -> None :
126- """Delete a user."""
116+ async def delete (self , user : UP ) -> None :
127117 self .session .delete (user )
128118 self .session .commit ()
129119
120+ async def add_oauth_account (self , user : UP , create_dict : Dict [str , Any ]) -> UP :
121+ if self .oauth_account_model is None :
122+ raise NotImplementedError ()
130123
131- class SQLModelUserDatabaseAsync (Generic [UD , OA ], BaseUserDatabase [UD ]):
124+ oauth_account = self .oauth_account_model (** create_dict )
125+ user .oauth_accounts .append (oauth_account ) # type: ignore
126+ self .session .add (user )
127+
128+ self .session .commit ()
129+
130+ return user
131+
132+ async def update_oauth_account (
133+ self , user : UP , oauth_account : OAP , update_dict : Dict [str , Any ]
134+ ) -> UP :
135+ if self .oauth_account_model is None :
136+ raise NotImplementedError ()
137+
138+ for key , value in update_dict .items ():
139+ setattr (oauth_account , key , value )
140+ self .session .add (oauth_account )
141+ self .session .commit ()
142+
143+ return user
144+
145+
146+ class SQLModelUserDatabaseAsync (Generic [UP , ID ], BaseUserDatabase [UP , ID ]):
132147 """
133148 Database adapter for SQLModel working purely asynchronously.
134149
135- :param user_db_model : SQLModel model of a DB representation of a user.
150+ :param user_model : SQLModel model of a DB representation of a user.
136151 :param session: SQLAlchemy async session.
137152 """
138153
139154 session : AsyncSession
140- oauth_account_model : Optional [Type [OA ]]
155+ user_model : Type [UP ]
156+ oauth_account_model : Optional [Type [SQLModelBaseOAuthAccount ]]
141157
142158 def __init__ (
143159 self ,
144- user_db_model : Type [UD ],
145160 session : AsyncSession ,
146- oauth_account_model : Optional [Type [OA ]] = None ,
161+ user_model : Type [UP ],
162+ oauth_account_model : Optional [Type [SQLModelBaseOAuthAccount ]] = None ,
147163 ):
148- super ().__init__ (user_db_model )
149164 self .session = session
165+ self .user_model = user_model
150166 self .oauth_account_model = oauth_account_model
151167
152- async def get (self , id : UUID4 ) -> Optional [UD ]:
168+ async def get (self , id : ID ) -> Optional [UP ]:
153169 """Get a single user by id."""
154- return await self .session .get (self .user_db_model , id )
170+ return await self .session .get (self .user_model , id )
155171
156- async def get_by_email (self , email : str ) -> Optional [UD ]:
172+ async def get_by_email (self , email : str ) -> Optional [UP ]:
157173 """Get a single user by email."""
158- statement = select (self .user_db_model ).where (
159- func .lower (self .user_db_model .email ) == func .lower (email )
174+ statement = select (self .user_model ).where (
175+ func .lower (self .user_model .email ) == func .lower (email )
160176 )
161177 results = await self .session .execute (statement )
162178 object = results .first ()
163179 if object is None :
164180 return None
165181 return object [0 ]
166182
167- async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UD ]:
183+ async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UP ]:
168184 """Get a single user by OAuth account id."""
169- if not self .oauth_account_model :
170- raise NotSetOAuthAccountTableError ()
185+ if self .oauth_account_model is None :
186+ raise NotImplementedError ()
171187 statement = (
172188 select (self .oauth_account_model )
173189 .where (self .oauth_account_model .oauth_name == oauth )
@@ -177,31 +193,51 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD
177193 results = await self .session .execute (statement )
178194 oauth_account = results .first ()
179195 if oauth_account :
180- user = oauth_account [0 ].user
196+ user = oauth_account [0 ].user # type: ignore
181197 return user
182198 return None
183199
184- async def create (self , user : UD ) -> UD :
200+ async def create (self , create_dict : Dict [ str , Any ] ) -> UP :
185201 """Create a user."""
202+ user = self .user_model (** create_dict )
186203 self .session .add (user )
187- if self .oauth_account_model is not None :
188- for oauth_account in user .oauth_accounts : # type: ignore
189- self .session .add (oauth_account )
190204 await self .session .commit ()
191205 await self .session .refresh (user )
192206 return user
193207
194- async def update (self , user : UD ) -> UD :
195- """Update a user."""
208+ async def update (self , user : UP , update_dict : Dict [str , Any ]) -> UP :
209+ for key , value in update_dict .items ():
210+ setattr (user , key , value )
196211 self .session .add (user )
197- if self .oauth_account_model is not None :
198- for oauth_account in user .oauth_accounts : # type: ignore
199- self .session .add (oauth_account )
200212 await self .session .commit ()
201213 await self .session .refresh (user )
202214 return user
203215
204- async def delete (self , user : UD ) -> None :
205- """Delete a user."""
216+ async def delete (self , user : UP ) -> None :
206217 await self .session .delete (user )
207218 await self .session .commit ()
219+
220+ async def add_oauth_account (self , user : UP , create_dict : Dict [str , Any ]) -> UP :
221+ if self .oauth_account_model is None :
222+ raise NotImplementedError ()
223+
224+ oauth_account = self .oauth_account_model (** create_dict )
225+ user .oauth_accounts .append (oauth_account ) # type: ignore
226+ self .session .add (user )
227+
228+ await self .session .commit ()
229+
230+ return user
231+
232+ async def update_oauth_account (
233+ self , user : UP , oauth_account : OAP , update_dict : Dict [str , Any ]
234+ ) -> UP :
235+ if self .oauth_account_model is None :
236+ raise NotImplementedError ()
237+
238+ for key , value in update_dict .items ():
239+ setattr (oauth_account , key , value )
240+ self .session .add (oauth_account )
241+ await self .session .commit ()
242+
243+ return user
0 commit comments