Skip to content

Commit aa61197

Browse files
committed
Test improvements
1 parent f26033b commit aa61197

File tree

5 files changed

+220
-102
lines changed

5 files changed

+220
-102
lines changed

fastapi_users_db_dynamodb/__init__.py

Lines changed: 94 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@
2323

2424
import aioboto3
2525
from boto3.dynamodb.conditions import Attr
26+
from botocore.exceptions import ClientError
2627
from fastapi_users.db.base import BaseUserDatabase
2728
from fastapi_users.models import ID, OAP, UP
2829
from pydantic import BaseModel, ConfigDict, Field
2930

3031
from fastapi_users_db_dynamodb._aioboto3_patch import * # noqa: F403
31-
from fastapi_users_db_dynamodb.generics import GUID
32+
from fastapi_users_db_dynamodb.generics import UUID_ID
3233

3334
__version__ = "1.0.0"
3435

35-
UUID_ID = uuid.UUID
36+
DATABASE_USERTABLE_PRIMARY_KEY: str = "id"
3637

3738

3839
class DynamoDBBaseUserTable(BaseModel, Generic[ID]):
@@ -88,7 +89,9 @@ class DynamoDBBaseOAuthAccountTableUUID(DynamoDBBaseUserTable[UUID_ID]):
8889
id: UUID_ID = Field(
8990
default_factory=uuid.uuid4, description="The ID for the OAuth account"
9091
)
91-
user_id: GUID = Field(..., description="The user ID this OAuth account belongs to")
92+
user_id: UUID_ID = Field(
93+
..., description="The user ID this OAuth account belongs to"
94+
)
9295

9396

