Skip to content

Commit 969ab9f

Browse files
committed
Fix last test
1 parent 599787e commit 969ab9f

File tree

1 file changed

+77
-46
lines changed

1 file changed

+77
-46
lines changed

fastapi_users_db_dynamodb/__init__.py

Lines changed: 77 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class DynamoDBBaseOAuthAccountTable(Generic[ID]):
8585
)
8686

8787

88-
class DynamoDBBaseOAuthAccountTableUUID(DynamoDBBaseUserTable[UUID_ID]):
88+
class DynamoDBBaseOAuthAccountTableUUID(DynamoDBBaseOAuthAccountTable[UUID_ID]):
8989
id: UUID_ID = Field(
9090
default_factory=uuid.uuid4, description="The ID for the OAuth account"
9191
)
@@ -159,6 +159,20 @@ async def _table(self, table_name: str, region: str | None = None):
159159
table = await dynamodb.Table(table_name)
160160
yield table
161161

162+
def _serialize_for_dynamodb(self, data: dict[str, Any]) -> dict[str, Any]:
163+
"""Convert UUIDs and other incompatible types for DynamoDB."""
164+
result = {}
165+
for key, value in data.items():
166+
if isinstance(value, uuid.UUID):
167+
result[key] = str(value)
168+
elif isinstance(value, list):
169+
result[key] = [str(v) if isinstance(v, uuid.UUID) else v for v in value]
170+
elif isinstance(value, dict):
171+
result[key] = self._serialize_for_dynamodb(value)
172+
else:
173+
result[key] = value
174+
return result
175+
162176
def _ensure_id_str(self, value: Any) -> str:
163177
"""Normalize id to string for DynamoDB keys."""
164178
return str(value)
@@ -212,12 +226,30 @@ def _ensure_email_lower(self, data: dict[str, Any]) -> None:
212226
data["email"] = data["email"].lower()
213227

214228
async def get(self, id: ID | str) -> UP | None:
215-
"""Get a user by id."""
229+
"""Get a user by id and hydrate oauth_accounts if available."""
216230
id_str = self._ensure_id_str(id)
231+
217232
async with self._table(self.user_table_name, self._resource_region) as table:
218233
resp = await table.get_item(Key={self.primary_key: id_str})
219234
item = resp.get("Item")
220-
return self._item_to_user(item)
235+
user = self._item_to_user(item)
236+
237+
if user is None:
238+
return None
239+
240+
if self.oauth_account_table and self.oauth_account_table_name:
241+
async with self._table(
242+
self.oauth_account_table_name, self._resource_region
243+
) as oauth_table:
244+
resp = await oauth_table.scan(
245+
FilterExpression=Attr("user_id").eq(id_str)
246+
)
247+
accounts = resp.get("Items", [])
248+
user.oauth_accounts = [ # type: ignore
249+
self.oauth_account_table(**acc) for acc in accounts
250+
]
251+
252+
return user
221253

222254
async def get_by_email(self, email: str) -> UP | None:
223255
"""Get a user by email (case-insensitive: emails are stored lowercased)."""
@@ -230,7 +262,25 @@ async def get_by_email(self, email: str) -> UP | None:
230262
items = resp.get("Items", [])
231263
if not items:
232264
return None
233-
return self._item_to_user(items[0])
265+
user = self._item_to_user(items[0])
266+
267+
if user is None:
268+
return None
269+
270+
user_id = self._ensure_id_str(user.id)
271+
if self.oauth_account_table and self.oauth_account_table_name:
272+
async with self._table(
273+
self.oauth_account_table_name, self._resource_region
274+
) as oauth_table:
275+
resp = await oauth_table.scan(
276+
FilterExpression=Attr("user_id").eq(user_id)
277+
)
278+
accounts = resp.get("Items", [])
279+
user.oauth_accounts = [ # type: ignore
280+
self.oauth_account_table(**acc) for acc in accounts
281+
]
282+
283+
return user
234284

235285
async def get_by_oauth_account(self, oauth: str, account_id: str) -> UP | None:
236286
"""Find a user by oauth provider and provider account id."""
@@ -250,7 +300,7 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> UP | None:
250300
return None
251301

252302
user_id = items[0].get("user_id")
253-
if user_id is None:
303+
if not user_id:
254304
return None
255305

