2323
2424import aioboto3
2525from boto3 .dynamodb .conditions import Attr
26+ from botocore .exceptions import ClientError
2627from fastapi_users .db .base import BaseUserDatabase
2728from fastapi_users .models import ID , OAP , UP
2829from pydantic import BaseModel , ConfigDict , Field
2930
3031from fastapi_users_db_dynamodb ._aioboto3_patch import * # noqa: F403
31- from fastapi_users_db_dynamodb .generics import GUID
32+ from fastapi_users_db_dynamodb .generics import UUID_ID
3233
3334__version__ = "1.0.0"
3435
35- UUID_ID = uuid . UUID
36+ DATABASE_USERTABLE_PRIMARY_KEY : str = "id"
3637
3738
3839class DynamoDBBaseUserTable (BaseModel , Generic [ID ]):
@@ -88,7 +89,9 @@ class DynamoDBBaseOAuthAccountTableUUID(DynamoDBBaseUserTable[UUID_ID]):
8889 id : UUID_ID = Field (
8990 default_factory = uuid .uuid4 , description = "The ID for the OAuth account"
9091 )
91- user_id : GUID = Field (..., description = "The user ID this OAuth account belongs to" )
92+ user_id : UUID_ID = Field (
93+ ..., description = "The user ID this OAuth account belongs to"
94+ )
9295
9396
9497class DynamoDBUserDatabase (Generic [UP , ID ], BaseUserDatabase [UP , ID ]):
@@ -109,6 +112,7 @@ class DynamoDBUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]):
109112 user_table : type [UP ]
110113 oauth_account_table : type [DynamoDBBaseOAuthAccountTable ] | None
111114 user_table_name : str
115+ primary_key : str = DATABASE_USERTABLE_PRIMARY_KEY
112116 oauth_account_table_name : str | None
113117 _resource : Any | None
114118 _resource_region : str | None
@@ -118,6 +122,7 @@ def __init__(
118122 session : aioboto3 .Session ,
119123 user_table : type [UP ],
120124 user_table_name : str ,
125+ primary_key : str = DATABASE_USERTABLE_PRIMARY_KEY ,
121126 oauth_account_table : type [DynamoDBBaseOAuthAccountTable ] | None = None ,
122127 oauth_account_table_name : str | None = None ,
123128 dynamodb_resource : Any | None = None ,
@@ -127,6 +132,7 @@ def __init__(
127132 self .user_table = user_table
128133 self .oauth_account_table = oauth_account_table
129134 self .user_table_name = user_table_name
135+ self .primary_key = primary_key
130136 self .oauth_account_table_name = oauth_account_table_name
131137
132138 self ._resource = dynamodb_resource
@@ -163,9 +169,11 @@ def _extract_id_from_user(self, user_obj: Any) -> str:
163169 if isinstance (user_obj , dict ):
164170 idv = user_obj .get ("id" )
165171
166- elif hasattr (user_obj , "dict" ) and callable (getattr (user_obj , "dict" )):
172+ elif hasattr (user_obj , "model_dump" ) and callable (
173+ getattr (user_obj , "model_dump" )
174+ ):
167175 try :
168- idv = user_obj .dict ().get ("id" )
176+ idv = user_obj .model_dump ().get ("id" )
169177 except Exception :
170178 idv = getattr (user_obj , "id" , None )
171179
@@ -207,7 +215,7 @@ async def get(self, id: ID | str) -> UP | None:
207215 """Get a user by id."""
208216 id_str = self ._ensure_id_str (id )
209217 async with self ._table (self .user_table_name , self ._resource_region ) as table :
210- resp = await table .get_item (Key = {"id" : id_str })
218+ resp = await table .get_item (Key = {self . primary_key : id_str })
211219 item = resp .get ("Item" )
212220 return self ._item_to_user (item )
213221
@@ -258,12 +266,21 @@ async def create(self, create_dict: dict[str, Any]) -> UP:
258266 self ._ensure_email_lower (item )
259267
260268 async with self ._table (self .user_table_name , self ._resource_region ) as table :
261- await table .put_item (Item = item )
262-
263- resp = await table .get_item (Key = {"id" : item ["id" ]})
264- stored = resp .get ("Item" , item )
265-
266- refreshed_user = self ._item_to_user (stored )
269+ try :
270+ await table .put_item (
271+ Item = item ,
272+ ConditionExpression = "attribute_not_exists(#id)" ,
273+ ExpressionAttributeNames = {"#id" : self .primary_key },
274+ )
275+ except ClientError as e :
276+ if (
277+ e .response .get ("Error" , {}).get ("Code" )
278+ == "ConditionalCheckFailedException"
279+ ):
280+ raise ValueError (f"User { item ['id' ]} already exists." )
281+ raise
282+
283+ refreshed_user = self ._item_to_user (item )
267284 if refreshed_user is None :
268285 raise ValueError ("Could not cast DB item to User model" )
269286 return refreshed_user
@@ -272,20 +289,31 @@ async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
272289 """Update a user with update_dict and return the updated UP instance."""
273290 user_id = self ._extract_id_from_user (user )
274291 async with self ._table (self .user_table_name , self ._resource_region ) as table :
275- resp = await table .get_item (Key = {"id" : user_id })
276- current = resp .get ("Item" , {})
292+ resp = await table .get_item (Key = {self .primary_key : user_id })
293+ current = resp .get ("Item" , None )
294+
277295 if not current :
278296 raise ValueError ("User not found" )
279297
280298 merged = {** current , ** update_dict }
281299
282300 self ._ensure_email_lower (merged )
283301
284- await table .put_item (Item = merged )
285-
286- refreshed = (await table .get_item (Key = {"id" : user_id })).get ("Item" , merged )
287-
288- refreshed_user = self ._item_to_user (refreshed )
302+ try :
303+ await table .put_item (
304+ Item = merged ,
305+ ConditionExpression = "attribute_exists(#id)" ,
306+ ExpressionAttributeNames = {"#id" : self .primary_key },
307+ )
308+ except ClientError as e :
309+ if (
310+ e .response .get ("Error" , {}).get ("Code" )
311+ == "ConditionalCheckFailedException"
312+ ):
313+ raise ValueError (f"User { user_id } does not exist." )
314+ raise
315+
316+ refreshed_user = self ._item_to_user (merged )
289317 if refreshed_user is None :
290318 raise ValueError ("Could not cast DB item to User model" )
291319 return refreshed_user
@@ -294,7 +322,19 @@ async def delete(self, user: UP) -> None:
294322 """Delete a user."""
295323 user_id = self ._extract_id_from_user (user )
296324 async with self ._table (self .user_table_name , self ._resource_region ) as table :
297- await table .delete_item (Key = {"id" : user_id })
325+ try :
326+ await table .delete_item (
327+ Key = {self .primary_key : user_id },
328+ ConditionExpression = "attribute_exists(#id)" ,
329+ ExpressionAttributeNames = {"#id" : self .primary_key },
330+ )
331+ except ClientError as e :
332+ if (
333+ e .response .get ("Error" , {}).get ("Code" )
334+ == "ConditionalCheckFailedException"
335+ ):
336+ raise ValueError (f"User { user_id } does not exist." )
337+ raise
298338
299339 async def add_oauth_account (self , user : UP , create_dict : dict [str , Any ]) -> UP :
300340 """Add an OAuth account for `user`. Returns the refreshed user (UP)."""
@@ -313,22 +353,15 @@ async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP:
313353 ) as oauth_table :
314354 await oauth_table .put_item (Item = oauth_item )
315355
316- refreshed_user = await self .get (user_id )
317- if refreshed_user is None :
318- raise ValueError ("Refreshed user is None" )
319-
320- try :
356+ if hasattr (user , "oauth_accounts" ):
321357 oauth_obj = (
322358 self .oauth_account_table (** oauth_item )
323359 if self .oauth_account_table is not None
324360 else oauth_item
325361 )
326- if hasattr (refreshed_user , "oauth_accounts" ):
327- getattr (refreshed_user , "oauth_accounts" ).append (oauth_obj )
328- except Exception :
329- pass
362+ user .oauth_accounts .append (oauth_obj ) # type: ignore
330363
331- return refreshed_user
364+ return user
332365
333366 async def update_oauth_account (
334367 self ,
@@ -343,11 +376,11 @@ async def update_oauth_account(
343376 oauth_id = None
344377 if isinstance (oauth_account , dict ):
345378 oauth_id = oauth_account .get ("id" )
346- elif hasattr (oauth_account , "dict " ) and callable (
347- getattr (oauth_account , "dict " )
379+ elif hasattr (oauth_account , "model_dump " ) and callable (
380+ getattr (oauth_account , "model_dump " )
348381 ):
349382 try :
350- oauth_id = oauth_account .dict ().get ("id" ) # type: ignore
383+ oauth_id = oauth_account .model_dump ().get ("id" ) # type: ignore
351384 except Exception :
352385 oauth_id = getattr (oauth_account , "id" , None )
353386 elif hasattr (oauth_account , "id" ):
@@ -363,16 +396,32 @@ async def update_oauth_account(
363396 async with self ._table (
364397 self .oauth_account_table_name , self ._resource_region
365398 ) as oauth_table :
366- resp = await oauth_table .get_item (Key = {"id" : oauth_id_str })
367- current = resp .get ("Item" , {})
368- if not current :
369- raise ValueError ("OAuth account not found" )
370-
371- merged = {** current , ** update_dict }
372- await oauth_table .put_item (Item = merged )
399+ merged = {** update_dict }
400+ try :
401+ await oauth_table .put_item (
402+ Item = merged ,
403+ ConditionExpression = "attribute_exists(#id)" ,
404+ ExpressionAttributeNames = {"#id" : self .primary_key },
405+ )
406+ except ClientError as e :
407+ if (
408+ e .response .get ("Error" , {}).get ("Code" )
409+ == "ConditionalCheckFailedException"
410+ ):
411+ raise ValueError (f"User { oauth_id_str } already exists." )
412+ raise
413+
414+ if hasattr (user , "oauth_accounts" ):
415+ existing_oauth = next (
416+ (
417+ account
418+ for account in user .oauth_accounts # type: ignore
419+ if account .id == oauth_id_str
420+ ),
421+ None ,
422+ )
423+ if existing_oauth :
424+ index = user .oauth_accounts .index (existing_oauth ) # type: ignore
425+ user .oauth_accounts [index ] = merged # type: ignore
373426
374- user_id = self ._extract_id_from_user (user )
375- refreshed_user = await self .get (user_id )
376- if refreshed_user is None :
377- raise ValueError ("Could not cast DB item to User model" )
378- return refreshed_user
427+ return user
0 commit comments