Skip to content

Commit faf778f

Browse files
authored
feat: pass the issuer_url to manage jkws storage (#18885)
* feat: pass the issuer_url to manage jkws storage Previously using `self.issuer_url` didn't make a difference, as the property is configured on the resolved Publisher class, and has nothing to do with the inbound value. Resolves #18845 Signed-off-by: Mike Fiedler <miketheman@gmail.com> * refactor: raise instead of passing none Raise an exception if the key isn't found, handled by the caller chain in `verify_jwt_signature()` prior to trying to decode a value with `None`. Signed-off-by: Mike Fiedler <miketheman@gmail.com> --------- Signed-off-by: Mike Fiedler <miketheman@gmail.com>
1 parent 78426df commit faf778f

File tree

5 files changed

+191
-106
lines changed

5 files changed

+191
-106
lines changed

tests/unit/oidc/test_services.py

Lines changed: 61 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ def test_interface_matches(self):
5959
)
6060

6161
def test_verify_jwt_signature(self, monkeypatch):
62+
issuer_url = "https://example.com"
6263
service = services.OIDCPublisherService(
6364
session=pretend.stub(),
6465
publisher=pretend.stub(),
65-
issuer_url=pretend.stub(),
66+
issuer_url=issuer_url,
6667
audience="fakeaudience",
6768
cache_url=pretend.stub(),
6869
metrics=pretend.stub(),
@@ -72,12 +73,10 @@ def test_verify_jwt_signature(self, monkeypatch):
7273
decoded = pretend.stub()
7374
jwt = pretend.stub(decode=pretend.call_recorder(lambda t, **kwargs: decoded))
7475
key = pretend.stub(key="fake-key")
75-
monkeypatch.setattr(
76-
service, "_get_key_for_token", pretend.call_recorder(lambda t: key)
77-
)
76+
monkeypatch.setattr(service, "_get_key_for_token", lambda t, i=issuer_url: key)
7877
monkeypatch.setattr(services, "jwt", jwt)
7978

80-
assert service.verify_jwt_signature(token) == decoded
79+
assert service.verify_jwt_signature(token, issuer_url) == decoded
8180
assert jwt.decode.calls == [
8281
pretend.call(
8382
token,
@@ -93,7 +92,7 @@ def test_verify_jwt_signature(self, monkeypatch):
9392
verify_nbf=True,
9493
strict_aud=True,
9594
),
96-
issuer=service.issuer_url,
95+
issuer=issuer_url,
9796
audience="fakeaudience",
9897
leeway=30,
9998
)
@@ -119,7 +118,7 @@ def test_verify_jwt_signature_get_key_for_token_fails(self, metrics, monkeypatch
119118
pretend.call_recorder(lambda s: None),
120119
)
121120

122-
assert service.verify_jwt_signature(token) is None
121+
assert service.verify_jwt_signature(token, "https://none") is None
123122
assert service.metrics.increment.calls == [
124123
pretend.call(
125124
"warehouse.oidc.verify_jwt_signature.malformed_jwt",
@@ -130,10 +129,11 @@ def test_verify_jwt_signature_get_key_for_token_fails(self, metrics, monkeypatch
130129

131130
@pytest.mark.parametrize("exc", [PyJWTError, TypeError("foo")])
132131
def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
132+
issuer_url = "https://none"
133133
service = services.OIDCPublisherService(
134134
session=pretend.stub(),
135135
publisher="fakepublisher",
136-
issuer_url="https://none",
136+
issuer_url=issuer_url,
137137
audience="fakeaudience",
138138
cache_url=pretend.stub(),
139139
metrics=metrics,
@@ -143,7 +143,9 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
143143
jwt = pretend.stub(decode=pretend.raiser(exc), PyJWTError=PyJWTError)
144144
key = pretend.stub(key="fake-key")
145145
monkeypatch.setattr(
146-
service, "_get_key_for_token", pretend.call_recorder(lambda t: key)
146+
service,
147+
"_get_key_for_token",
148+
pretend.call_recorder(lambda t, i=issuer_url: key),
147149
)
148150
monkeypatch.setattr(services, "jwt", jwt)
149151
monkeypatch.setattr(
@@ -152,7 +154,7 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
152154
pretend.call_recorder(lambda s: None),
153155
)
154156

155-
assert service.verify_jwt_signature(token) is None
157+
assert service.verify_jwt_signature(token, issuer_url) is None
156158
assert service.metrics.increment.calls == [
157159
pretend.call(
158160
"warehouse.oidc.verify_jwt_signature.invalid_signature",
@@ -168,16 +170,17 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
168170
assert services.sentry_sdk.capture_message.calls == []
169171

170172
def test_find_publisher(self, metrics, monkeypatch):
173+
issuer_url = "https://none"
171174
service = services.OIDCPublisherService(
172175
session=pretend.stub(),
173176
publisher="fakepublisher",
174-
issuer_url="https://none",
177+
issuer_url=issuer_url,
175178
audience="fakeaudience",
176179
cache_url=pretend.stub(),
177180
metrics=metrics,
178181
)
179182

180-
token = SignedClaims({})
183+
token = SignedClaims({"iss": issuer_url})
181184

182185
publisher = pretend.stub(verify_claims=pretend.call_recorder(lambda c, s: True))
183186
find_publisher_by_issuer = pretend.call_recorder(lambda *a, **kw: publisher)
@@ -198,10 +201,11 @@ def test_find_publisher(self, metrics, monkeypatch):
198201
]
199202

200203
def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch):
204+
issuer_url = "https://none"
201205
service = services.OIDCPublisherService(
202206
session=pretend.stub(),
203207
publisher="fakepublisher",
204-
issuer_url="https://none",
208+
issuer_url=issuer_url,
205209
audience="fakeaudience",
206210
cache_url=pretend.stub(),
207211
metrics=metrics,
@@ -212,7 +216,7 @@ def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch):
212216
services, "find_publisher_by_issuer", find_publisher_by_issuer
213217
)
214218

215-
claims = pretend.stub()
219+
claims = SignedClaims({"iss": issuer_url})
216220
with pytest.raises(errors.InvalidPublisherError):
217221
service.find_publisher(claims)
218222
assert service.metrics.increment.calls == [
@@ -227,10 +231,11 @@ def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch):
227231
]
228232

229233
def test_find_publisher_verify_claims_fails(self, metrics, monkeypatch):
234+
issuer_url = "https://none"
230235
service = services.OIDCPublisherService(
231236
session=pretend.stub(),
232237
publisher="fakepublisher",
233-
issuer_url="https://none",
238+
issuer_url=issuer_url,
234239
audience="fakeaudience",
235240
cache_url=pretend.stub(),
236241
metrics=metrics,
@@ -246,7 +251,7 @@ def test_find_publisher_verify_claims_fails(self, metrics, monkeypatch):
246251
services, "find_publisher_by_issuer", find_publisher_by_issuer
247252
)
248253

249-
claims = SignedClaims({})
254+
claims = SignedClaims({"iss": issuer_url})
250255
with pytest.raises(errors.InvalidPublisherError):
251256
service.find_publisher(claims)
252257
assert service.metrics.increment.calls == [
@@ -311,7 +316,7 @@ def test_get_keyset_not_cached(self, monkeypatch, mockredis):
311316

312317
monkeypatch.setattr(services.redis, "StrictRedis", mockredis)
313318

314-
keys, timeout = service._get_keyset()
319+
keys, timeout = service._get_keyset("https://example.com")
315320

316321
assert not keys
317322
assert timeout is False
@@ -329,17 +334,18 @@ def test_get_keyset_cached(self, monkeypatch, mockredis):
329334
monkeypatch.setattr(services.redis, "StrictRedis", mockredis)
330335

331336
keyset = {"fake-key-id": {"foo": "bar"}}
332-
service._store_keyset(keyset)
333-
keys, timeout = service._get_keyset()
337+
service._store_keyset("https://example.com", keyset)
338+
keys, timeout = service._get_keyset("https://example.com")
334339

335340
assert keys == keyset
336341
assert timeout is True
337342

338343
def test_refresh_keyset_timeout(self, metrics, monkeypatch, mockredis):
344+
issuer_url = "https://example.com"
339345
service = services.OIDCPublisherService(
340346
session=pretend.stub(),
341347
publisher="example",
342-
issuer_url="https://example.com",
348+
issuer_url=issuer_url,
343349
audience="fakeaudience",
344350
cache_url="rediss://fake.example.com",
345351
metrics=metrics,
@@ -348,9 +354,9 @@ def test_refresh_keyset_timeout(self, metrics, monkeypatch, mockredis):
348354
monkeypatch.setattr(services.redis, "StrictRedis", mockredis)
349355

350356
keyset = {"fake-key-id": {"foo": "bar"}}
351-
service._store_keyset(keyset)
357+
service._store_keyset(issuer_url, keyset)
352358

353-
keys = service._refresh_keyset()
359+
keys = service._refresh_keyset(issuer_url)
354360
assert keys == keyset
355361
assert metrics.increment.calls == [
356362
pretend.call(
@@ -380,7 +386,7 @@ def test_refresh_keyset_oidc_config_fails(self, metrics, monkeypatch, mockredis)
380386
monkeypatch.setattr(services, "requests", requests)
381387
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)
382388

383-
keys = service._refresh_keyset()
389+
keys = service._refresh_keyset("https://example.com")
384390

385391
assert keys == {}
386392
assert metrics.increment.calls == []
@@ -421,7 +427,7 @@ def test_refresh_keyset_oidc_config_no_jwks_uri(
421427
monkeypatch.setattr(services, "requests", requests)
422428
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)
423429

424-
keys = service._refresh_keyset()
430+
keys = service._refresh_keyset("https://example.com")
425431

426432
assert keys == {}
427433
assert metrics.increment.calls == []
@@ -472,7 +478,7 @@ def get(url, timeout=5):
472478
monkeypatch.setattr(services, "requests", requests)
473479
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)
474480

475-
keys = service._refresh_keyset()
481+
keys = service._refresh_keyset("https://example.com")
476482

477483
assert keys == {}
478484
assert metrics.increment.calls == []
@@ -524,7 +530,7 @@ def get(url, timeout=5):
524530
monkeypatch.setattr(services, "requests", requests)
525531
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)
526532

527-
keys = service._refresh_keyset()
533+
keys = service._refresh_keyset("https://example.com")
528534

529535
assert keys == {}
530536
assert metrics.increment.calls == []
@@ -573,7 +579,7 @@ def get(url, timeout=5):
573579
monkeypatch.setattr(services, "requests", requests)
574580
monkeypatch.setattr(services, "sentry_sdk", sentry_sdk)
575581

576-
keys = service._refresh_keyset()
582+
keys = service._refresh_keyset("https://example.com")
577583

578584
assert keys == {"fake-key-id": {"kid": "fake-key-id", "foo": "bar"}}
579585
assert metrics.increment.calls == []
@@ -586,7 +592,7 @@ def get(url, timeout=5):
586592
assert sentry_sdk.capture_message.calls == []
587593

588594
# Ensure that we also cached the updated keyset as part of refreshing.
589-
keys, timeout = service._get_keyset()
595+
keys, timeout = service._get_keyset("https://example.com")
590596
assert keys == {"fake-key-id": {"kid": "fake-key-id", "foo": "bar"}}
591597
assert timeout is True
592598

@@ -612,9 +618,11 @@ def test_get_key_cached(self, metrics, monkeypatch):
612618
"x5t": "dummy",
613619
}
614620
}
615-
monkeypatch.setattr(service, "_get_keyset", lambda: (keyset, True))
621+
monkeypatch.setattr(
622+
service, "_get_keyset", lambda issuer_url=None: (keyset, True)
623+
)
616624

617-
key = service._get_key("fake-key-id")
625+
key = service._get_key("fake-key-id", "https://example.com")
618626
assert isinstance(key, PyJWK)
619627
assert key.key_id == "fake-key-id"
620628

@@ -642,10 +650,10 @@ def test_get_key_uncached(self, metrics, monkeypatch):
642650
"x5t": "dummy",
643651
}
644652
}
645-
monkeypatch.setattr(service, "_get_keyset", lambda: ({}, False))
646-
monkeypatch.setattr(service, "_refresh_keyset", lambda: keyset)
653+
monkeypatch.setattr(service, "_get_keyset", lambda issuer_url=None: ({}, False))
654+
monkeypatch.setattr(service, "_refresh_keyset", lambda issuer_url=None: keyset)
647655

