@@ -62,71 +62,71 @@ class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
6262
6363 :param user_db_model: Pydantic model of a DB representation of a user.
6464 :param session: SQLAlchemy session instance.
65- :param user_model : SQLAlchemy user model.
66- :param oauth_account_model : Optional SQLAlchemy OAuth accounts model.
65+ :param user_table : SQLAlchemy user model.
66+ :param oauth_account_table : Optional SQLAlchemy OAuth accounts model.
6767 """
6868
6969 session : AsyncSession
70- user_model : Type [SQLAlchemyBaseUserTable ]
71- oauth_account_model : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]]
70+ user_table : Type [SQLAlchemyBaseUserTable ]
71+ oauth_account_table : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]]
7272
7373 def __init__ (
7474 self ,
7575 user_db_model : Type [UD ],
7676 session : AsyncSession ,
77- user_model : Type [SQLAlchemyBaseUserTable ],
78- oauth_account_model : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]] = None ,
77+ user_table : Type [SQLAlchemyBaseUserTable ],
78+ oauth_account_table : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]] = None ,
7979 ):
8080 super ().__init__ (user_db_model )
8181 self .session = session
82- self .user_model = user_model
83- self .oauth_account_model = oauth_account_model
82+ self .user_table = user_table
83+ self .oauth_account_table = oauth_account_table
8484
8585 async def get (self , id : UUID4 ) -> Optional [UD ]:
86- statement = select (self .user_model ).where (self .user_model .id == id )
86+ statement = select (self .user_table ).where (self .user_table .id == id )
8787 return await self ._get_user (statement )
8888
8989 async def get_by_email (self , email : str ) -> Optional [UD ]:
90- statement = select (self .user_model ).where (
91- func .lower (self .user_model .email ) == func .lower (email )
90+ statement = select (self .user_table ).where (
91+ func .lower (self .user_table .email ) == func .lower (email )
9292 )
9393 return await self ._get_user (statement )
9494
9595 async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UD ]:
96- if self .oauth_account_model is not None :
96+ if self .oauth_account_table is not None :
9797 statement = (
98- select (self .user_model )
99- .join (self .oauth_account_model )
100- .where (self .oauth_account_model .oauth_name == oauth )
101- .where (self .oauth_account_model .account_id == account_id )
98+ select (self .user_table )
99+ .join (self .oauth_account_table )
100+ .where (self .oauth_account_table .oauth_name == oauth )
101+ .where (self .oauth_account_table .account_id == account_id )
102102 )
103103 return await self ._get_user (statement )
104104
105105 async def create (self , user : UD ) -> UD :
106- user_model = self .user_model (** user .dict (exclude = {"oauth_accounts" }))
107- self .session .add (user_model )
106+ user_table = self .user_table (** user .dict (exclude = {"oauth_accounts" }))
107+ self .session .add (user_table )
108108
109- if self .oauth_account_model is not None :
109+ if self .oauth_account_table is not None :
110110 for oauth_account in user .oauth_accounts :
111- oauth_account_model = self .oauth_account_model (
111+ oauth_account_table = self .oauth_account_table (
112112 ** oauth_account .dict (), user_id = user .id
113113 )
114- self .session .add (oauth_account_model )
114+ self .session .add (oauth_account_table )
115115
116116 await self .session .commit ()
117117 return user
118118
119119 async def update (self , user : UD ) -> UD :
120- user_model = await self .session .get (self .user_model , user .id )
120+ user_table = await self .session .get (self .user_table , user .id )
121121 for key , value in user .dict (exclude = {"oauth_accounts" }).items ():
122- setattr (user_model , key , value )
123- self .session .add (user_model )
122+ setattr (user_table , key , value )
123+ self .session .add (user_table )
124124
125- if self .oauth_account_model is not None :
125+ if self .oauth_account_table is not None :
126126 for oauth_account in user .oauth_accounts :
127127 statement = update (
128- self .oauth_account_model ,
129- whereclause = self .oauth_account_model .id == oauth_account .id ,
128+ self .oauth_account_table ,
129+ whereclause = self .oauth_account_table .id == oauth_account .id ,
130130 values = {** oauth_account .dict (), "user_id" : user .id },
131131 )
132132 await self .session .execute (statement )
@@ -136,11 +136,11 @@ async def update(self, user: UD) -> UD:
136136 return user
137137
138138 async def delete (self , user : UD ) -> None :
139- statement = delete (self .user_model , self .user_model .id == user .id )
139+ statement = delete (self .user_table , self .user_table .id == user .id )
140140 await self .session .execute (statement )
141141
142142 async def _get_user (self , statement : Select ) -> Optional [UD ]:
143- if self .oauth_account_model is not None :
143+ if self .oauth_account_table is not None :
144144 statement = statement .options (joinedload ("oauth_accounts" ))
145145
146146 results = await self .session .execute (statement )
0 commit comments