diff --git a/fastapi_azure_auth/auth.py b/fastapi_azure_auth/auth.py index 5a8541c..3513610 100644 --- a/fastapi_azure_auth/auth.py +++ b/fastapi_azure_auth/auth.py @@ -36,6 +36,8 @@ if TYPE_CHECKING: # pragma: no cover from jwt.algorithms import AllowedPublicKeys + from fastapi_azure_auth.openid_config import HttpClientConfig + log = logging.getLogger('fastapi_azure_auth') @@ -57,6 +59,7 @@ def __init__( openid_config_url: Optional[str] = None, openapi_description: Optional[str] = None, scheme_name: str = "AzureAuthorizationCodeBearerBase", + http_client_config: Optional["HttpClientConfig"] = None, ) -> None: """ Initialize settings. @@ -107,6 +110,8 @@ def __init__( :param scheme_name: str The name of the security scheme to be used in OpenAPI documentation. Default is 'AzureAuthorizationCodeBearerBase'. + :param http_client_config: HttpClientConfig + Configuration for the HTTP client used to fetch the OpenID configuration. """ self.auto_error = auto_error # Validate settings, making sure there's no misconfigured dependencies out there @@ -123,6 +128,7 @@ def __init__( multi_tenant=self.multi_tenant, app_id=app_client_id if openid_config_use_app_id else None, config_url=openid_config_url or None, + http_client_config=http_client_config, ) self.leeway: int = leeway @@ -302,6 +308,7 @@ def __init__( openapi_token_url: Optional[str] = None, openapi_description: Optional[str] = None, scheme_name: str = "AzureAD_PKCE_single_tenant", + http_client_config: Optional["HttpClientConfig"] = None, ) -> None: """ Initialize settings for a single tenant application. @@ -340,6 +347,8 @@ def __init__( :param scheme_name: str The name of the security scheme to be used in OpenAPI documentation. Default is 'AzureAD_PKCE_single_tenant'. + :param http_client_config: HttpClientConfig + Configuration for the HTTP client used to fetch the OpenID configuration. """ super().__init__( app_client_id=app_client_id, @@ -352,6 +361,7 @@ def __init__( openapi_authorization_url=openapi_authorization_url, openapi_token_url=openapi_token_url, openapi_description=openapi_description, + http_client_config=http_client_config, ) self.scheme_name: str = scheme_name @@ -371,6 +381,7 @@ def __init__( openapi_token_url: Optional[str] = None, openapi_description: Optional[str] = None, scheme_name: str = "AzureAD_PKCE_multi_tenant", + http_client_config: Optional["HttpClientConfig"] = None, ) -> None: """ Initialize settings for a multi-tenant application. @@ -414,6 +425,8 @@ def __init__( :param scheme_name: str The name of the security scheme to be used in OpenAPI documentation. Default is 'AzureAD_PKCE_multi_tenant'. + :param http_client_config: HttpClientConfig + Configuration for the HTTP client used to fetch the OpenID configuration. """ super().__init__( app_client_id=app_client_id, @@ -428,6 +441,7 @@ def __init__( openapi_authorization_url=openapi_authorization_url, openapi_token_url=openapi_token_url, openapi_description=openapi_description, + http_client_config=http_client_config, ) self.scheme_name: str = scheme_name @@ -447,6 +461,7 @@ def __init__( openapi_token_url: Optional[str] = None, openapi_description: Optional[str] = None, scheme_name: str = "AzureAD_PKCE_B2C_multi_tenant", + http_client_config: Optional["HttpClientConfig"] = None, ) -> None: """ Initialize settings for a B2C multi-tenant application. @@ -485,6 +500,8 @@ def __init__( :param scheme_name: str The name of the security scheme to be used in OpenAPI documentation. Default is 'AzureAD_PKCE_B2C_multi_tenant'. + :param http_client_config: HttpClientConfig + Configuration for the HTTP client used to fetch the OpenID configuration. """ super().__init__( app_client_id=app_client_id, @@ -500,5 +517,6 @@ def __init__( openapi_authorization_url=openapi_authorization_url, openapi_token_url=openapi_token_url, openapi_description=openapi_description, + http_client_config=http_client_config, ) self.scheme_name: str = scheme_name diff --git a/fastapi_azure_auth/openid_config.py b/fastapi_azure_auth/openid_config.py index 470fc4a..943639a 100644 --- a/fastapi_azure_auth/openid_config.py +++ b/fastapi_azure_auth/openid_config.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging +import ssl from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, TypedDict import jwt from fastapi import HTTPException, status @@ -8,26 +11,42 @@ if TYPE_CHECKING: # pragma: no cover from jwt.algorithms import AllowedPublicKeys + from typing_extensions import NotRequired # added in python 3.11 log = logging.getLogger('fastapi_azure_auth') +class HttpClientConfig(TypedDict): + """ + Configuration for the HTTP client used to fetch the OpenID configuration. + + verify - (optional) Either `True` to use an SSL context with the default CA bundle, + `False` to disable verification, or an instance of `ssl.SSLContext` to use a custom context. + trust_env - (optional) Enables or disables usage of environment variables for configuration. + """ + + verify: NotRequired[ssl.SSLContext | bool] + trust_env: NotRequired[bool] + + class OpenIdConfig: def __init__( self, - tenant_id: Optional[str] = None, + tenant_id: str | None = None, multi_tenant: bool = False, - app_id: Optional[str] = None, - config_url: Optional[str] = None, + app_id: str | None = None, + config_url: str | None = None, + http_client_config: HttpClientConfig | None = None, ) -> None: - self.tenant_id: Optional[str] = tenant_id - self._config_timestamp: Optional[datetime] = None + self.tenant_id: str | None = tenant_id + self._config_timestamp: datetime | None = None self.multi_tenant: bool = multi_tenant self.app_id = app_id self.config_url = config_url + self.http_client_config: HttpClientConfig = http_client_config or HttpClientConfig() self.authorization_endpoint: str - self.signing_keys: dict[str, 'AllowedPublicKeys'] + self.signing_keys: dict[str, AllowedPublicKeys] self.token_endpoint: str self.issuer: str @@ -72,7 +91,7 @@ async def _load_openid_config(self) -> None: if self.app_id: config_url += f'?appid={self.app_id}' - async with AsyncClient(timeout=10) as client: + async with AsyncClient(timeout=10, **self.http_client_config) as client: log.info('Fetching OpenID Connect config from %s', config_url) openid_response = await client.get(config_url) openid_response.raise_for_status() @@ -88,7 +107,7 @@ async def _load_openid_config(self) -> None: jwks_response.raise_for_status() self._load_keys(jwks_response.json()['keys']) - def _load_keys(self, keys: List[Dict[str, Any]]) -> None: + def _load_keys(self, keys: list[dict[str, Any]]) -> None: """ Create certificates based on signing keys and store them """