Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions fastapi_azure_auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
if TYPE_CHECKING: # pragma: no cover
from jwt.algorithms import AllowedPublicKeys

from fastapi_azure_auth.openid_config import HttpClientConfig

log = logging.getLogger('fastapi_azure_auth')


Expand All @@ -57,6 +59,7 @@ def __init__(
openid_config_url: Optional[str] = None,
openapi_description: Optional[str] = None,
scheme_name: str = "AzureAuthorizationCodeBearerBase",
http_client_config: Optional["HttpClientConfig"] = None,
) -> None:
"""
Initialize settings.
Expand Down Expand Up @@ -107,6 +110,8 @@ def __init__(
:param scheme_name: str
The name of the security scheme to be used in OpenAPI documentation.
Default is 'AzureAuthorizationCodeBearerBase'.
:param http_client_config: HttpClientConfig
Configuration for the HTTP client used to fetch the OpenID configuration.
"""
self.auto_error = auto_error
# Validate settings, making sure there's no misconfigured dependencies out there
Expand All @@ -123,6 +128,7 @@ def __init__(
multi_tenant=self.multi_tenant,
app_id=app_client_id if openid_config_use_app_id else None,
config_url=openid_config_url or None,
http_client_config=http_client_config,
)

self.leeway: int = leeway
Expand Down Expand Up @@ -302,6 +308,7 @@ def __init__(
openapi_token_url: Optional[str] = None,
openapi_description: Optional[str] = None,
scheme_name: str = "AzureAD_PKCE_single_tenant",
http_client_config: Optional["HttpClientConfig"] = None,
) -> None:
"""
Initialize settings for a single tenant application.
Expand Down Expand Up @@ -340,6 +347,8 @@ def __init__(
:param scheme_name: str
The name of the security scheme to be used in OpenAPI documentation.
Default is 'AzureAD_PKCE_single_tenant'.
:param http_client_config: HttpClientConfig
Configuration for the HTTP client used to fetch the OpenID configuration.
"""
super().__init__(
app_client_id=app_client_id,
Expand All @@ -352,6 +361,7 @@ def __init__(
openapi_authorization_url=openapi_authorization_url,
openapi_token_url=openapi_token_url,
openapi_description=openapi_description,
http_client_config=http_client_config,
)
self.scheme_name: str = scheme_name

Expand All @@ -371,6 +381,7 @@ def __init__(
openapi_token_url: Optional[str] = None,
openapi_description: Optional[str] = None,
scheme_name: str = "AzureAD_PKCE_multi_tenant",
http_client_config: Optional["HttpClientConfig"] = None,
) -> None:
"""
Initialize settings for a multi-tenant application.
Expand Down Expand Up @@ -414,6 +425,8 @@ def __init__(
:param scheme_name: str
The name of the security scheme to be used in OpenAPI documentation.
Default is 'AzureAD_PKCE_multi_tenant'.
:param http_client_config: HttpClientConfig
Configuration for the HTTP client used to fetch the OpenID configuration.
"""
super().__init__(
app_client_id=app_client_id,
Expand All @@ -428,6 +441,7 @@ def __init__(
openapi_authorization_url=openapi_authorization_url,
openapi_token_url=openapi_token_url,
openapi_description=openapi_description,
http_client_config=http_client_config,
)
self.scheme_name: str = scheme_name

Expand All @@ -447,6 +461,7 @@ def __init__(
openapi_token_url: Optional[str] = None,
openapi_description: Optional[str] = None,
scheme_name: str = "AzureAD_PKCE_B2C_multi_tenant",
http_client_config: Optional["HttpClientConfig"] = None,
) -> None:
"""
Initialize settings for a B2C multi-tenant application.
Expand Down Expand Up @@ -485,6 +500,8 @@ def __init__(
:param scheme_name: str
The name of the security scheme to be used in OpenAPI documentation.
Default is 'AzureAD_PKCE_B2C_multi_tenant'.
:param http_client_config: HttpClientConfig
Configuration for the HTTP client used to fetch the OpenID configuration.
"""
super().__init__(
app_client_id=app_client_id,
Expand All @@ -500,5 +517,6 @@ def __init__(
openapi_authorization_url=openapi_authorization_url,
openapi_token_url=openapi_token_url,
openapi_description=openapi_description,
http_client_config=http_client_config,
)
self.scheme_name: str = scheme_name
37 changes: 28 additions & 9 deletions fastapi_azure_auth/openid_config.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,52 @@
from __future__ import annotations

import logging
import ssl
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, TypedDict

import jwt
from fastapi import HTTPException, status
from httpx import AsyncClient

if TYPE_CHECKING: # pragma: no cover
from jwt.algorithms import AllowedPublicKeys
from typing_extensions import NotRequired # added in python 3.11

log = logging.getLogger('fastapi_azure_auth')


class HttpClientConfig(TypedDict):
"""
Configuration for the HTTP client used to fetch the OpenID configuration.

verify - (optional) Either `True` to use an SSL context with the default CA bundle,
`False` to disable verification, or an instance of `ssl.SSLContext` to use a custom context.
trust_env - (optional) Enables or disables usage of environment variables for configuration.
"""

verify: NotRequired[ssl.SSLContext | bool]
trust_env: NotRequired[bool]


class OpenIdConfig:
def __init__(
self,
tenant_id: Optional[str] = None,
tenant_id: str | None = None,
multi_tenant: bool = False,
app_id: Optional[str] = None,
config_url: Optional[str] = None,
app_id: str | None = None,
config_url: str | None = None,
http_client_config: HttpClientConfig | None = None,
) -> None:
self.tenant_id: Optional[str] = tenant_id
self._config_timestamp: Optional[datetime] = None
self.tenant_id: str | None = tenant_id
self._config_timestamp: datetime | None = None
self.multi_tenant: bool = multi_tenant
self.app_id = app_id
self.config_url = config_url
self.http_client_config: HttpClientConfig = http_client_config or HttpClientConfig()

self.authorization_endpoint: str
self.signing_keys: dict[str, 'AllowedPublicKeys']
self.signing_keys: dict[str, AllowedPublicKeys]
self.token_endpoint: str
self.issuer: str

Expand Down Expand Up @@ -72,7 +91,7 @@ async def _load_openid_config(self) -> None:
if self.app_id:
config_url += f'?appid={self.app_id}'

async with AsyncClient(timeout=10) as client:
async with AsyncClient(timeout=10, **self.http_client_config) as client:
log.info('Fetching OpenID Connect config from %s', config_url)
openid_response = await client.get(config_url)
openid_response.raise_for_status()
Expand All @@ -88,7 +107,7 @@ async def _load_openid_config(self) -> None:
jwks_response.raise_for_status()
self._load_keys(jwks_response.json()['keys'])

def _load_keys(self, keys: List[Dict[str, Any]]) -> None:
def _load_keys(self, keys: list[dict[str, Any]]) -> None:
"""
Create certificates based on signing keys and store them
"""
Expand Down
Loading