Skip to content

Commit f6f6ecd

Browse files
committed
fix region param
1 parent 1741b9a commit f6f6ecd

File tree

5 files changed

+69
-27
lines changed

5 files changed

+69
-27
lines changed

fastapi_users_db_dynamodb/__init__.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,28 +90,31 @@ class DynamoDBUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]):
9090
user_table: type[UP]
9191
oauth_account_table: type[DynamoDBBaseOAuthAccountTable] | None
9292
user_table_name: str
93-
oauth_table_name: str | None
93+
oauth_account_table_name: str | None
9494
_resource: Any | None
95+
_resource_region: str | None
9596

9697
def __init__(
9798
self,
9899
session: aioboto3.Session,
99100
user_table: type[UP],
100101
user_table_name: str,
101102
oauth_account_table: type[DynamoDBBaseOAuthAccountTable] | None = None,
102-
oauth_table_name: str | None = None,
103+
oauth_account_table_name: str | None = None,
103104
dynamodb_resource: Any | None = None,
105+
dynamodb_resource_region: str | None = None,
104106
):
105107
self.session = session
106108
self.user_table = user_table
107109
self.oauth_account_table = oauth_account_table
108110
self.user_table_name = user_table_name
109-
self.oauth_table_name = oauth_table_name
111+
self.oauth_account_table_name = oauth_account_table_name
110112

111113
self._resource = dynamodb_resource
114+
self._resource_region = dynamodb_resource_region
112115