9497
class DynamoDBUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]):
@@ -109,6 +112,7 @@ class DynamoDBUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]):
109112
user_table: type[UP]
110113
oauth_account_table: type[DynamoDBBaseOAuthAccountTable] | None
111114
user_table_name: str
115+
primary_key: str = DATABASE_USERTABLE_PRIMARY_KEY
112116
oauth_account_table_name: str | None
113117
_resource: Any | None
114118
_resource_region: str | None
@@ -118,6 +122,7 @@ def __init__(
118122
session: aioboto3.Session,
119123
user_table: type[UP],
120124
user_table_name: str,
125+
primary_key: str = DATABASE_USERTABLE_PRIMARY_KEY,
121126
oauth_account_table: type[DynamoDBBaseOAuthAccountTable] | None = None,
122127
oauth_account_table_name: str | None = None,
123128
dynamodb_resource: Any | None = None,
@@ -127,6 +132,7 @@ def __init__(
127132
self.user_table = user_table
128133
self.oauth_account_table = oauth_account_table
129134
self.user_table_name = user_table_name
135+
self.primary_key = primary_key
130136
self.oauth_account_table_name = oauth_account_table_name
131137

132138
self._resource = dynamodb_resource
@@ -163,9 +169,11 @@ def _extract_id_from_user(self, user_obj: Any) -> str:
163169
if isinstance(user_obj, dict):
164170
idv = user_obj.get("id")
165171

166-
elif hasattr(user_obj, "dict") and callable(getattr(user_obj, "dict")):
172+
elif hasattr(user_obj, "model_dump") and callable(
173+
getattr(user_obj, "model_dump")
174+
):
167175
try:
168-
idv = user_obj.dict().get("id")
176+
idv = user_obj.model_dump().get("id")
169177
except Exception:
170178
idv = getattr(user_obj, "id", None)
171179

@@ -207,7 +215,7 @@ async def get(self, id: ID | str) -> UP | None:
207215
"""Get a user by id."""
208216
id_str = self._ensure_id_str(id)
209217
async with self._table(self.user_table_name, self._resource_region) as table:
210-
resp = await table.get_item(Key={"id": id_str})
218+
resp = await table.get_item(Key={self.primary_key: id_str})
211219
item = resp.get("Item")
212220
return self._item_to_user(item)
213221

@@ -258,12 +266,21 @@ async def create(self, create_dict: dict[str, Any]) -> UP:
258266
self._ensure_email_lower(item)
259267

260268
async with self._table(self.user_table_name, self._resource_region) as table:
261-
await table.put_item(Item=item)
262-
263-
resp = await table.get_item(Key={"id": item["id"]})
264-
stored = resp.get("Item", item)
265-
266-
refreshed_user = self._item_to_user(stored)
269+
try:
270+
await table.put_item(
271+
Item=item,
272+
ConditionExpression="attribute_not_exists(#id)",
273+
ExpressionAttributeNames={"#id": self.primary_key},
274+
)
275+
except ClientError as e:
276+
if (
277+
e.response.get("Error", {}).get("Code")
278+
== "ConditionalCheckFailedException"
279+
):
280+
raise ValueError(f"User {item['id']} already exists.")
281+
raise
282+
283+
refreshed_user = self._item_to_user(item)
267284
if refreshed_user is None:
268285
raise ValueError("Could not cast DB item to User model")
269286
return refreshed_user
@@ -272,20 +289,31 @@ async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
272289
"""Update a user with update_dict and return the updated UP instance."""
273290
user_id = self._extract_id_from_user(user)
274291
async with self._table(self.user_table_name, self._resource_region) as table:
275-
resp = await table.get_item(Key={"id": user_id})
276-
current = resp.get("Item", {})
292+
resp = await table.get_item(Key={self.primary_key: user_id})
293+
current = resp.get("Item", None)
294+
277295
if not current:
278296
raise ValueError("User not found")
279297

280298
merged = {**current, **update_dict}
281299

282300
self._ensure_email_lower(merged)
283301

284-
await table.put_item(Item=merged)
285-
286-
refreshed = (await table.get_item(Key={"id": user_id})).get("Item", merged)
287-
288-
refreshed_user = self._item_to_user(refreshed)
302+
try:
303+
await table.put_item(
304+
Item=merged,
305+
ConditionExpression="attribute_exists(#id)",
306+
ExpressionAttributeNames={"#id": self.primary_key},
307+
)
308+
except ClientError as e:
309+
if (
310+
e.response.get("Error", {}).get("Code")
311+
== "ConditionalCheckFailedException"
312+
):
313+
raise ValueError(f"User {user_id} does not exist.")
314+
raise
315+
316+
refreshed_user = self._item_to_user(merged)
289317
if refreshed_user is None:
290318
raise ValueError("Could not cast DB item to User model")
291319
return refreshed_user
@@ -294,7 +322,19 @@ async def delete(self, user: UP) -> None:
294322
"""Delete a user."""
295323
user_id = self._extract_id_from_user(user)
296324
async with self._table(self.user_table_name, self._resource_region) as table:
297-
await table.delete_item(Key={"id": user_id})
325+
try:
326+
await table.delete_item(
327+
Key={self.primary_key: user_id},
328+
ConditionExpression="attribute_exists(#id)",
329+
ExpressionAttributeNames={"#id": self.primary_key},
330+
)
331+
except ClientError as e:
332+
if (
333+
e.response.get("Error", {}).get("Code")
334+
== "ConditionalCheckFailedException"
335+
):
336+
raise ValueError(f"User {user_id} does not exist.")
337+
raise
298338

299339
async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP:
300340
"""Add an OAuth account for `user`. Returns the refreshed user (UP)."""
@@ -313,22 +353,15 @@ async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP:
313353
) as oauth_table:
314354
await oauth_table.put_item(Item=oauth_item)
315355

316-
refreshed_user = await self.get(user_id)
317-
if refreshed_user is None:
318-
raise ValueError("Refreshed user is None")
319-
320-
try:
356+
if hasattr(user, "oauth_accounts"):
321357
oauth_obj = (
322358
self.oauth_account_table(**oauth_item)
323359
if self.oauth_account_table is not None
324360
else oauth_item
325361
)
326-
if hasattr(refreshed_user, "oauth_accounts"):
327-
getattr(refreshed_user, "oauth_accounts").append(oauth_obj)
328-
except Exception:
329-
pass
362+
user.oauth_accounts.append(oauth_obj) # type: ignore
330363

331-
return refreshed_user
364+
return user
332365

333366
async def update_oauth_account(
334367
self,
@@ -343,11 +376,11 @@ async def update_oauth_account(
343376
oauth_id = None
344377
if isinstance(oauth_account, dict):
345378
oauth_id = oauth_account.get("id")
346-
elif hasattr(oauth_account, "dict") and callable(
347-
getattr(oauth_account, "dict")
379+
elif hasattr(oauth_account, "model_dump") and callable(
380+
getattr(oauth_account, "model_dump")
348381
):
349382
try:
350-
oauth_id = oauth_account.dict().get("id") # type: ignore
383+
oauth_id = oauth_account.model_dump().get("id") # type: ignore
351384
except Exception:
352385
oauth_id = getattr(oauth_account, "id", None)
353386
elif hasattr(oauth_account, "id"):
@@ -363,16 +396,32 @@ async def update_oauth_account(
363396
async with self._table(
364397
self.oauth_account_table_name, self._resource_region
365398
) as oauth_table:
366-
resp = await oauth_table.get_item(Key={"id": oauth_id_str})
367-
current = resp.get("Item", {})
368-
if not current:
369-
raise ValueError("OAuth account not found")
370-
371-
merged = {**current, **update_dict}
372-
await oauth_table.put_item(Item=merged)
399+
merged = {**update_dict}
400+
try:
401+
await oauth_table.put_item(
402+
Item=merged,
403+
ConditionExpression="attribute_exists(#id)",
404+
ExpressionAttributeNames={"#id": self.primary_key},
405+
)
406+
except ClientError as e:
407+
if (
408+
e.response.get("Error", {}).get("Code")
409+
== "ConditionalCheckFailedException"
410+
):
411+
raise ValueError(f"User {oauth_id_str} already exists.")
412+
raise
413+
414+
if hasattr(user, "oauth_accounts"):
415+
existing_oauth = next(
416+
(
417+
account
418+
for account in user.oauth_accounts # type: ignore
419+
if account.id == oauth_id_str
420+
),
421+
None,
422+
)
423+
if existing_oauth:
424+
index = user.oauth_accounts.index(existing_oauth) # type: ignore
425+
user.oauth_accounts[index] = merged # type: ignore
373426

374-
user_id = self._extract_id_from_user(user)
375-
refreshed_user = await self.get(user_id)
376-
if refreshed_user is None:
377-
raise ValueError("Could not cast DB item to User model")
378-
return refreshed_user
427+
return user

0 commit comments

Comments
 (0)