@@ -85,7 +85,7 @@ class DynamoDBBaseOAuthAccountTable(Generic[ID]):
8585 )
8686
8787
88- class DynamoDBBaseOAuthAccountTableUUID (DynamoDBBaseUserTable [UUID_ID ]):
88+ class DynamoDBBaseOAuthAccountTableUUID (DynamoDBBaseOAuthAccountTable [UUID_ID ]):
8989 id : UUID_ID = Field (
9090 default_factory = uuid .uuid4 , description = "The ID for the OAuth account"
9191 )
@@ -159,6 +159,20 @@ async def _table(self, table_name: str, region: str | None = None):
159159 table = await dynamodb .Table (table_name )
160160 yield table
161161
162+ def _serialize_for_dynamodb (self , data : dict [str , Any ]) -> dict [str , Any ]:
163+ """Convert UUIDs and other incompatible types for DynamoDB."""
164+ result = {}
165+ for key , value in data .items ():
166+ if isinstance (value , uuid .UUID ):
167+ result [key ] = str (value )
168+ elif isinstance (value , list ):
169+ result [key ] = [str (v ) if isinstance (v , uuid .UUID ) else v for v in value ]
170+ elif isinstance (value , dict ):
171+ result [key ] = self ._serialize_for_dynamodb (value )
172+ else :
173+ result [key ] = value
174+ return result
175+
162176 def _ensure_id_str (self , value : Any ) -> str :
163177 """Normalize id to string for DynamoDB keys."""
164178 return str (value )
@@ -212,12 +226,30 @@ def _ensure_email_lower(self, data: dict[str, Any]) -> None:
212226 data ["email" ] = data ["email" ].lower ()
213227
214228 async def get (self , id : ID | str ) -> UP | None :
215- """Get a user by id."""
229+ """Get a user by id and hydrate oauth_accounts if available ."""
216230 id_str = self ._ensure_id_str (id )
231+
217232 async with self ._table (self .user_table_name , self ._resource_region ) as table :
218233 resp = await table .get_item (Key = {self .primary_key : id_str })
219234 item = resp .get ("Item" )
220- return self ._item_to_user (item )
235+ user = self ._item_to_user (item )
236+
237+ if user is None :
238+ return None
239+
240+ if self .oauth_account_table and self .oauth_account_table_name :
241+ async with self ._table (
242+ self .oauth_account_table_name , self ._resource_region
243+ ) as oauth_table :
244+ resp = await oauth_table .scan (
245+ FilterExpression = Attr ("user_id" ).eq (id_str )
246+ )
247+ accounts = resp .get ("Items" , [])
248+ user .oauth_accounts = [ # type: ignore
249+ self .oauth_account_table (** acc ) for acc in accounts
250+ ]
251+
252+ return user
221253
222254 async def get_by_email (self , email : str ) -> UP | None :
223255 """Get a user by email (case-insensitive: emails are stored lowercased)."""
@@ -230,7 +262,25 @@ async def get_by_email(self, email: str) -> UP | None:
230262 items = resp .get ("Items" , [])
231263 if not items :
232264 return None
233- return self ._item_to_user (items [0 ])
265+ user = self ._item_to_user (items [0 ])
266+
267+ if user is None :
268+ return None
269+
270+ user_id = self ._ensure_id_str (user .id )
271+ if self .oauth_account_table and self .oauth_account_table_name :
272+ async with self ._table (
273+ self .oauth_account_table_name , self ._resource_region
274+ ) as oauth_table :
275+ resp = await oauth_table .scan (
276+ FilterExpression = Attr ("user_id" ).eq (user_id )
277+ )
278+ accounts = resp .get ("Items" , [])
279+ user .oauth_accounts = [ # type: ignore
280+ self .oauth_account_table (** acc ) for acc in accounts
281+ ]
282+
283+ return user
234284
235285 async def get_by_oauth_account (self , oauth : str , account_id : str ) -> UP | None :
236286 """Find a user by oauth provider and provider account id."""
@@ -250,7 +300,7 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> UP | None:
250300 return None
251301
252302 user_id = items [0 ].get ("user_id" )
253- if user_id is None :
303+ if not user_id :
254304 return None
255305
256306 return await self .get (user_id )
@@ -268,7 +318,7 @@ async def create(self, create_dict: dict[str, Any]) -> UP:
268318 async with self ._table (self .user_table_name , self ._resource_region ) as table :
269319 try :
270320 await table .put_item (
271- Item = item ,
321+ Item = self . _serialize_for_dynamodb ( item ) ,
272322 ConditionExpression = "attribute_not_exists(#id)" ,
273323 ExpressionAttributeNames = {"#id" : self .primary_key },
274324 )
@@ -301,7 +351,7 @@ async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
301351
302352 try :
303353 await table .put_item (
304- Item = merged ,
354+ Item = self . _serialize_for_dynamodb ( merged ) ,
305355 ConditionExpression = "attribute_exists(#id)" ,
306356 ExpressionAttributeNames = {"#id" : self .primary_key },
307357 )
@@ -351,14 +401,10 @@ async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP:
351401 async with self ._table (
352402 self .oauth_account_table_name , self ._resource_region
353403 ) as oauth_table :
354- await oauth_table .put_item (Item = oauth_item )
404+ await oauth_table .put_item (Item = self . _serialize_for_dynamodb ( oauth_item ) )
355405
356406 if hasattr (user , "oauth_accounts" ):
357- oauth_obj = (
358- self .oauth_account_table (** oauth_item )
359- if self .oauth_account_table is not None
360- else oauth_item
361- )
407+ oauth_obj = self .oauth_account_table (** oauth_item )
362408 user .oauth_accounts .append (oauth_obj ) # type: ignore
363409
364410 return user
@@ -368,38 +414,28 @@ async def update_oauth_account(
368414 user : UP ,
369415 oauth_account : OAP , # type: ignore
370416 update_dict : dict [str , Any ],
371- ) -> UP :
417+ ) -> UP | None :
372418 """Update an OAuth account and return the refreshed user (UP)."""
373419 if self .oauth_account_table is None or self .oauth_account_table_name is None :
374420 raise NotImplementedError ()
375421
376- oauth_id = None
377- if isinstance (oauth_account , dict ):
378- oauth_id = oauth_account .get ("id" )
379- elif hasattr (oauth_account , "model_dump" ) and callable (
380- getattr (oauth_account , "model_dump" )
381- ):
382- try :
383- oauth_id = oauth_account .model_dump ().get ("id" ) # type: ignore
384- except Exception :
385- oauth_id = getattr (oauth_account , "id" , None )
386- elif hasattr (oauth_account , "id" ):
387- oauth_id = getattr (oauth_account , "id" , None )
388- elif hasattr (oauth_account , "__dict__" ):
389- oauth_id = vars (oauth_account ).get ("id" )
422+ oauth_item = (
423+ oauth_account .model_dump () # type: ignore
424+ if hasattr (oauth_account , "model_dump" )
425+ else vars (oauth_account )
426+ )
390427
391- if oauth_id is None :
392- raise ValueError ("oauth_account has no 'id' field" )
428+ updated_item = {** oauth_item , ** update_dict }
393429
394- oauth_id_str = self ._ensure_id_str (oauth_id )
430+ for field in ("id" , "user_id" , "oauth_name" , "account_id" ):
431+ updated_item [field ] = getattr (oauth_account , field , oauth_item .get (field ))
395432
396433 async with self ._table (
397434 self .oauth_account_table_name , self ._resource_region
398435 ) as oauth_table :
399- merged = {** update_dict }
400436 try :
401437 await oauth_table .put_item (
402- Item = merged ,
438+ Item = self . _serialize_for_dynamodb ( updated_item ) ,
403439 ConditionExpression = "attribute_exists(#id)" ,
404440 ExpressionAttributeNames = {"#id" : self .primary_key },
405441 )
@@ -408,20 +444,15 @@ async def update_oauth_account(
408444 e .response .get ("Error" , {}).get ("Code" )
409445 == "ConditionalCheckFailedException"
410446 ):
411- raise ValueError (f"User { oauth_id_str } already exists." )
447+ raise ValueError (
448+ f"OAuth account with ID { updated_item ['id' ]} does not exist."
449+ )
412450 raise
413451
414452 if hasattr (user , "oauth_accounts" ):
415- existing_oauth = next (
416- (
417- account
418- for account in user .oauth_accounts # type: ignore
419- if account .id == oauth_id_str
420- ),
421- None ,
422- )
423- if existing_oauth :
424- index = user .oauth_accounts .index (existing_oauth ) # type: ignore
425- user .oauth_accounts [index ] = merged # type: ignore
453+ for idx , account in enumerate (user .oauth_accounts ): # type: ignore
454+ if str (getattr (account , "id" , None )) == str (updated_item ["id" ]):
455+ user .oauth_accounts [idx ] = type (oauth_account )(** updated_item ) # type: ignore
456+ break
426457
427- return user
458+ return await self . get ( user . id )
0 commit comments