From 4f1f0b0ebe2d1d76d10cdfa727cce5ef45830c82 Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Sun, 11 Dec 2022 13:43:32 +0400 Subject: [PATCH 01/12] Fix CORS by passing 'Origin' header to OAuthLib It is possible to control CORS by overriding is_origin_allowed method of RequestValidator class. OAuthLib allows origin if: - is_origin_allowed returns True for particular request - Request connection is secure - Request has 'Origin' header --- tests/test_cors.py | 117 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 tests/test_cors.py diff --git a/tests/test_cors.py b/tests/test_cors.py new file mode 100644 index 000000000..4ddc0e141 --- /dev/null +++ b/tests/test_cors.py @@ -0,0 +1,117 @@ +from urllib.parse import parse_qs, urlparse + +import pytest +from django.contrib.auth import get_user_model +from django.test import RequestFactory, TestCase +from django.urls import reverse + +from oauth2_provider.models import get_application_model +from oauth2_provider.oauth2_validators import OAuth2Validator + +from . import presets +from .utils import get_basic_auth_header + + +class CorsOAuth2Validator(OAuth2Validator): + def is_origin_allowed(self, client_id, origin, request, *args, **kwargs): + """Enable CORS in OAuthLib""" + return True + + +Application = get_application_model() +UserModel = get_user_model() + +CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz" + +# CORS is allowed for https only +CLIENT_URI = "https://example.org" + + +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) +class CorsTest(TestCase): + """ + Test that CORS headers can be managed by OAuthLib. + The objective is: http request 'Origin' header should be passed to OAuthLib + """ + + def setUp(self): + self.factory = RequestFactory() + self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") + self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") + + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"] + self.oauth2_settings.PKCE_REQUIRED = False + + self.application = Application.objects.create( + name="Test Application", + redirect_uris=(CLIENT_URI), + user=self.dev_user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + client_secret=CLEARTEXT_SECRET, + ) + + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"] + self.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CorsOAuth2Validator + + def tearDown(self): + self.application.delete() + self.test_user.delete() + self.dev_user.delete() + + def test_cors_header(self): + """ + Test that /token endpoint has Access-Control-Allow-Origin + """ + authorization_code = self._get_authorization_code() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": CLIENT_URI, + } + + auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) + auth_headers["origin"] = CLIENT_URI + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI) + + def test_no_cors_header(self): + """ + Test that /token endpoint does not have Access-Control-Allow-Origin + """ + authorization_code = self._get_authorization_code() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": CLIENT_URI, + } + + auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + # No CORS headers, because request did not have Origin + self.assertFalse(response.has_header("Access-Control-Allow-Origin")) + + def _get_authorization_code(self): + self.client.login(username="test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "https://example.org", + "response_type": "code", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).query) + return query_dict["code"].pop() From 6cd4185012c9750550d8eeae4b6f87d180a5de47 Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Sun, 19 Feb 2023 15:34:51 +0300 Subject: [PATCH 02/12] Fixed tests for Access-Control-Allow-Origin header returned by oauthlib --- tests/test_cors.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_cors.py b/tests/test_cors.py index 4ddc0e141..9d7260bc9 100644 --- a/tests/test_cors.py +++ b/tests/test_cors.py @@ -29,7 +29,7 @@ def is_origin_allowed(self, client_id, origin, request, *args, **kwargs): @pytest.mark.usefixtures("oauth2_settings") @pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) -class CorsTest(TestCase): +class TestCors(TestCase): """ Test that CORS headers can be managed by OAuthLib. The objective is: http request 'Origin' header should be passed to OAuthLib @@ -74,8 +74,7 @@ def test_cors_header(self): } auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) - auth_headers["origin"] = CLIENT_URI - + auth_headers["HTTP_ORIGIN"] = CLIENT_URI response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI) From 09588a90ff9dff9d4439d3ae77ce8a8df37a02e2 Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Fri, 29 Sep 2023 22:12:25 +0300 Subject: [PATCH 03/12] Added Allowed Origins application setting --- oauth2_provider/models.py | 2 -- tests/conftest.py | 1 + tests/test_cors.py | 44 +++++++++++++++++++++++++++++++-------- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index c37057e49..a1e7fda52 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -22,7 +22,6 @@ from .utils import jwk_from_pem from .validators import RedirectURIValidator, URIValidator, WildcardSet - logger = logging.getLogger(__name__) @@ -137,7 +136,6 @@ class AbstractApplication(models.Model): help_text=_("Allowed origins list to enable CORS, space separated"), default="", ) - class Meta: abstract = True diff --git a/tests/conftest.py b/tests/conftest.py index d620c3f59..2cc3c3901 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,6 +108,7 @@ def application(): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, algorithm=Application.RS256_ALGORITHM, client_secret=CLEARTEXT_SECRET, + allowed_origins="https://example.com", ) diff --git a/tests/test_cors.py b/tests/test_cors.py index 9d7260bc9..64f2a5fec 100644 --- a/tests/test_cors.py +++ b/tests/test_cors.py @@ -1,3 +1,4 @@ +import json from urllib.parse import parse_qs, urlparse import pytest @@ -6,18 +7,11 @@ from django.urls import reverse from oauth2_provider.models import get_application_model -from oauth2_provider.oauth2_validators import OAuth2Validator from . import presets from .utils import get_basic_auth_header -class CorsOAuth2Validator(OAuth2Validator): - def is_origin_allowed(self, client_id, origin, request, *args, **kwargs): - """Enable CORS in OAuthLib""" - return True - - Application = get_application_model() UserModel = get_user_model() @@ -50,10 +44,10 @@ def setUp(self): client_type=Application.CLIENT_CONFIDENTIAL, authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, client_secret=CLEARTEXT_SECRET, + allowed_origins=CLIENT_URI, ) self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"] - self.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CorsOAuth2Validator def tearDown(self): self.application.delete() @@ -76,10 +70,42 @@ def test_cors_header(self): auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) auth_headers["HTTP_ORIGIN"] = CLIENT_URI response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + + content = json.loads(response.content.decode("utf-8")) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI) + + token_request_data = { + "grant_type": "refresh_token", + "refresh_token": content["refresh_token"], + "scope": content["scope"], + } + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI) - def test_no_cors_header(self): + def test_no_cors_header_origin_not_allowed(self): + """ + Test that /token endpoint does not have Access-Control-Allow-Origin + when request origin is not in Application.allowed_origins + """ + authorization_code = self._get_authorization_code() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": CLIENT_URI, + } + + auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) + auth_headers["HTTP_ORIGIN"] = "another_example.org" + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + self.assertFalse(response.has_header("Access-Control-Allow-Origin")) + + def test_no_cors_header_no_origin(self): """ Test that /token endpoint does not have Access-Control-Allow-Origin """ From 86ec8ff11cbf8d5d563b31b8fa75858237938926 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Sep 2023 19:22:20 +0000 Subject: [PATCH 04/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- oauth2_provider/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index a1e7fda52..c37057e49 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -22,6 +22,7 @@ from .utils import jwk_from_pem from .validators import RedirectURIValidator, URIValidator, WildcardSet + logger = logging.getLogger(__name__) @@ -136,6 +137,7 @@ class AbstractApplication(models.Model): help_text=_("Allowed origins list to enable CORS, space separated"), default="", ) + class Meta: abstract = True From c04ca48ed5ef32a805f64bc99deea81e8d485aff Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Sun, 1 Oct 2023 09:43:01 +0300 Subject: [PATCH 05/12] Code and docs cleanup --- tests/test_cors.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/test_cors.py b/tests/test_cors.py index 64f2a5fec..e8eff07a1 100644 --- a/tests/test_cors.py +++ b/tests/test_cors.py @@ -20,6 +20,8 @@ # CORS is allowed for https only CLIENT_URI = "https://example.org" +CLIENT_URI_HTTP = "http://example.org" + @pytest.mark.usefixtures("oauth2_settings") @pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) @@ -39,7 +41,7 @@ def setUp(self): self.application = Application.objects.create( name="Test Application", - redirect_uris=(CLIENT_URI), + redirect_uris=CLIENT_URI, user=self.dev_user, client_type=Application.CLIENT_CONFIDENTIAL, authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, @@ -85,6 +87,26 @@ def test_cors_header(self): self.assertEqual(response.status_code, 200) self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI) + def test_cors_header_no_https(self): + """ + Test that CORS is not allowed if origin uri does not have https:// schema + """ + authorization_code = self._get_authorization_code() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": CLIENT_URI, + } + + auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) + auth_headers["HTTP_ORIGIN"] = CLIENT_URI_HTTP + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + + self.assertEqual(response.status_code, 200) + self.assertFalse(response.has_header("Access-Control-Allow-Origin")) + def test_no_cors_header_origin_not_allowed(self): """ Test that /token endpoint does not have Access-Control-Allow-Origin From 3e72f8e036b87cbb1974e36c8b210e5c1a92d6a4 Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Sun, 1 Oct 2023 10:13:19 +0300 Subject: [PATCH 06/12] Code cleanup --- tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2cc3c3901..d620c3f59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,7 +108,6 @@ def application(): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, algorithm=Application.RS256_ALGORITHM, client_secret=CLEARTEXT_SECRET, - allowed_origins="https://example.com", ) From 70928638acd8d65bafb0cfa897b22cd3161d4b30 Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Tue, 3 Oct 2023 21:18:04 +0300 Subject: [PATCH 07/12] Code review: update docs and test names --- tests/test_cors.py | 164 --------------------------------------------- 1 file changed, 164 deletions(-) delete mode 100644 tests/test_cors.py diff --git a/tests/test_cors.py b/tests/test_cors.py deleted file mode 100644 index e8eff07a1..000000000 --- a/tests/test_cors.py +++ /dev/null @@ -1,164 +0,0 @@ -import json -from urllib.parse import parse_qs, urlparse - -import pytest -from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase -from django.urls import reverse - -from oauth2_provider.models import get_application_model - -from . import presets -from .utils import get_basic_auth_header - - -Application = get_application_model() -UserModel = get_user_model() - -CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz" - -# CORS is allowed for https only -CLIENT_URI = "https://example.org" - -CLIENT_URI_HTTP = "http://example.org" - - -@pytest.mark.usefixtures("oauth2_settings") -@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) -class TestCors(TestCase): - """ - Test that CORS headers can be managed by OAuthLib. - The objective is: http request 'Origin' header should be passed to OAuthLib - """ - - def setUp(self): - self.factory = RequestFactory() - self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") - self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") - - self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"] - self.oauth2_settings.PKCE_REQUIRED = False - - self.application = Application.objects.create( - name="Test Application", - redirect_uris=CLIENT_URI, - user=self.dev_user, - client_type=Application.CLIENT_CONFIDENTIAL, - authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, - client_secret=CLEARTEXT_SECRET, - allowed_origins=CLIENT_URI, - ) - - self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"] - - def tearDown(self): - self.application.delete() - self.test_user.delete() - self.dev_user.delete() - - def test_cors_header(self): - """ - Test that /token endpoint has Access-Control-Allow-Origin - """ - authorization_code = self._get_authorization_code() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": CLIENT_URI, - } - - auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) - auth_headers["HTTP_ORIGIN"] = CLIENT_URI - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - - content = json.loads(response.content.decode("utf-8")) - - self.assertEqual(response.status_code, 200) - self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI) - - token_request_data = { - "grant_type": "refresh_token", - "refresh_token": content["refresh_token"], - "scope": content["scope"], - } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI) - - def test_cors_header_no_https(self): - """ - Test that CORS is not allowed if origin uri does not have https:// schema - """ - authorization_code = self._get_authorization_code() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": CLIENT_URI, - } - - auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) - auth_headers["HTTP_ORIGIN"] = CLIENT_URI_HTTP - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - - self.assertEqual(response.status_code, 200) - self.assertFalse(response.has_header("Access-Control-Allow-Origin")) - - def test_no_cors_header_origin_not_allowed(self): - """ - Test that /token endpoint does not have Access-Control-Allow-Origin - when request origin is not in Application.allowed_origins - """ - authorization_code = self._get_authorization_code() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": CLIENT_URI, - } - - auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) - auth_headers["HTTP_ORIGIN"] = "another_example.org" - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - self.assertFalse(response.has_header("Access-Control-Allow-Origin")) - - def test_no_cors_header_no_origin(self): - """ - Test that /token endpoint does not have Access-Control-Allow-Origin - """ - authorization_code = self._get_authorization_code() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": CLIENT_URI, - } - - auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - # No CORS headers, because request did not have Origin - self.assertFalse(response.has_header("Access-Control-Allow-Origin")) - - def _get_authorization_code(self): - self.client.login(username="test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "https://example.org", - "response_type": "code", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - query_dict = parse_qs(urlparse(response["Location"]).query) - return query_dict["code"].pop() From 4015e4dc2e40406aedd7466507189ea4ab0d1bcb Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Wed, 18 Oct 2023 12:50:26 +0300 Subject: [PATCH 08/12] Added ALLOWED_SCHEMES setting for Allowed Orgins validation --- docs/settings.rst | 11 ++++++ oauth2_provider/models.py | 9 +++-- oauth2_provider/settings.py | 1 + oauth2_provider/validators.py | 26 +++++++++++++ tests/test_validators.py | 71 ++++++++++++++++++++++++++++++++++- 5 files changed, 114 insertions(+), 4 deletions(-) diff --git a/docs/settings.rst b/docs/settings.rst index f31aff533..a7cac94a1 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -63,6 +63,17 @@ assigned ports. Note that you may override ``Application.get_allowed_schemes()`` to set this on a per-application basis. +ALLOWED_SCHEMES +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Default: ``["https"]`` + +A list of schemes that the ``allowed_origins`` field will be validated against. +Setting this to ``["https"]`` only in production is strongly recommended. +Adding ``"http"`` to the list is considered to be safe only for local development and testing. +Note that `OAUTHLIB_INSECURE_TRANSPORT `_ +environment variable should be also set to allow http origins. + APPLICATION_MODEL ~~~~~~~~~~~~~~~~~ diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index c37057e49..e09b41664 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -20,8 +20,7 @@ from .scopes import get_scopes_backend from .settings import oauth2_settings from .utils import jwk_from_pem -from .validators import RedirectURIValidator, URIValidator, WildcardSet - +from .validators import RedirectURIValidator, URIValidator, WildcardSet, AllowedURIValidator logger = logging.getLogger(__name__) @@ -218,7 +217,7 @@ def clean(self): allowed_origins = self.allowed_origins.strip().split() if allowed_origins: # oauthlib allows only https scheme for CORS - validator = URIValidator({"https"}) + validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "Origin") for uri in allowed_origins: validator(uri) @@ -808,6 +807,10 @@ def is_origin_allowed(origin, allowed_origins): """ parsed_origin = urlparse(origin) + + if parsed_origin.scheme not in oauth2_settings.ALLOWED_SCHEMES: + return False + for allowed_origin in allowed_origins: parsed_allowed_origin = urlparse(allowed_origin) if ( diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index aa7de7351..c5af9ebae 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -68,6 +68,7 @@ "REFRESH_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.RefreshTokenAdmin", "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], + "ALLOWED_SCHEMES": ["https"], "OIDC_ENABLED": False, "OIDC_ISS_ENDPOINT": "", "OIDC_USERINFO_ENDPOINT": "", diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index 6c8fa3839..9ecced631 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -31,6 +31,32 @@ def __call__(self, value): raise ValidationError("Redirect URIs must not contain fragments") +class AllowedURIValidator(URIValidator): + def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False): + """ + :params schemes: List of allowed schemes. E.g.: ["https"] + :params name: Name of the validater URI required for validation message. E.g.: "Origin" + :params allow_path: If URI can contain path part + :params allow_query: If URI can contain query part + :params allow_fragments: If URI can contain fragments part + """ + super().__init__(schemes=schemes) + self.name = name + self.allow_path = allow_path + self.allow_query = allow_query + self.allow_fragments = allow_fragments + + def __call__(self, value): + super().__call__(value) + value = force_str(value) + scheme, netloc, path, query, fragment = urlsplit(value) + if path and not self.allow_path: + raise ValidationError("{} URIs must not contain path".format(self.name)) + if query and not self.allow_query: + raise ValidationError("{} URIs must not contain query".format(self.name)) + if fragment and not self.allow_fragments: + raise ValidationError("{} URIs must not contain fragments".format(self.name)) + ## # WildcardSet is a special set that contains everything. # This is required in order to move validation of the scheme from diff --git a/tests/test_validators.py b/tests/test_validators.py index 0760e0290..d77e128a3 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -2,7 +2,7 @@ from django.core.validators import ValidationError from django.test import TestCase -from oauth2_provider.validators import RedirectURIValidator +from oauth2_provider.validators import RedirectURIValidator, AllowedURIValidator @pytest.mark.usefixtures("oauth2_settings") @@ -36,6 +36,11 @@ def test_validate_custom_uri_scheme(self): # Check ValidationError not thrown validator(uri) + validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "Origin") + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + def test_validate_bad_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] @@ -61,3 +66,67 @@ def test_validate_bad_uris(self): for uri in bad_uris: with self.assertRaises(ValidationError): validator(uri) + + def test_validate_good_origin_uris(self): + """ + Test AllowedURIValidator validates origin URIs if they match requirements + """ + validator = AllowedURIValidator( + ["https"], + "Origin", + allow_path=False, + allow_query=False, + allow_fragments=False, + ) + good_uris = [ + "https://example.com", + "https://example.com:8080", + "https://example", + "https://localhost", + "https://1.1.1.1", + "https://127.0.0.1", + "https://255.255.255.255", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + + def test_validate_bad_origin_uris(self): + """ + Test AllowedURIValidator rejects origin URIs if they do not match requirements + """ + validator = AllowedURIValidator( + ["https"], + "Origin", + allow_path=False, + allow_query=False, + allow_fragments=False, + ) + bad_uris = [ + "http:/example.com", + "HTTP://localhost", + "HTTP://example.com", + "HTTP://example.com.", + "http://example.com/#fragment", + "123://example.com", + "http://fe80::1", + "git+ssh://example.com", + "my-scheme://example.com", + "uri-without-a-scheme", + "https://example.com/#fragment", + "good://example.com/#fragment", + " ", + "", + # Bad IPv6 URL, urlparse behaves differently for these + 'https://[">', + # Origin uri should not contain path, query of fragment parts + # https://www.rfc-editor.org/rfc/rfc6454#section-7.1 + "https:/example.com/", + "https:/example.com/test", + "https:/example.com/?q=test", + "https:/example.com/#test", + ] + + for uri in bad_uris: + with self.assertRaises(ValidationError): + validator(uri) From 2f8789d982dc699e02d05814b521e384b015e8fe Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Wed, 18 Oct 2023 22:13:34 +0300 Subject: [PATCH 09/12] Code cleanup --- oauth2_provider/models.py | 2 +- oauth2_provider/validators.py | 1 + tests/test_validators.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index e09b41664..a50972728 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -20,7 +20,7 @@ from .scopes import get_scopes_backend from .settings import oauth2_settings from .utils import jwk_from_pem -from .validators import RedirectURIValidator, URIValidator, WildcardSet, AllowedURIValidator +from .validators import AllowedURIValidator, RedirectURIValidator, WildcardSet logger = logging.getLogger(__name__) diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index 9ecced631..e69bb27b2 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -57,6 +57,7 @@ def __call__(self, value): if fragment and not self.allow_fragments: raise ValidationError("{} URIs must not contain fragments".format(self.name)) + ## # WildcardSet is a special set that contains everything. # This is required in order to move validation of the scheme from diff --git a/tests/test_validators.py b/tests/test_validators.py index d77e128a3..247e97baa 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -2,7 +2,7 @@ from django.core.validators import ValidationError from django.test import TestCase -from oauth2_provider.validators import RedirectURIValidator, AllowedURIValidator +from oauth2_provider.validators import AllowedURIValidator, RedirectURIValidator @pytest.mark.usefixtures("oauth2_settings") From 84c8c4c16a58ce363ff77035de87c3cee20614a6 Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Wed, 18 Oct 2023 23:01:34 +0300 Subject: [PATCH 10/12] Add more tests for origin validators --- oauth2_provider/validators.py | 14 +++++++------- tests/conftest.py | 12 ++++++++++++ tests/presets.py | 8 ++++++++ tests/test_models.py | 16 ++++++++++++++++ 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index e69bb27b2..df3d9e753 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -34,11 +34,11 @@ def __call__(self, value): class AllowedURIValidator(URIValidator): def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False): """ - :params schemes: List of allowed schemes. E.g.: ["https"] - :params name: Name of the validater URI required for validation message. E.g.: "Origin" - :params allow_path: If URI can contain path part - :params allow_query: If URI can contain query part - :params allow_fragments: If URI can contain fragments part + :param schemes: List of allowed schemes. E.g.: ["https"] + :param name: Name of the validated URI. It is required for validation message. E.g.: "Origin" + :param allow_path: If URI can contain path part + :param allow_query: If URI can contain query part + :param allow_fragments: If URI can contain fragments part """ super().__init__(schemes=schemes) self.name = name @@ -50,12 +50,12 @@ def __call__(self, value): super().__call__(value) value = force_str(value) scheme, netloc, path, query, fragment = urlsplit(value) - if path and not self.allow_path: - raise ValidationError("{} URIs must not contain path".format(self.name)) if query and not self.allow_query: raise ValidationError("{} URIs must not contain query".format(self.name)) if fragment and not self.allow_fragments: raise ValidationError("{} URIs must not contain fragments".format(self.name)) + if path and not self.allow_path: + raise ValidationError("{} URIs must not contain path".format(self.name)) ## diff --git a/tests/conftest.py b/tests/conftest.py index d620c3f59..eff48f7fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,6 +124,18 @@ def public_application(): ) +@pytest.fixture +def cors_application(): + return Application.objects.create( + name="Test CORS Application", + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + algorithm=Application.RS256_ALGORITHM, + client_secret=CLEARTEXT_SECRET, + allowed_origins="https://example.com http://example.com", + ) + + @pytest.fixture def logged_in_client(test_user): from django.test.client import Client diff --git a/tests/presets.py b/tests/presets.py index 1ac8d3279..4538c64eb 100644 --- a/tests/presets.py +++ b/tests/presets.py @@ -57,3 +57,11 @@ "READ_SCOPE": "read", "WRITE_SCOPE": "write", } + +ALLOWED_SCHEMES_DEFAULT = { + "ALLOWED_SCHEMES": ["https"], +} + +ALLOWED_SCHEMES_HTTP = { + "ALLOWED_SCHEMES": ["https", "http"], +} diff --git a/tests/test_models.py b/tests/test_models.py index 4de823b8d..8c62e2c99 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -594,3 +594,19 @@ def test_application_clean(oauth2_settings, application): assert "Enter a valid URL" in str(exc.value) application.allowed_origins = "https://example.com" application.clean() + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT) +def test_application_origin_allowed_default_https(oauth2_settings, cors_application): + """Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https""" + assert cors_application.origin_allowed("https://example.com") + assert not cors_application.origin_allowed("http://example.com") + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP) +def test_application_origin_allowed_http(oauth2_settings, cors_application): + """Test that http schemes are allowed because http was added to ALLOWED_SCHEMES""" + assert cors_application.origin_allowed("https://example.com") + assert cors_application.origin_allowed("http://example.com") From f7ae512662feea6cca1b976cb11038f7658d947c Mon Sep 17 00:00:00 2001 From: Aliaksei Kanstantsinau Date: Wed, 18 Oct 2023 23:18:10 +0300 Subject: [PATCH 11/12] fix coverage --- tests/test_validators.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_validators.py b/tests/test_validators.py index 247e97baa..6cbc23172 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -121,10 +121,10 @@ def test_validate_bad_origin_uris(self): 'https://[">', # Origin uri should not contain path, query of fragment parts # https://www.rfc-editor.org/rfc/rfc6454#section-7.1 - "https:/example.com/", - "https:/example.com/test", - "https:/example.com/?q=test", - "https:/example.com/#test", + "https://example.com/", + "https://example.com/test", + "https://example.com/?q=test", + "https://example.com/#test", ] for uri in bad_uris: From afdfa4561440427ac1887d03ddeeee8011a04fdf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 18:52:23 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- oauth2_provider/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index a50972728..80d8f3487 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -22,6 +22,7 @@ from .utils import jwk_from_pem from .validators import AllowedURIValidator, RedirectURIValidator, WildcardSet + logger = logging.getLogger(__name__)