648-
key = service._get_key("fake-key-id")
656+
key = service._get_key("fake-key-id", "https://example.com")
649657
assert isinstance(key, PyJWK)
650658
assert key.key_id == "fake-key-id"
651659

@@ -661,11 +669,14 @@ def test_get_key_refresh_fails(self, metrics, monkeypatch):
661669
metrics=metrics,
662670
)
663671

664-
monkeypatch.setattr(service, "_get_keyset", lambda: ({}, False))
665-
monkeypatch.setattr(service, "_refresh_keyset", lambda: {})
672+
monkeypatch.setattr(service, "_get_keyset", lambda issuer_url=None: ({}, False))
673+
monkeypatch.setattr(service, "_refresh_keyset", lambda issuer_url=None: {})
666674

667-
key = service._get_key("fake-key-id")
668-
assert key is None
675+
with pytest.raises(
676+
jwt.PyJWTError,
677+
match=r"Key ID 'fake-key-id' not found for issuer 'https://example.com'",
678+
):
679+
service._get_key("fake-key-id", "https://example.com")
669680

670681
assert metrics.increment.calls == [
671682
pretend.call(
@@ -690,16 +701,20 @@ def test_get_key_for_token(self, monkeypatch):
690701
cache_url="rediss://fake.example.com",
691702
metrics=pretend.stub(),
692703
)
693-
monkeypatch.setattr(service, "_get_key", pretend.call_recorder(lambda kid: key))
704+
monkeypatch.setattr(
705+
service, "_get_key", pretend.call_recorder(lambda kid, i: key)
706+
)
694707

695708
monkeypatch.setattr(
696709
services.jwt,
697710
"get_unverified_header",
698711
pretend.call_recorder(lambda token: {"kid": "fake-key-id"}),
699712
)
700713

701-
assert service._get_key_for_token(token) == key
702-
assert service._get_key.calls == [pretend.call("fake-key-id")]
714+
assert service._get_key_for_token(token, "https://example.com") == key
715+
assert service._get_key.calls == [
716+
pretend.call("fake-key-id", "https://example.com")
717+
]
703718
assert services.jwt.get_unverified_header.calls == [pretend.call(token)]
704719

705720
def test_reify_publisher(self, monkeypatch):
@@ -802,7 +817,9 @@ def test_verify_jwt_signature_malformed_jwt(self):
802817
metrics=pretend.stub(),
803818
)
804819