113116
@asynccontextmanager
114-
async def _table(self, table_name: str):
117+
async def _table(self, table_name: str, region: str | None = None):
115118
"""Async context manager that yields a Table object.
116119
117120
If a long-lived resource was provided at init, it's reused (no enter/exit).
@@ -121,7 +124,13 @@ async def _table(self, table_name: str):
121124
table = await self._resource.Table(table_name)
122125
yield table
123126
else:
124-
async with self.session.resource("dynamodb") as dynamodb:
127+
if region is None:
128+
raise ValueError(
129+
"Parameter `region` must be specified when `dynamodb_resource` is omitted"
130+
)
131+
async with self.session.resource(
132+
"dynamodb", region_name=region
133+
) as dynamodb:
125134
table = await dynamodb.Table(table_name)
126135
yield table
127136

@@ -178,15 +187,15 @@ def _ensure_email_lower(self, data: dict[str, Any]) -> None:
178187
async def get(self, id: ID | str) -> UP | None:
179188
"""Get a user by id."""
180189
id_str = self._ensure_id_str(id)
181-
async with self._table(self.user_table_name) as table:
190+
async with self._table(self.user_table_name, self._resource_region) as table:
182191
resp = await table.get_item(Key={"id": id_str})
183192
item = resp.get("Item")
184193
return self._item_to_user(item)
185194

186195
async def get_by_email(self, email: str) -> UP | None:
187196
"""Get a user by email (case-insensitive: emails are stored lowercased)."""
188197
email_norm = email.lower()
189-
async with self._table(self.user_table_name) as table:
198+
async with self._table(self.user_table_name, self._resource_region) as table:
190199
resp = await table.scan(
191200
FilterExpression=Attr("email").eq(email_norm),
192201
Limit=1,
@@ -198,10 +207,12 @@ async def get_by_email(self, email: str) -> UP | None:
198207

199208
async def get_by_oauth_account(self, oauth: str, account_id: str) -> UP | None:
200209
"""Find a user by oauth provider and provider account id."""
201-
if self.oauth_account_table is None or self.oauth_table_name is None:
210+
if self.oauth_account_table is None or self.oauth_account_table_name is None:
202211
raise NotImplementedError()
203212

204-
async with self._table(self.oauth_table_name) as oauth_table:
213+
async with self._table(
214+
self.oauth_account_table_name, self._resource_region
215+
) as oauth_table:
205216
resp = await oauth_table.scan(
206217
FilterExpression=Attr("oauth_name").eq(oauth)
207218
& Attr("account_id").eq(account_id),
@@ -227,7 +238,7 @@ async def create(self, create_dict: dict[str, Any]) -> UP:
227238

228239
self._ensure_email_lower(item)
229240

230-
async with self._table(self.user_table_name) as table:
241+
async with self._table(self.user_table_name, self._resource_region) as table:
231242
await table.put_item(Item=item)
232243

233244
resp = await table.get_item(Key={"id": item["id"]})
@@ -241,7 +252,7 @@ async def create(self, create_dict: dict[str, Any]) -> UP:
241252
async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
242253
"""Update a user with update_dict and return the updated UP instance."""
243254
user_id = self._extract_id_from_user(user)
244-
async with self._table(self.user_table_name) as table:
255+
async with self._table(self.user_table_name, self._resource_region) as table:
245256
resp = await table.get_item(Key={"id": user_id})
246257
current = resp.get("Item", {})
247258
if not current:
@@ -263,12 +274,12 @@ async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
263274
async def delete(self, user: UP) -> None:
264275
"""Delete a user."""
265276
user_id = self._extract_id_from_user(user)
266-
async with self._table(self.user_table_name) as table:
277+
async with self._table(self.user_table_name, self._resource_region) as table:
267278
await table.delete_item(Key={"id": user_id})
268279

269280
async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP:
270281
"""Add an OAuth account for `user`. Returns the refreshed user (UP)."""
271-
if self.oauth_account_table is None or self.oauth_table_name is None:
282+
if self.oauth_account_table is None or self.oauth_account_table_name is None:
272283
raise NotImplementedError()
273284

274285
oauth_item = dict(create_dict)
@@ -278,7 +289,9 @@ async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP:
278289
user_id = self._extract_id_from_user(user)
279290
oauth_item["user_id"] = user_id
280291

281-
async with self._table(self.oauth_table_name) as oauth_table:
292+
async with self._table(
293+
self.oauth_account_table_name, self._resource_region
294+
) as oauth_table:
282295
await oauth_table.put_item(Item=oauth_item)
283296

284297
refreshed_user = await self.get(user_id)
@@ -305,7 +318,7 @@ async def update_oauth_account(
305318
update_dict: dict[str, Any],
306319
) -> UP:
307320
"""Update an OAuth account and return the refreshed user (UP)."""
308-
if self.oauth_account_table is None or self.oauth_table_name is None:
321+
if self.oauth_account_table is None or self.oauth_account_table_name is None:
309322
raise NotImplementedError()
310323

311324
oauth_id = None
@@ -328,7 +341,9 @@ async def update_oauth_account(
328341

329342
oauth_id_str = self._ensure_id_str(oauth_id)
330343

331-
async with self._table(self.oauth_table_name) as oauth_table:
344+
async with self._table(
345+
self.oauth_account_table_name, self._resource_region
346+
) as oauth_table:
332347
resp = await oauth_table.get_item(Key={"id": oauth_id_str})
333348
current = resp.get("Item", {})
334349
if not current:

fastapi_users_db_dynamodb/access_token.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import uuid
1010
from contextlib import asynccontextmanager
11-
from datetime import datetime
11+
from datetime import datetime, timezone
1212
from typing import TYPE_CHECKING, Any, Generic, get_type_hints
1313

1414
import aioboto3
@@ -45,27 +45,36 @@ class DynamoDBAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]):
4545
access_token_table: type[AP]
4646
table_name: str
4747
_resource: Any | None
48+
_resource_region: str | None
4849

4950
def __init__(
5051
self,
5152
session: aioboto3.Session,
5253
access_token_table: type[AP],
5354
table_name: str,
5455
dynamodb_resource: Any | None = None,
56+
dynamodb_resource_region: Any | None = None,
5557
):
5658
self.session = session
5759
self.access_token_table = access_token_table
5860
self.table_name = table_name
5961
self._resource = dynamodb_resource
62+
self._resource_region = dynamodb_resource_region
6063

6164
@asynccontextmanager
62-
async def _table(self, table_name: str):
65+
async def _table(self, table_name: str, region: str | None = None):
6366
"""Async context manager that yields a Table object."""
6467
if self._resource is not None:
6568
table = await self._resource.Table(table_name)
6669
yield table
6770
else:
68-
async with self.session.resource("dynamodb") as dynamodb:
71+
if region is None:
72+
raise ValueError(
73+
"Parameter `region` must be specified when `dynamodb_resource` is omitted"
74+
)
75+
async with self.session.resource(
76+
"dynamodb", region_name=region
77+
) as dynamodb:
6978
table = await dynamodb.Table(table_name)
7079
yield table
7180

@@ -98,7 +107,7 @@ async def get_by_token(
98107
self, token: str, max_age: datetime | None = None
99108
) -> AP | None:
100109
"""Retrieve an access token by token string."""
101-
async with self._table(self.table_name) as table:
110+
async with self._table(self.table_name, self._resource_region) as table:
102111
resp = await table.get_item(Key={"token": self._ensure_token(token)})
103112
item = resp.get("Item")
104113

@@ -119,11 +128,11 @@ async def create(self, create_dict: dict[str, Any]) -> AP:
119128
if "token" not in item or item["token"] is None:
120129
item["token"] = uuid.uuid4().hex[:43]
121130
if "created_at" not in item or not isinstance(item["created_at"], str):
122-
item["created_at"] = datetime.utcnow().isoformat()
131+
item["created_at"] = datetime.now(timezone.utc).isoformat()
123132
if isinstance(item.get("user_id"), uuid.UUID):
124133
item["user_id"] = str(item["user_id"])
125134

126-
async with self._table(self.table_name) as table:
135+
async with self._table(self.table_name, self._resource_region) as table:
127136
await table.put_item(Item=item)
128137

129138
resp = await table.get_item(Key={"token": item["token"]})
@@ -151,7 +160,7 @@ async def update(self, access_token: AP, update_dict: dict[str, Any]) -> AP:
151160
if isinstance(token_dict.get("created_at"), datetime):
152161
token_dict["created_at"] = token_dict["created_at"].isoformat()
153162

154-
async with self._table(self.table_name) as table:
163+
async with self._table(self.table_name, self._resource_region) as table:
155164
await table.put_item(Item=token_dict)
156165

157166
resp = await table.get_item(Key={"token": token_dict["token"]})
@@ -170,5 +179,5 @@ async def delete(self, access_token: AP) -> None:
170179
if token is None:
171180
raise ValueError("access_token has no 'token' field")
172181

173-
async with self._table(self.table_name) as table:
182+
async with self._table(self.table_name, self._resource_region) as table:
174183
await table.delete_item(Key={"token": self._ensure_token(token)})

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pytest
55
from fastapi_users import schemas
66

7+
DATABASE_REGION: str = "eu-central-1"
8+
79

810
class User(schemas.BaseUser):
911
first_name: Optional[str]

tests/test_access_token.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
DynamoDBAccessTokenDatabase,
1414
DynamoDBBaseAccessTokenTableUUID,
1515
)
16+
from tests.conftest import DATABASE_REGION
1617

1718

1819
class Base:
@@ -42,7 +43,10 @@ async def dynamodb_access_token_db(
4243
token_table_name = "access_tokens_test"
4344

4445
user_db = DynamoDBUserDatabase(
45-
session, DynamoDBBaseUserTableUUID, user_table_name
46+
session,
47+
DynamoDBBaseUserTableUUID,
48+
user_table_name,
49+
dynamodb_resource_region=DATABASE_REGION,
4650
)
4751
user = await user_db.create(
4852
{
@@ -52,7 +56,12 @@ async def dynamodb_access_token_db(
5256
}
5357
)
5458

55-
token_db = DynamoDBAccessTokenDatabase(session, AccessToken, token_table_name)
59+
token_db = DynamoDBAccessTokenDatabase(
60+
session,
61+
AccessToken,
62+
token_table_name,
63+
dynamodb_resource_region=DATABASE_REGION,
64+
)
5665

5766
# Vorherigen Token löschen, falls er existiert
5867
token_obj = await token_db.get_by_token("TOKEN")

tests/test_users.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
DynamoDBBaseUserTableUUID,
1212
DynamoDBUserDatabase,
1313
)
14+
from tests.conftest import DATABASE_REGION
1415

1516

1617
class Base:
@@ -40,7 +41,12 @@ async def dynamodb_user_db() -> AsyncGenerator[DynamoDBUserDatabase, None]:
4041
session = aioboto3.Session()
4142
table_name = "users_test"
4243

43-
db = DynamoDBUserDatabase(session, DynamoDBBaseUserTableUUID, table_name)
44+
db = DynamoDBUserDatabase(
45+
session,
46+
DynamoDBBaseUserTableUUID,
47+
table_name,
48+
dynamodb_resource_region=DATABASE_REGION,
49+
)
4450
yield db
4551

4652

@@ -57,6 +63,7 @@ async def dynamodb_user_db_oauth() -> AsyncGenerator[DynamoDBUserDatabase, None]
5763
user_table_name,
5864
OAuthAccount,
5965
oauth_table_name,
66+
dynamodb_resource_region=DATABASE_REGION,
6067
)
6168
yield db
6269

0 commit comments

Comments
 (0)