From 0e50c304cac40014c42795e6f06a4a19a3a07214 Mon Sep 17 00:00:00 2001 From: Tristan Sweeney Date: Tue, 8 Nov 2022 15:18:04 -0500 Subject: [PATCH 1/4] Select JWK by `kid` to get around python-jose bug Python jose has a bug where it'll crash with a JWK type mismatch when validating a JWT against a heterogeneous set of JWKs. I opened a PR against the project to manage that case better in accordance with the JOSE RFC, but that project is is poorly maintained and so it may be worthwhile to work around it for now. Also, it may be worth switching to PyJWT, which supports this use case just as well and is better maintained. --- fastapi_third_party_auth/auth.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/fastapi_third_party_auth/auth.py b/fastapi_third_party_auth/auth.py index 850211d..27b72d1 100644 --- a/fastapi_third_party_auth/auth.py +++ b/fastapi_third_party_auth/auth.py @@ -33,9 +33,8 @@ def test_auth(authenticated_user: IDToken = Security(auth.required)): from fastapi.security import OAuth2 from fastapi.security import SecurityScopes from jose import ExpiredSignatureError -from jose import JWTError from jose import jwt -from jose.exceptions import JWTClaimsError +from jose.exceptions import JWTClaimsError, JWKError, JWTError, JWSError from fastapi_third_party_auth import discovery from fastapi_third_party_auth.grant_types import GrantType @@ -189,6 +188,30 @@ def optional( auto_error=False, ) + + def _find_key(self, token: str) -> dict: + oidc_discoveries = self.discover.auth_server( + openid_connect_url=self.openid_connect_url + ) + keys = self.discover.public_keys(oidc_discoveries) + + header = jwt.get_unverified_header(token) + try: + kid = header['kid'] + except KeyError as e: + raise JWTError("field 'kid' is missing from JWT headers") from e + + for key in keys: + try: + key_kid = key['kid'] + except KeyError as e: + raise JWKError("field 'kid' is missing from JWK") from e + if key_kid == kid: + return key + + raise JWKError(f"Could not find JWK 'kid'={kid}") + + def authenticate_user( self, security_scopes: SecurityScopes, @@ -226,8 +249,8 @@ def authenticate_user( oidc_discoveries = self.discover.auth_server( openid_connect_url=self.openid_connect_url ) - key = self.discover.public_keys(oidc_discoveries) algorithms = self.discover.signing_algos(oidc_discoveries) + key = self._find_key(authorization_credentials.credentials) try: id_token = jwt.decode( From de8ea2fcd8f3a81b48d4973b2c8b5c4a37f75f99 Mon Sep 17 00:00:00 2001 From: Tristan Sweeney <76963169+tsweeney-dust@users.noreply.github.com> Date: Wed, 9 Nov 2022 12:32:30 -0500 Subject: [PATCH 2/4] Update auth.py --- fastapi_third_party_auth/auth.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fastapi_third_party_auth/auth.py b/fastapi_third_party_auth/auth.py index 27b72d1..dd31534 100644 --- a/fastapi_third_party_auth/auth.py +++ b/fastapi_third_party_auth/auth.py @@ -193,7 +193,10 @@ def _find_key(self, token: str) -> dict: oidc_discoveries = self.discover.auth_server( openid_connect_url=self.openid_connect_url ) - keys = self.discover.public_keys(oidc_discoveries) + try: + keys = self.discover.public_keys(oidc_discoveries)["keys"] + except KeyError as e: + raise JWKError("Badly formed JWKs_uri") from e header = jwt.get_unverified_header(token) try: From d6fb15d55a889f3399e73ad8976dd0e0d8b61811 Mon Sep 17 00:00:00 2001 From: Tristan Sweeney Date: Wed, 21 Dec 2022 15:57:29 -0500 Subject: [PATCH 3/4] Avoid keyerror when "aud" isn't set --- fastapi_third_party_auth/auth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastapi_third_party_auth/auth.py b/fastapi_third_party_auth/auth.py index dd31534..5e43ad4 100644 --- a/fastapi_third_party_auth/auth.py +++ b/fastapi_third_party_auth/auth.py @@ -271,7 +271,8 @@ def authenticate_user( ) if ( - type(id_token["aud"]) == list + "aud" in id_token + and type(id_token["aud"]) == list and len(id_token["aud"]) >= 1 and "azp" not in id_token ): From 2ca789e7b5c376f4fe846ec9d924d74d4aba9d08 Mon Sep 17 00:00:00 2001 From: Tristan Sweeney Date: Mon, 9 Jan 2023 16:37:38 -0500 Subject: [PATCH 4/4] Fail more cleanly on inability to talk to auth server --- fastapi_third_party_auth/auth.py | 45 +++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/fastapi_third_party_auth/auth.py b/fastapi_third_party_auth/auth.py index 5e43ad4..5dd792d 100644 --- a/fastapi_third_party_auth/auth.py +++ b/fastapi_third_party_auth/auth.py @@ -15,6 +15,7 @@ def test_auth(authenticated_user: IDToken = Security(auth.required)): return f"Hello {authenticated_user.preferred_username}" """ +from logging import getLogger from typing import List from typing import Optional from typing import Type @@ -35,11 +36,14 @@ def test_auth(authenticated_user: IDToken = Security(auth.required)): from jose import ExpiredSignatureError from jose import jwt from jose.exceptions import JWTClaimsError, JWKError, JWTError, JWSError +from requests.exceptions import ConnectionError from fastapi_third_party_auth import discovery from fastapi_third_party_auth.grant_types import GrantType from fastapi_third_party_auth.idtoken_types import IDToken +logger = getLogger(__name__) + class Auth(OAuth2): def __init__( @@ -80,8 +84,19 @@ def __init__( self.client_id = client_id self.idtoken_model = idtoken_model self.scopes = scopes - + self.discover = discovery.configure(cache_ttl=signature_cache_ttl) + self.grant_types = grant_types + + try: + flows = self.get_flows() + except ConnectionError as e: + logger.warning("Could not discover OIDC flows %s", e) + flows = OAuthFlows() + + super().__init__(scheme_name="OIDC", flows=flows, auto_error=False) + + def get_flows(self) -> OAuthFlows: oidc_discoveries = self.discover.auth_server( openid_connect_url=self.openid_connect_url ) @@ -90,36 +105,32 @@ def __init__( # } flows = OAuthFlows() - if GrantType.AUTHORIZATION_CODE in grant_types: + if GrantType.AUTHORIZATION_CODE in self.grant_types: flows.authorizationCode = OAuthFlowAuthorizationCode( authorizationUrl=self.discover.authorization_url(oidc_discoveries), tokenUrl=self.discover.token_url(oidc_discoveries), # scopes=scopes_dict, ) - if GrantType.CLIENT_CREDENTIALS in grant_types: + if GrantType.CLIENT_CREDENTIALS in self.grant_types: flows.clientCredentials = OAuthFlowClientCredentials( tokenUrl=self.discover.token_url(oidc_discoveries), # scopes=scopes_dict, ) - if GrantType.PASSWORD in grant_types: + if GrantType.PASSWORD in self.grant_types: flows.password = OAuthFlowPassword( tokenUrl=self.discover.token_url(oidc_discoveries), # scopes=scopes_dict, ) - if GrantType.IMPLICIT in grant_types: + if GrantType.IMPLICIT in self.grant_types: flows.implicit = OAuthFlowImplicit( authorizationUrl=self.discover.authorization_url(oidc_discoveries), # scopes=scopes_dict, ) - - super().__init__( - scheme_name="OIDC", - flows=flows, - auto_error=False, - ) + + return flows async def __call__(self, request: Request) -> None: return None @@ -248,10 +259,14 @@ def authenticate_user( ) else: return None - - oidc_discoveries = self.discover.auth_server( - openid_connect_url=self.openid_connect_url - ) + + try: + oidc_discoveries = self.discover.auth_server( + openid_connect_url=self.openid_connect_url + ) + except ConnectionError as e: + logger.warning("Could not reach auth server %e", e) + raise HTTPException(503, detail="Could not reach auth server") from e algorithms = self.discover.signing_algos(oidc_discoveries) key = self._find_key(authorization_credentials.credentials)