diff --git a/.gitignore b/.gitignore index 8ed2328..245d384 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ .cache .python-version .venv +venv/ .idea/ /build/ @@ -17,3 +18,4 @@ # Editors .idea/ +.vscode/ diff --git a/README.rst b/README.rst index a8f9f6b..895c6e8 100644 --- a/README.rst +++ b/README.rst @@ -72,3 +72,57 @@ you can use the ``COGNITO_USER_MODEL`` setting. .. code-block:: python COGNITO_USER_MODEL = "myproject.AppUser" + +The library by default uses id token. To use access token, add the following lines to your Django ``settings.py`` file: + +.. code-block:: python + + COGNITO_TOKEN_TYPE = "access" # {'id', 'access'}, default 'id' + + +As the payload of access token only contains basic user info, we could obtain further info from the `UserInfo endpoint`. +You need to specify the Cognito domain in the ``settings.py`` file to obtain the user info from the endpoint, as follows: + +.. code-block:: python + + COGNITO_DOMAIN = "your-user-pool-domain" # eg, exampledomain.auth.ap-southeast-1.amazoncognito.com + +To use the backend functions, at the DJANGO_USER_MODEL, could define methods as follows: + +.. code-block:: python + + class CustomizedUserManager(UserManager): + def get_user(self, payload): + cognito_id = payload['sub'] + try: + return self.get(cognito_id=cognito_id) + except self.model.DoesNotExist: + return None + + def create_for_cognito(self, payload): + """Get any value from `payload` here + ipdb> pprint(payload) + {'aud': '159ufjrihgehb67sn373aotli7', + 'auth_time': 1583503962, + 'cognito:username': 'john-rambo', + 'email': 'foggygiga@gmail.com', + 'email_verified': True, + 'event_id': 'd92a99c2-c49e-4312-8a57-c0dccb84f1c3', + 'exp': 1583507562, + 'iat': 1583503962, + 'iss': 'https://cognito-idp.us-west-2.amazonaws.com/us-west-2_flCJaoDig', + 'sub': '2e4790a0-35a4-45d7-b10c-ced79be22e94', + 'token_use': 'id'} + """ + cognito_id = payload['sub'] + + try: + user = self.create( + username= payload["cognito:username"] if payload.get("cognito:username") else payload["username"], + cognito_id=cognito_id, + email=payload['email'], + is_active=True) + except IntegrityError: + user = self.get(cognito_id=cognito_id) + + return user \ No newline at end of file diff --git a/src/django_cognito_jwt/backend.py b/src/django_cognito_jwt/backend.py index 941c0f5..8c5a3f4 100644 --- a/src/django_cognito_jwt/backend.py +++ b/src/django_cognito_jwt/backend.py @@ -1,4 +1,6 @@ import logging +import requests +import json from django.apps import apps as django_apps from django.conf import settings @@ -29,12 +31,31 @@ def authenticate(self, request): raise exceptions.AuthenticationFailed() USER_MODEL = self.get_user_model() - user = USER_MODEL.objects.get_or_create_for_cognito(jwt_payload) + user = USER_MODEL.objects.get_user(jwt_payload) + if not user: + # Create new user if not exists + payload = jwt_payload + if settings.COGNITO_TOKEN_TYPE == "access": + user_info = self.get_user_info(jwt_token.decode("UTF-8")) + user_info = json.loads(user_info.decode("UTF-8")) + payload = user_info + + user = USER_MODEL.objects.create_for_cognito(payload) + return (user, jwt_token) def get_user_model(self): user_model = getattr(settings, "COGNITO_USER_MODEL", settings.AUTH_USER_MODEL) return django_apps.get_model(user_model, require_ready=False) + + def get_user_info(self, access_token): + if settings.COGNITO_TOKEN_TYPE == "access": + url = f"https://{settings.COGNITO_DOMAIN}/oauth2/userInfo" + + headers = {'Authorization': f'Bearer {access_token}'} + + res = requests.get(url, headers=headers) + return res.content def get_jwt_token(self, request): auth = get_authorization_header(request).split() @@ -58,6 +79,7 @@ def get_token_validator(self, request): settings.COGNITO_AWS_REGION, settings.COGNITO_USER_POOL, settings.COGNITO_AUDIENCE, + settings.COGNITO_TOKEN_TYPE, ) def authenticate_header(self, request): diff --git a/src/django_cognito_jwt/validator.py b/src/django_cognito_jwt/validator.py index 80d7546..df3e120 100644 --- a/src/django_cognito_jwt/validator.py +++ b/src/django_cognito_jwt/validator.py @@ -1,4 +1,5 @@ import json +from typing import Literal import jwt import requests @@ -13,10 +14,14 @@ class TokenError(Exception): class TokenValidator: - def __init__(self, aws_region, aws_user_pool, audience): + def __init__(self, aws_region, aws_user_pool, audience, token_type: Literal["id", "access"] = "id"): self.aws_region = aws_region self.aws_user_pool = aws_user_pool self.audience = audience + self.token_type = token_type + + if token_type not in ["id", "access"]: + raise TokenError("Invalid token type. Choose either id or access token.") @cached_property def pool_url(self): @@ -58,13 +63,21 @@ def validate(self, token): raise TokenError("No key found for this token") try: - jwt_data = jwt.decode( - token, - public_key, - audience=self.audience, - issuer=self.pool_url, - algorithms=["RS256"], - ) + params = { + "jwt": token, + "key": public_key, + "issuer": self.pool_url, + "algorithms": ["RS256"] + } + if self.token_type == "id": + params.update({"audience": self.audience}) + + jwt_data = jwt.decode(**params) + if self.token_type == "access": + if "access" not in jwt_data["token_use"]: + raise TokenError("Incorrect token use") + if jwt_data["client_id"] not in self.audience: + raise TokenError("Incorrect client_id") except ( jwt.InvalidTokenError, jwt.ExpiredSignatureError,