@@ -90,28 +90,31 @@ class DynamoDBUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]):
9090 user_table : type [UP ]
9191 oauth_account_table : type [DynamoDBBaseOAuthAccountTable ] | None
9292 user_table_name : str
93- oauth_table_name : str | None
93+ oauth_account_table_name : str | None
9494 _resource : Any | None
95+ _resource_region : str | None
9596
9697 def __init__ (
9798 self ,
9899 session : aioboto3 .Session ,
99100 user_table : type [UP ],
100101 user_table_name : str ,
101102 oauth_account_table : type [DynamoDBBaseOAuthAccountTable ] | None = None ,
102- oauth_table_name : str | None = None ,
103+ oauth_account_table_name : str | None = None ,
103104 dynamodb_resource : Any | None = None ,
105+ dynamodb_resource_region : str | None = None ,
104106 ):
105107 self .session = session
106108 self .user_table = user_table
107109 self .oauth_account_table = oauth_account_table
108110 self .user_table_name = user_table_name
109- self .oauth_table_name = oauth_table_name
111+ self .oauth_account_table_name = oauth_account_table_name
110112
111113 self ._resource = dynamodb_resource
114+ self ._resource_region = dynamodb_resource_region
112115
113116 @asynccontextmanager
114- async def _table (self , table_name : str ):
117+ async def _table (self , table_name : str , region : str | None = None ):
115118 """Async context manager that yields a Table object.
116119
117120 If a long-lived resource was provided at init, it's reused (no enter/exit).
@@ -121,7 +124,13 @@ async def _table(self, table_name: str):
121124 table = await self ._resource .Table (table_name )
122125 yield table
123126 else :
124- async with self .session .resource ("dynamodb" ) as dynamodb :
127+ if region is None :
128+ raise ValueError (
129+ "Parameter `region` must be specified when `dynamodb_resource` is omitted"
130+ )
131+ async with self .session .resource (
132+ "dynamodb" , region_name = region
133+ ) as dynamodb :
125134 table = await dynamodb .Table (table_name )
126135 yield table
127136
@@ -178,15 +187,15 @@ def _ensure_email_lower(self, data: dict[str, Any]) -> None:
178187 async def get (self , id : ID | str ) -> UP | None :
179188 """Get a user by id."""
180189 id_str = self ._ensure_id_str (id )
181- async with self ._table (self .user_table_name ) as table :
190+ async with self ._table (self .user_table_name , self . _resource_region ) as table :
182191 resp = await table .get_item (Key = {"id" : id_str })
183192 item = resp .get ("Item" )
184193 return self ._item_to_user (item )
185194
186195 async def get_by_email (self , email : str ) -> UP | None :
187196 """Get a user by email (case-insensitive: emails are stored lowercased)."""
188197 email_norm = email .lower ()
189- async with self ._table (self .user_table_name ) as table :
198+ async with self ._table (self .user_table_name , self . _resource_region ) as table :
190199 resp = await table .scan (
191200 FilterExpression = Attr ("email" ).eq (email_norm ),
192201 Limit = 1 ,
@@ -198,10 +207,12 @@ async def get_by_email(self, email: str) -> UP | None:
198207
199208 async def get_by_oauth_account (self , oauth : str , account_id : str ) -> UP | None :
200209 """Find a user by oauth provider and provider account id."""
201- if self .oauth_account_table is None or self .oauth_table_name is None :
210+ if self .oauth_account_table is None or self .oauth_account_table_name is None :
202211 raise NotImplementedError ()
203212
204- async with self ._table (self .oauth_table_name ) as oauth_table :
213+ async with self ._table (
214+ self .oauth_account_table_name , self ._resource_region
215+ ) as oauth_table :
205216 resp = await oauth_table .scan (
206217 FilterExpression = Attr ("oauth_name" ).eq (oauth )
207218 & Attr ("account_id" ).eq (account_id ),
@@ -227,7 +238,7 @@ async def create(self, create_dict: dict[str, Any]) -> UP:
227238
228239 self ._ensure_email_lower (item )
229240
230- async with self ._table (self .user_table_name ) as table :
241+ async with self ._table (self .user_table_name , self . _resource_region ) as table :
231242 await table .put_item (Item = item )
232243
233244 resp = await table .get_item (Key = {"id" : item ["id" ]})
@@ -241,7 +252,7 @@ async def create(self, create_dict: dict[str, Any]) -> UP:
241252 async def update (self , user : UP , update_dict : dict [str , Any ]) -> UP :
242253 """Update a user with update_dict and return the updated UP instance."""
243254 user_id = self ._extract_id_from_user (user )
244- async with self ._table (self .user_table_name ) as table :
255+ async with self ._table (self .user_table_name , self . _resource_region ) as table :
245256 resp = await table .get_item (Key = {"id" : user_id })
246257 current = resp .get ("Item" , {})
247258 if not current :
@@ -263,12 +274,12 @@ async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
263274 async def delete (self , user : UP ) -> None :
264275 """Delete a user."""
265276 user_id = self ._extract_id_from_user (user )
266- async with self ._table (self .user_table_name ) as table :
277+ async with self ._table (self .user_table_name , self . _resource_region ) as table :
267278 await table .delete_item (Key = {"id" : user_id })
268279
269280 async def add_oauth_account (self , user : UP , create_dict : dict [str , Any ]) -> UP :
270281 """Add an OAuth account for `user`. Returns the refreshed user (UP)."""
271- if self .oauth_account_table is None or self .oauth_table_name is None :
282+ if self .oauth_account_table is None or self .oauth_account_table_name is None :
272283 raise NotImplementedError ()
273284
274285 oauth_item = dict (create_dict )
@@ -278,7 +289,9 @@ async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP:
278289 user_id = self ._extract_id_from_user (user )
279290 oauth_item ["user_id" ] = user_id
280291
281- async with self ._table (self .oauth_table_name ) as oauth_table :
292+ async with self ._table (
293+ self .oauth_account_table_name , self ._resource_region
294+ ) as oauth_table :
282295 await oauth_table .put_item (Item = oauth_item )
283296
284297 refreshed_user = await self .get (user_id )
@@ -305,7 +318,7 @@ async def update_oauth_account(
305318 update_dict : dict [str , Any ],
306319 ) -> UP :
307320 """Update an OAuth account and return the refreshed user (UP)."""
308- if self .oauth_account_table is None or self .oauth_table_name is None :
321+ if self .oauth_account_table is None or self .oauth_account_table_name is None :
309322 raise NotImplementedError ()
310323
311324 oauth_id = None
@@ -328,7 +341,9 @@ async def update_oauth_account(
328341
329342 oauth_id_str = self ._ensure_id_str (oauth_id )
330343
331- async with self ._table (self .oauth_table_name ) as oauth_table :
344+ async with self ._table (
345+ self .oauth_account_table_name , self ._resource_region
346+ ) as oauth_table :
332347 resp = await oauth_table .get_item (Key = {"id" : oauth_id_str })
333348 current = resp .get ("Item" , {})
334349 if not current :
0 commit comments