805-
assert service.verify_jwt_signature("malformed-jwt") is None
820+
assert (
821+
service.verify_jwt_signature("malformed-jwt", "https://example.com") is None
822+
)
806823

807824
def test_verify_jwt_signature_missing_aud(self):
808825
# {
@@ -830,7 +847,7 @@ def test_verify_jwt_signature_missing_aud(self):
830847
metrics=pretend.stub(),
831848
)
832849

833-
assert service.verify_jwt_signature(jwt) is None
850+
assert service.verify_jwt_signature(jwt, "https://example.com") is None
834851

835852
def test_verify_jwt_signature_wrong_aud(self):
836853
# {
@@ -860,7 +877,7 @@ def test_verify_jwt_signature_wrong_aud(self):
860877
metrics=pretend.stub(),
861878
)
862879

863-
assert service.verify_jwt_signature(jwt) is None
880+
assert service.verify_jwt_signature(jwt, "https://example.com") is None
864881

865882
def test_verify_jwt_signature_strict_aud(self):
866883
# {
@@ -885,7 +902,7 @@ def test_verify_jwt_signature_strict_aud(self):
885902
metrics=pretend.stub(),
886903
)
887904

888-
assert service.verify_jwt_signature(jwt) is None
905+
assert service.verify_jwt_signature(jwt, "https://example.com") is None
889906

890907
def test_find_publisher(self, monkeypatch):
891908
claims = SignedClaims(

0 commit comments

Comments
 (0)