Skip to content

Commit 52b2a35

Browse files
authored
Feat: Added audience override if contains Project ID (#674)
* added audience override if contains project ID and tests * fixed tests
1 parent 2554dd9 commit 52b2a35

File tree

2 files changed

+155
-1
lines changed

2 files changed

+155
-1
lines changed

descope/auth.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,12 +439,37 @@ def _validate_token(
439439
"Algorithm signature in JWT header does not match the algorithm signature in the public key",
440440
)
441441

442+
# Check if we need to auto-detect audience from token
443+
validation_audience = audience
444+
if audience is None:
445+
try:
446+
unverified_claims = jwt.decode(
447+
jwt=token,
448+
key=copy_key[0].key,
449+
algorithms=[alg_header],
450+
options={"verify_aud": False}, # Skip audience verification for now
451+
leeway=self.jwt_validation_leeway,
452+
)
453+
token_audience = unverified_claims.get("aud")
454+
455+
# If token has audience claim and it matches our project ID, use it
456+
if token_audience and self.project_id:
457+
if isinstance(token_audience, list):
458+
if self.project_id in token_audience:
459+
validation_audience = self.project_id
460+
else:
461+
if token_audience == self.project_id:
462+
validation_audience = self.project_id
463+
except Exception:
464+
# If we can't decode the token to check audience, proceed with original audience (None)
465+
pass
466+
442467
try:
443468
claims = jwt.decode(
444469
jwt=token,
445470
key=copy_key[0].key,
446471
algorithms=[alg_header],
447-
audience=audience,
472+
audience=validation_audience,
448473
leeway=self.jwt_validation_leeway,
449474
)
450475
except ImmatureSignatureError:

tests/test_auth.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from unittest import mock
88
from unittest.mock import patch
99

10+
import jwt
11+
1012
from descope import (
1113
API_RATE_LIMIT_RETRY_AFTER_HEADER,
1214
ERROR_TYPE_API_RATE_LIMIT,
@@ -948,6 +950,133 @@ def test_raise_from_response(self):
948950
"""{"errorCode":"E062108","errorDescription":"User not found","errorMessage":"Cannot find user"}""",
949951
)
950952

953+
def test_validate_session_audience_auto_detection(self):
954+
"""Test that validate_session automatically detects audience when token audience matches project ID"""
955+
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())
956+
957+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
958+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
959+
mock_decode.side_effect = [
960+
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999},
961+
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}
962+
]
963+
964+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
965+
with patch.object(auth, '_fetch_public_keys'):
966+
result = auth.validate_session("dummy_token")
967+
968+
self.assertEqual(mock_decode.call_count, 2)
969+
first_call = mock_decode.call_args_list[0]
970+
self.assertIn("options", first_call.kwargs)
971+
self.assertIn("verify_aud", first_call.kwargs["options"])
972+
self.assertFalse(first_call.kwargs["options"]["verify_aud"])
973+
second_call = mock_decode.call_args_list[1]
974+
self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id)
975+
976+
def test_validate_session_audience_auto_detection_list(self):
977+
"""Test that validate_session automatically detects audience when token audience is a list containing project ID"""
978+
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())
979+
980+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
981+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
982+
mock_decode.side_effect = [
983+
{"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999},
984+
{"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999}
985+
]
986+
987+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
988+
with patch.object(auth, '_fetch_public_keys'):
989+
result = auth.validate_session("dummy_token")
990+
991+
self.assertEqual(mock_decode.call_count, 2)
992+
second_call = mock_decode.call_args_list[1]
993+
self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id)
994+
995+
def test_validate_session_audience_auto_detection_no_match(self):
996+
"""Test that validate_session does not auto-detect audience when token audience doesn't match project ID"""
997+
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())
998+
999+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
1000+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
1001+
mock_decode.side_effect = [
1002+
{"aud": "different-project-id", "sub": "user123", "exp": 9999999999},
1003+
{"aud": "different-project-id", "sub": "user123", "exp": 9999999999}
1004+
]
1005+
1006+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
1007+
with patch.object(auth, '_fetch_public_keys'):
1008+
result = auth.validate_session("dummy_token")
1009+
1010+
self.assertEqual(mock_decode.call_count, 2)
1011+
second_call = mock_decode.call_args_list[1]
1012+
self.assertIsNone(second_call.kwargs["audience"])
1013+
1014+
def test_validate_session_explicit_audience(self):
1015+
"""Test that validate_session uses explicit audience parameter instead of auto-detection"""
1016+
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())
1017+
explicit_audience = "explicit-audience"
1018+
1019+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
1020+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
1021+
mock_decode.return_value = {"aud": explicit_audience, "sub": "user123", "exp": 9999999999}
1022+
1023+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
1024+
with patch.object(auth, '_fetch_public_keys'):
1025+
result = auth.validate_session("dummy_token", audience=explicit_audience)
1026+
1027+
self.assertEqual(mock_decode.call_count, 1)
1028+
call_args = mock_decode.call_args
1029+
self.assertEqual(call_args.kwargs["audience"], explicit_audience)
1030+
1031+
def test_validate_and_refresh_session_audience_auto_detection(self):
1032+
"""Test that validate_and_refresh_session automatically detects audience when token audience matches project ID"""
1033+
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())
1034+
1035+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
1036+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
1037+
mock_decode.side_effect = [
1038+
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999},
1039+
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}
1040+
]
1041+
1042+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
1043+
with patch.object(auth, '_fetch_public_keys'):
1044+
with patch("requests.post") as mock_post:
1045+
mock_post.return_value.ok = True
1046+
mock_post.return_value.json.return_value = {"sessionJwt": "new_token"}
1047+
mock_post.return_value.cookies = {}
1048+
1049+
result = auth.validate_and_refresh_session("dummy_session_token", "dummy_refresh_token")
1050+
1051+
self.assertEqual(mock_decode.call_count, 2)
1052+
first_call = mock_decode.call_args_list[0]
1053+
self.assertIn("options", first_call.kwargs)
1054+
self.assertIn("verify_aud", first_call.kwargs["options"])
1055+
self.assertFalse(first_call.kwargs["options"]["verify_aud"])
1056+
second_call = mock_decode.call_args_list[1]
1057+
self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id)
1058+
1059+
def test_validate_session_audience_mismatch_fails(self):
1060+
"""Test that validate_session fails when token audience doesn't match project ID and no explicit audience is provided"""
1061+
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())
1062+
1063+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
1064+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
1065+
# First call succeeds (for audience detection), second call fails (for validation with None audience)
1066+
mock_decode.side_effect = [
1067+
{"aud": "different-project-id", "sub": "user123", "exp": 9999999999}, # First call for audience detection
1068+
jwt.InvalidAudienceError("Invalid audience") # Second call fails because audience doesn't match
1069+
]
1070+
1071+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
1072+
with patch.object(auth, '_fetch_public_keys'):
1073+
with self.assertRaises(jwt.InvalidAudienceError) as cm:
1074+
auth.validate_session("dummy_token")
1075+
1076+
# Verify the error is about invalid audience
1077+
self.assertIn("Invalid audience", str(cm.exception))
1078+
self.assertEqual(mock_decode.call_count, 2)
1079+
9511080
def test_http_client_authorization_header_variants(self):
9521081
# Base client without management key
9531082
client = self.make_http_client()

0 commit comments

Comments
 (0)