256306
return await self.get(user_id)
@@ -268,7 +318,7 @@ async def create(self, create_dict: dict[str, Any]) -> UP:
268318
async with self._table(self.user_table_name, self._resource_region) as table:
269319
try:
270320
await table.put_item(
271-
Item=item,
321+
Item=self._serialize_for_dynamodb(item),
272322
ConditionExpression="attribute_not_exists(#id)",
273323
ExpressionAttributeNames={"#id": self.primary_key},
274324
)
@@ -301,7 +351,7 @@ async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
301351

302352
try:
303353
await table.put_item(
304-
Item=merged,
354+
Item=self._serialize_for_dynamodb(merged),
305355
ConditionExpression="attribute_exists(#id)",
306356
ExpressionAttributeNames={"#id": self.primary_key},
307357
)
@@ -351,14 +401,10 @@ async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP:
351401
async with self._table(
352402
self.oauth_account_table_name, self._resource_region
353403
) as oauth_table:
354-
await oauth_table.put_item(Item=oauth_item)
404+
await oauth_table.put_item(Item=self._serialize_for_dynamodb(oauth_item))
355405

356406
if hasattr(user, "oauth_accounts"):
357-
oauth_obj = (
358-
self.oauth_account_table(**oauth_item)
359-
if self.oauth_account_table is not None
360-
else oauth_item
361-
)
407+
oauth_obj = self.oauth_account_table(**oauth_item)
362408
user.oauth_accounts.append(oauth_obj) # type: ignore
363409

364410
return user
@@ -368,38 +414,28 @@ async def update_oauth_account(
368414
user: UP,
369415
oauth_account: OAP, # type: ignore
370416
update_dict: dict[str, Any],
371-
) -> UP:
417+
) -> UP | None:
372418
"""Update an OAuth account and return the refreshed user (UP)."""
373419
if self.oauth_account_table is None or self.oauth_account_table_name is None:
374420
raise NotImplementedError()
375421

376-
oauth_id = None
377-
if isinstance(oauth_account, dict):
378-
oauth_id = oauth_account.get("id")
379-
elif hasattr(oauth_account, "model_dump") and callable(
380-
getattr(oauth_account, "model_dump")
381-
):
382-
try:
383-
oauth_id = oauth_account.model_dump().get("id") # type: ignore
384-
except Exception:
385-
oauth_id = getattr(oauth_account, "id", None)
386-
elif hasattr(oauth_account, "id"):
387-
oauth_id = getattr(oauth_account, "id", None)
388-
elif hasattr(oauth_account, "__dict__"):
389-
oauth_id = vars(oauth_account).get("id")
422+
oauth_item = (
423+
oauth_account.model_dump() # type: ignore
424+
if hasattr(oauth_account, "model_dump")
425+
else vars(oauth_account)
426+
)
390427

391-
if oauth_id is None:
392-
raise ValueError("oauth_account has no 'id' field")
428+
updated_item = {**oauth_item, **update_dict}
393429

394-
oauth_id_str = self._ensure_id_str(oauth_id)
430+
for field in ("id", "user_id", "oauth_name", "account_id"):
431+
updated_item[field] = getattr(oauth_account, field, oauth_item.get(field))
395432

396433
async with self._table(
397434
self.oauth_account_table_name, self._resource_region
398435
) as oauth_table:
399-
merged = {**update_dict}
400436
try:
401437
await oauth_table.put_item(
402-
Item=merged,
438+
Item=self._serialize_for_dynamodb(updated_item),
403439
ConditionExpression="attribute_exists(#id)",
404440
ExpressionAttributeNames={"#id": self.primary_key},
405441
)
@@ -408,20 +444,15 @@ async def update_oauth_account(
408444
e.response.get("Error", {}).get("Code")
409445
== "ConditionalCheckFailedException"
410446
):
411-
raise ValueError(f"User {oauth_id_str} already exists.")
447+
raise ValueError(
448+
f"OAuth account with ID {updated_item['id']} does not exist."
449+
)
412450
raise
413451

414452
if hasattr(user, "oauth_accounts"):
415-
existing_oauth = next(
416-
(
417-
account
418-
for account in user.oauth_accounts # type: ignore
419-
if account.id == oauth_id_str
420-
),
421-
None,
422-
)
423-
if existing_oauth:
424-
index = user.oauth_accounts.index(existing_oauth) # type: ignore
425-
user.oauth_accounts[index] = merged # type: ignore
453+
for idx, account in enumerate(user.oauth_accounts): # type: ignore
454+
if str(getattr(account, "id", None)) == str(updated_item["id"]):
455+
user.oauth_accounts[idx] = type(oauth_account)(**updated_item) # type: ignore
456+
break
426457

427-
return user
458+
return await self.get(user.id)

0 commit comments

Comments
 (0)