@@ -59,10 +59,11 @@ def test_interface_matches(self):
5959 )
6060
6161 def test_verify_jwt_signature (self , monkeypatch ):
62+ issuer_url = "https://example.com"
6263 service = services .OIDCPublisherService (
6364 session = pretend .stub (),
6465 publisher = pretend .stub (),
65- issuer_url = pretend . stub () ,
66+ issuer_url = issuer_url ,
6667 audience = "fakeaudience" ,
6768 cache_url = pretend .stub (),
6869 metrics = pretend .stub (),
@@ -72,12 +73,10 @@ def test_verify_jwt_signature(self, monkeypatch):
7273 decoded = pretend .stub ()
7374 jwt = pretend .stub (decode = pretend .call_recorder (lambda t , ** kwargs : decoded ))
7475 key = pretend .stub (key = "fake-key" )
75- monkeypatch .setattr (
76- service , "_get_key_for_token" , pretend .call_recorder (lambda t : key )
77- )
76+ monkeypatch .setattr (service , "_get_key_for_token" , lambda t , i = issuer_url : key )
7877 monkeypatch .setattr (services , "jwt" , jwt )
7978
80- assert service .verify_jwt_signature (token ) == decoded
79+ assert service .verify_jwt_signature (token , issuer_url ) == decoded
8180 assert jwt .decode .calls == [
8281 pretend .call (
8382 token ,
@@ -93,7 +92,7 @@ def test_verify_jwt_signature(self, monkeypatch):
9392 verify_nbf = True ,
9493 strict_aud = True ,
9594 ),
96- issuer = service . issuer_url ,
95+ issuer = issuer_url ,
9796 audience = "fakeaudience" ,
9897 leeway = 30 ,
9998 )
@@ -119,7 +118,7 @@ def test_verify_jwt_signature_get_key_for_token_fails(self, metrics, monkeypatch
119118 pretend .call_recorder (lambda s : None ),
120119 )
121120
122- assert service .verify_jwt_signature (token ) is None
121+ assert service .verify_jwt_signature (token , "https://none" ) is None
123122 assert service .metrics .increment .calls == [
124123 pretend .call (
125124 "warehouse.oidc.verify_jwt_signature.malformed_jwt" ,
@@ -130,10 +129,11 @@ def test_verify_jwt_signature_get_key_for_token_fails(self, metrics, monkeypatch
130129
131130 @pytest .mark .parametrize ("exc" , [PyJWTError , TypeError ("foo" )])
132131 def test_verify_jwt_signature_fails (self , metrics , monkeypatch , exc ):
132+ issuer_url = "https://none"
133133 service = services .OIDCPublisherService (
134134 session = pretend .stub (),
135135 publisher = "fakepublisher" ,
136- issuer_url = "https://none" ,
136+ issuer_url = issuer_url ,
137137 audience = "fakeaudience" ,
138138 cache_url = pretend .stub (),
139139 metrics = metrics ,
@@ -143,7 +143,9 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
143143 jwt = pretend .stub (decode = pretend .raiser (exc ), PyJWTError = PyJWTError )
144144 key = pretend .stub (key = "fake-key" )
145145 monkeypatch .setattr (
146- service , "_get_key_for_token" , pretend .call_recorder (lambda t : key )
146+ service ,
147+ "_get_key_for_token" ,
148+ pretend .call_recorder (lambda t , i = issuer_url : key ),
147149 )
148150 monkeypatch .setattr (services , "jwt" , jwt )
149151 monkeypatch .setattr (
@@ -152,7 +154,7 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
152154 pretend .call_recorder (lambda s : None ),
153155 )
154156
155- assert service .verify_jwt_signature (token ) is None
157+ assert service .verify_jwt_signature (token , issuer_url ) is None
156158 assert service .metrics .increment .calls == [
157159 pretend .call (
158160 "warehouse.oidc.verify_jwt_signature.invalid_signature" ,
@@ -168,16 +170,17 @@ def test_verify_jwt_signature_fails(self, metrics, monkeypatch, exc):
168170 assert services .sentry_sdk .capture_message .calls == []
169171
170172 def test_find_publisher (self , metrics , monkeypatch ):
173+ issuer_url = "https://none"
171174 service = services .OIDCPublisherService (
172175 session = pretend .stub (),
173176 publisher = "fakepublisher" ,
174- issuer_url = "https://none" ,
177+ issuer_url = issuer_url ,
175178 audience = "fakeaudience" ,
176179 cache_url = pretend .stub (),
177180 metrics = metrics ,
178181 )
179182
180- token = SignedClaims ({})
183+ token = SignedClaims ({"iss" : issuer_url })
181184
182185 publisher = pretend .stub (verify_claims = pretend .call_recorder (lambda c , s : True ))
183186 find_publisher_by_issuer = pretend .call_recorder (lambda * a , ** kw : publisher )
@@ -198,10 +201,11 @@ def test_find_publisher(self, metrics, monkeypatch):
198201 ]
199202
200203 def test_find_publisher_issuer_lookup_fails (self , metrics , monkeypatch ):
204+ issuer_url = "https://none"
201205 service = services .OIDCPublisherService (
202206 session = pretend .stub (),
203207 publisher = "fakepublisher" ,
204- issuer_url = "https://none" ,
208+ issuer_url = issuer_url ,
205209 audience = "fakeaudience" ,
206210 cache_url = pretend .stub (),
207211 metrics = metrics ,
@@ -212,7 +216,7 @@ def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch):
212216 services , "find_publisher_by_issuer" , find_publisher_by_issuer
213217 )
214218
215- claims = pretend . stub ( )
219+ claims = SignedClaims ({ "iss" : issuer_url } )
216220 with pytest .raises (errors .InvalidPublisherError ):
217221 service .find_publisher (claims )
218222 assert service .metrics .increment .calls == [
@@ -227,10 +231,11 @@ def test_find_publisher_issuer_lookup_fails(self, metrics, monkeypatch):
227231 ]
228232
229233 def test_find_publisher_verify_claims_fails (self , metrics , monkeypatch ):
234+ issuer_url = "https://none"
230235 service = services .OIDCPublisherService (
231236 session = pretend .stub (),
232237 publisher = "fakepublisher" ,
233- issuer_url = "https://none" ,
238+ issuer_url = issuer_url ,
234239 audience = "fakeaudience" ,
235240 cache_url = pretend .stub (),
236241 metrics = metrics ,
@@ -246,7 +251,7 @@ def test_find_publisher_verify_claims_fails(self, metrics, monkeypatch):
246251 services , "find_publisher_by_issuer" , find_publisher_by_issuer
247252 )
248253
249- claims = SignedClaims ({})
254+ claims = SignedClaims ({"iss" : issuer_url })
250255 with pytest .raises (errors .InvalidPublisherError ):
251256 service .find_publisher (claims )
252257 assert service .metrics .increment .calls == [
@@ -311,7 +316,7 @@ def test_get_keyset_not_cached(self, monkeypatch, mockredis):
311316
312317 monkeypatch .setattr (services .redis , "StrictRedis" , mockredis )
313318
314- keys , timeout = service ._get_keyset ()
319+ keys , timeout = service ._get_keyset ("https://example.com" )
315320
316321 assert not keys
317322 assert timeout is False
@@ -329,17 +334,18 @@ def test_get_keyset_cached(self, monkeypatch, mockredis):
329334 monkeypatch .setattr (services .redis , "StrictRedis" , mockredis )
330335
331336 keyset = {"fake-key-id" : {"foo" : "bar" }}
332- service ._store_keyset (keyset )
333- keys , timeout = service ._get_keyset ()
337+ service ._store_keyset ("https://example.com" , keyset )
338+ keys , timeout = service ._get_keyset ("https://example.com" )
334339
335340 assert keys == keyset
336341 assert timeout is True
337342
338343 def test_refresh_keyset_timeout (self , metrics , monkeypatch , mockredis ):
344+ issuer_url = "https://example.com"
339345 service = services .OIDCPublisherService (
340346 session = pretend .stub (),
341347 publisher = "example" ,
342- issuer_url = "https://example.com" ,
348+ issuer_url = issuer_url ,
343349 audience = "fakeaudience" ,
344350 cache_url = "rediss://fake.example.com" ,
345351 metrics = metrics ,
@@ -348,9 +354,9 @@ def test_refresh_keyset_timeout(self, metrics, monkeypatch, mockredis):
348354 monkeypatch .setattr (services .redis , "StrictRedis" , mockredis )
349355
350356 keyset = {"fake-key-id" : {"foo" : "bar" }}
351- service ._store_keyset (keyset )
357+ service ._store_keyset (issuer_url , keyset )
352358
353- keys = service ._refresh_keyset ()
359+ keys = service ._refresh_keyset (issuer_url )
354360 assert keys == keyset
355361 assert metrics .increment .calls == [
356362 pretend .call (
@@ -380,7 +386,7 @@ def test_refresh_keyset_oidc_config_fails(self, metrics, monkeypatch, mockredis)
380386 monkeypatch .setattr (services , "requests" , requests )
381387 monkeypatch .setattr (services , "sentry_sdk" , sentry_sdk )
382388
383- keys = service ._refresh_keyset ()
389+ keys = service ._refresh_keyset ("https://example.com" )
384390
385391 assert keys == {}
386392 assert metrics .increment .calls == []
@@ -421,7 +427,7 @@ def test_refresh_keyset_oidc_config_no_jwks_uri(
421427 monkeypatch .setattr (services , "requests" , requests )
422428 monkeypatch .setattr (services , "sentry_sdk" , sentry_sdk )
423429
424- keys = service ._refresh_keyset ()
430+ keys = service ._refresh_keyset ("https://example.com" )
425431
426432 assert keys == {}
427433 assert metrics .increment .calls == []
@@ -472,7 +478,7 @@ def get(url, timeout=5):
472478 monkeypatch .setattr (services , "requests" , requests )
473479 monkeypatch .setattr (services , "sentry_sdk" , sentry_sdk )
474480
475- keys = service ._refresh_keyset ()
481+ keys = service ._refresh_keyset ("https://example.com" )
476482
477483 assert keys == {}
478484 assert metrics .increment .calls == []
@@ -524,7 +530,7 @@ def get(url, timeout=5):
524530 monkeypatch .setattr (services , "requests" , requests )
525531 monkeypatch .setattr (services , "sentry_sdk" , sentry_sdk )
526532
527- keys = service ._refresh_keyset ()
533+ keys = service ._refresh_keyset ("https://example.com" )
528534
529535 assert keys == {}
530536 assert metrics .increment .calls == []
@@ -573,7 +579,7 @@ def get(url, timeout=5):
573579 monkeypatch .setattr (services , "requests" , requests )
574580 monkeypatch .setattr (services , "sentry_sdk" , sentry_sdk )
575581
576- keys = service ._refresh_keyset ()
582+ keys = service ._refresh_keyset ("https://example.com" )
577583
578584 assert keys == {"fake-key-id" : {"kid" : "fake-key-id" , "foo" : "bar" }}
579585 assert metrics .increment .calls == []
@@ -586,7 +592,7 @@ def get(url, timeout=5):
586592 assert sentry_sdk .capture_message .calls == []
587593
588594 # Ensure that we also cached the updated keyset as part of refreshing.
589- keys , timeout = service ._get_keyset ()
595+ keys , timeout = service ._get_keyset ("https://example.com" )
590596 assert keys == {"fake-key-id" : {"kid" : "fake-key-id" , "foo" : "bar" }}
591597 assert timeout is True
592598
@@ -612,9 +618,11 @@ def test_get_key_cached(self, metrics, monkeypatch):
612618 "x5t" : "dummy" ,
613619 }
614620 }
615- monkeypatch .setattr (service , "_get_keyset" , lambda : (keyset , True ))
621+ monkeypatch .setattr (
622+ service , "_get_keyset" , lambda issuer_url = None : (keyset , True )
623+ )
616624
617- key = service ._get_key ("fake-key-id" )
625+ key = service ._get_key ("fake-key-id" , "https://example.com" )
618626 assert isinstance (key , PyJWK )
619627 assert key .key_id == "fake-key-id"
620628
@@ -642,10 +650,10 @@ def test_get_key_uncached(self, metrics, monkeypatch):
642650 "x5t" : "dummy" ,
643651 }
644652 }
645- monkeypatch .setattr (service , "_get_keyset" , lambda : ({}, False ))
646- monkeypatch .setattr (service , "_refresh_keyset" , lambda : keyset )
653+ monkeypatch .setattr (service , "_get_keyset" , lambda issuer_url = None : ({}, False ))
654+ monkeypatch .setattr (service , "_refresh_keyset" , lambda issuer_url = None : keyset )
647655
648- key = service ._get_key ("fake-key-id" )
656+ key = service ._get_key ("fake-key-id" , "https://example.com" )
649657 assert isinstance (key , PyJWK )
650658 assert key .key_id == "fake-key-id"
651659
@@ -661,11 +669,14 @@ def test_get_key_refresh_fails(self, metrics, monkeypatch):
661669 metrics = metrics ,
662670 )
663671
664- monkeypatch .setattr (service , "_get_keyset" , lambda : ({}, False ))
665- monkeypatch .setattr (service , "_refresh_keyset" , lambda : {})
672+ monkeypatch .setattr (service , "_get_keyset" , lambda issuer_url = None : ({}, False ))
673+ monkeypatch .setattr (service , "_refresh_keyset" , lambda issuer_url = None : {})
666674
667- key = service ._get_key ("fake-key-id" )
668- assert key is None
675+ with pytest .raises (
676+ jwt .PyJWTError ,
677+ match = r"Key ID 'fake-key-id' not found for issuer 'https://example.com'" ,
678+ ):
679+ service ._get_key ("fake-key-id" , "https://example.com" )
669680
670681 assert metrics .increment .calls == [
671682 pretend .call (
@@ -690,16 +701,20 @@ def test_get_key_for_token(self, monkeypatch):
690701 cache_url = "rediss://fake.example.com" ,
691702 metrics = pretend .stub (),
692703 )
693- monkeypatch .setattr (service , "_get_key" , pretend .call_recorder (lambda kid : key ))
704+ monkeypatch .setattr (
705+ service , "_get_key" , pretend .call_recorder (lambda kid , i : key )
706+ )
694707
695708 monkeypatch .setattr (
696709 services .jwt ,
697710 "get_unverified_header" ,
698711 pretend .call_recorder (lambda token : {"kid" : "fake-key-id" }),
699712 )
700713
701- assert service ._get_key_for_token (token ) == key
702- assert service ._get_key .calls == [pretend .call ("fake-key-id" )]
714+ assert service ._get_key_for_token (token , "https://example.com" ) == key
715+ assert service ._get_key .calls == [
716+ pretend .call ("fake-key-id" , "https://example.com" )
717+ ]
703718 assert services .jwt .get_unverified_header .calls == [pretend .call (token )]
704719
705720 def test_reify_publisher (self , monkeypatch ):
@@ -802,7 +817,9 @@ def test_verify_jwt_signature_malformed_jwt(self):
802817 metrics = pretend .stub (),
803818 )
804819
805- assert service .verify_jwt_signature ("malformed-jwt" ) is None
820+ assert (
821+ service .verify_jwt_signature ("malformed-jwt" , "https://example.com" ) is None
822+ )
806823
807824 def test_verify_jwt_signature_missing_aud (self ):
808825 # {
@@ -830,7 +847,7 @@ def test_verify_jwt_signature_missing_aud(self):
830847 metrics = pretend .stub (),
831848 )
832849
833- assert service .verify_jwt_signature (jwt ) is None
850+ assert service .verify_jwt_signature (jwt , "https://example.com" ) is None
834851
835852 def test_verify_jwt_signature_wrong_aud (self ):
836853 # {
@@ -860,7 +877,7 @@ def test_verify_jwt_signature_wrong_aud(self):
860877 metrics = pretend .stub (),
861878 )
862879
863- assert service .verify_jwt_signature (jwt ) is None
880+ assert service .verify_jwt_signature (jwt , "https://example.com" ) is None
864881
865882 def test_verify_jwt_signature_strict_aud (self ):
866883 # {
@@ -885,7 +902,7 @@ def test_verify_jwt_signature_strict_aud(self):
885902 metrics = pretend .stub (),
886903 )
887904
888- assert service .verify_jwt_signature (jwt ) is None
905+ assert service .verify_jwt_signature (jwt , "https://example.com" ) is None
889906
890907 def test_find_publisher (self , monkeypatch ):
891908 claims = SignedClaims (
0 commit comments