@@ -26,9 +26,7 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity
2626 // Central, process-local cache for mTLS binding (cert + endpoint + canonical client_id).
2727 internal static readonly ICertificateCache s_mtlsCertificateCache = new InMemoryCertificateCache ( ) ;
2828
29- // Per-key async de-duplication so concurrent callers don’t double-mint.
30- internal static readonly ConcurrentDictionary < string , SemaphoreSlim > s_perKeyGates =
31- new ConcurrentDictionary < string , SemaphoreSlim > ( StringComparer . Ordinal ) ;
29+ private readonly IMtlsCertificateCache _mtlsCache ;
3230
3331 // used in unit tests
3432 public const string ImdsV2ApiVersion = "2.0" ;
@@ -193,9 +191,20 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)
193191 return new ImdsV2ManagedIdentitySource ( requestContext ) ;
194192 }
195193
196- internal ImdsV2ManagedIdentitySource ( RequestContext requestContext ) :
197- base ( requestContext , ManagedIdentitySource . ImdsV2 )
198- { }
194+ internal ImdsV2ManagedIdentitySource ( RequestContext requestContext )
195+ : this ( requestContext ,
196+ new MtlsBindingCache ( s_mtlsCertificateCache , PersistentCertificateCacheFactory
197+ . Create ( requestContext . Logger ) ) )
198+ {
199+ }
200+
201+ internal ImdsV2ManagedIdentitySource (
202+ RequestContext requestContext ,
203+ IMtlsCertificateCache mtlsCache )
204+ : base ( requestContext , ManagedIdentitySource . ImdsV2 )
205+ {
206+ _mtlsCache = mtlsCache ?? throw new ArgumentNullException ( nameof ( mtlsCache ) ) ;
207+ }
199208
200209 private async Task < CertificateRequestResponse > ExecuteCertificateRequestAsync (
201210 string clientId ,
@@ -291,11 +300,11 @@ private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(
291300
292301 protected override async Task < ManagedIdentityRequest > CreateRequestAsync ( string resource )
293302 {
294- var csrMetadata = await GetCsrMetadataAsync ( _requestContext , false ) . ConfigureAwait ( false ) ;
303+ CsrMetadata csrMetadata = await GetCsrMetadataAsync ( _requestContext , false ) . ConfigureAwait ( false ) ;
295304
296305 string certCacheKey = _requestContext . ServiceBundle . Config . ClientId ;
297306
298- var certEndpointAndClientId = await GetOrCreateMtlsBindingAsync (
307+ MtlsBindingInfo mtlsBinding = await GetOrCreateMtlsBindingAsync (
299308 cacheKey : certCacheKey ,
300309 async ( ) =>
301310 {
@@ -333,15 +342,16 @@ protected override async Task<ManagedIdentityRequest> CreateRequestAsync(string
333342 // Canonical GUID to use as client_id in the token call
334343 string clientIdGuid = certificateRequestResponse . ClientId ;
335344
336- return Tuple . Create ( mtlsCertificate , endpointBase , clientIdGuid ) ;
345+ return new MtlsBindingInfo ( mtlsCertificate , endpointBase , clientIdGuid ) ;
346+
337347 } ,
338- _requestContext . UserCancellationToken ,
348+ _requestContext . UserCancellationToken ,
339349 _requestContext . Logger )
340350 . ConfigureAwait ( false ) ;
341351
342- X509Certificate2 bindingCertificate = certEndpointAndClientId . Item1 ;
343- string endpointBaseForToken = certEndpointAndClientId . Item2 ;
344- string clientIdForToken = certEndpointAndClientId . Item3 ;
352+ X509Certificate2 bindingCertificate = mtlsBinding . Certificate ;
353+ string endpointBaseForToken = mtlsBinding . Endpoint ;
354+ string clientIdForToken = mtlsBinding . ClientId ;
345355
346356 ManagedIdentityRequest request = new ManagedIdentityRequest (
347357 HttpMethod . Post ,
@@ -440,65 +450,13 @@ private async Task<string> GetAttestationJwtAsync(
440450 return response . AttestationToken ;
441451 }
442452
443- // ...unchanged usings and class header...
444-
445- /// <summary>
446- /// Read-through cache: try cache; if missing, run async factory once (per key),
447- /// store the result, and return it. Thread-safe for the given cacheKey.
448- /// </summary>
449- private static async Task < Tuple < X509Certificate2 , string , string > > GetOrCreateMtlsBindingAsync (
453+ private Task < MtlsBindingInfo > GetOrCreateMtlsBindingAsync (
450454 string cacheKey ,
451- Func < Task < Tuple < X509Certificate2 , string , string > > > factory ,
455+ Func < Task < MtlsBindingInfo > > factory ,
452456 CancellationToken cancellationToken ,
453457 ILoggerAdapter logger )
454458 {
455- if ( string . IsNullOrWhiteSpace ( cacheKey ) )
456- throw new ArgumentException ( "cacheKey must be non-empty." , nameof ( cacheKey ) ) ;
457- if ( factory is null )
458- throw new ArgumentNullException ( nameof ( factory ) ) ;
459-
460- X509Certificate2 cachedCertificate ;
461- string cachedEndpointBase ;
462- string cachedClientId ;
463-
464- // 1) Only lookup by cacheKey
465- if ( s_mtlsCertificateCache . TryGet ( cacheKey , out var cached , logger ) )
466- {
467- cachedCertificate = cached . Certificate ;
468- cachedEndpointBase = cached . Endpoint ;
469- cachedClientId = cached . ClientId ;
470-
471- return Tuple . Create ( cachedCertificate , cachedEndpointBase , cachedClientId ) ;
472- }
473-
474- // 2) Gate per cacheKey
475- var gate = s_perKeyGates . GetOrAdd ( cacheKey , _ => new SemaphoreSlim ( 1 , 1 ) ) ;
476- await gate . WaitAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
477-
478- try
479- {
480- // Re-check after acquiring the gate
481- if ( s_mtlsCertificateCache . TryGet ( cacheKey , out cached , logger ) )
482- {
483- cachedCertificate = cached . Certificate ;
484- cachedEndpointBase = cached . Endpoint ;
485- cachedClientId = cached . ClientId ;
486- return Tuple . Create ( cachedCertificate , cachedEndpointBase , cachedClientId ) ;
487- }
488-
489- // 3) Mint + cache under the provided cacheKey
490- var created = await factory ( ) . ConfigureAwait ( false ) ;
491-
492- s_mtlsCertificateCache . Set ( cacheKey ,
493- new CertificateCacheValue ( created . Item1 , created . Item2 , created . Item3 ) ,
494- logger ) ;
495-
496- return created ;
497- }
498- finally
499- {
500- gate . Release ( ) ;
501- }
459+ return _mtlsCache . GetOrCreateAsync ( cacheKey , factory , cancellationToken , logger ) ;
502460 }
503461
504462 internal static void ResetCertCacheForTest ( )
@@ -508,14 +466,6 @@ internal static void ResetCertCacheForTest()
508466 {
509467 s_mtlsCertificateCache . Clear ( ) ;
510468 }
511-
512- foreach ( var gate in s_perKeyGates . Values )
513- {
514- try
515- { gate . Dispose ( ) ; }
516- catch { }
517- }
518- s_perKeyGates . Clear ( ) ;
519469 }
520470 }
521471}
0 commit comments