Skip to content

Commit a43ec82

Browse files
authored
IMDSv2 mTLS PoP: add best‑effort persisted binding‑cert cache (Windows CurrentUser\My) layered over in‑memory cache; per‑alias mutex; tests (#5566)
* peristed cert * more tests * pr comments * address pr comments * pr comments
1 parent 2a38dd0 commit a43ec82

17 files changed

+2273
-76
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Security.Cryptography.X509Certificates;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using Microsoft.Identity.Client.Core;
9+
10+
namespace Microsoft.Identity.Client.ManagedIdentity.V2
11+
{
12+
/// <summary>
13+
/// Abstraction over the in-memory + persisted cache for IMDSv2 mTLS binding certificates.
14+
/// </summary>
15+
internal interface IMtlsCertificateCache
16+
{
17+
/// <summary>
18+
/// Returns a cached binding certificate for the given <paramref name="cacheKey"/>,
19+
/// or uses <paramref name="factory"/> to create, persist and return one when needed.
20+
/// </summary>
21+
Task<MtlsBindingInfo> GetOrCreateAsync(
22+
string cacheKey,
23+
Func<Task<MtlsBindingInfo>> factory,
24+
CancellationToken cancellationToken,
25+
ILoggerAdapter logger);
26+
}
27+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System.Security.Cryptography.X509Certificates;
5+
using Microsoft.Identity.Client.Core;
6+
7+
namespace Microsoft.Identity.Client.ManagedIdentity.V2
8+
{
9+
/// <summary>
10+
/// Persistence interface for IMDSv2 mTLS binding certificates.
11+
/// Implementations must be best-effort and non-throwing so that
12+
/// certificate persistence never blocks authentication.
13+
/// </summary>
14+
internal interface IPersistentCertificateCache
15+
{
16+
/// <summary>
17+
/// Reads the newest valid (≥24h remaining, has private key) entry for the alias.
18+
/// Returns <c>true</c> on cache hit, <c>false</c> otherwise.
19+
/// </summary>
20+
bool Read(string alias, out CertificateCacheValue value, ILoggerAdapter logger);
21+
22+
/// <summary>
23+
/// Persists the certificate for the alias (best-effort).
24+
/// Implementations should log failures but must not throw; callers do not
25+
/// depend on persistence succeeding and fall back to in-memory cache only.
26+
/// </summary>
27+
void Write(string alias, X509Certificate2 cert, string endpointBase, ILoggerAdapter logger);
28+
29+
/// <summary>
30+
/// Prunes expired entries for the alias (best-effort).
31+
/// Implementations should remove stale/expired entries while leaving the
32+
/// latest valid binding for the alias in place.
33+
/// </summary>
34+
void Delete(string alias, ILoggerAdapter logger);
35+
}
36+
}

src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs

Lines changed: 26 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity
2626
// Central, process-local cache for mTLS binding (cert + endpoint + canonical client_id).
2727
internal static readonly ICertificateCache s_mtlsCertificateCache = new InMemoryCertificateCache();
2828

29-
// Per-key async de-duplication so concurrent callers don’t double-mint.
30-
internal static readonly ConcurrentDictionary<string, SemaphoreSlim> s_perKeyGates =
31-
new ConcurrentDictionary<string, SemaphoreSlim>(StringComparer.Ordinal);
29+
private readonly IMtlsCertificateCache _mtlsCache;
3230

3331
// used in unit tests
3432
public const string ImdsV2ApiVersion = "2.0";
@@ -193,9 +191,20 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)
193191
return new ImdsV2ManagedIdentitySource(requestContext);
194192
}
195193

196-
internal ImdsV2ManagedIdentitySource(RequestContext requestContext) :
197-
base(requestContext, ManagedIdentitySource.ImdsV2)
198-
{ }
194+
internal ImdsV2ManagedIdentitySource(RequestContext requestContext)
195+
: this(requestContext,
196+
new MtlsBindingCache(s_mtlsCertificateCache, PersistentCertificateCacheFactory
197+
.Create(requestContext.Logger)))
198+
{
199+
}
200+
201+
internal ImdsV2ManagedIdentitySource(
202+
RequestContext requestContext,
203+
IMtlsCertificateCache mtlsCache)
204+
: base(requestContext, ManagedIdentitySource.ImdsV2)
205+
{
206+
_mtlsCache = mtlsCache ?? throw new ArgumentNullException(nameof(mtlsCache));
207+
}
199208

200209
private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(
201210
string clientId,
@@ -291,11 +300,11 @@ private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(
291300

292301
protected override async Task<ManagedIdentityRequest> CreateRequestAsync(string resource)
293302
{
294-
var csrMetadata = await GetCsrMetadataAsync(_requestContext, false).ConfigureAwait(false);
303+
CsrMetadata csrMetadata = await GetCsrMetadataAsync(_requestContext, false).ConfigureAwait(false);
295304

296305
string certCacheKey = _requestContext.ServiceBundle.Config.ClientId;
297306

298-
var certEndpointAndClientId = await GetOrCreateMtlsBindingAsync(
307+
MtlsBindingInfo mtlsBinding = await GetOrCreateMtlsBindingAsync(
299308
cacheKey: certCacheKey,
300309
async () =>
301310
{
@@ -333,15 +342,16 @@ protected override async Task<ManagedIdentityRequest> CreateRequestAsync(string
333342
// Canonical GUID to use as client_id in the token call
334343
string clientIdGuid = certificateRequestResponse.ClientId;
335344

336-
return Tuple.Create(mtlsCertificate, endpointBase, clientIdGuid);
345+
return new MtlsBindingInfo(mtlsCertificate, endpointBase, clientIdGuid);
346+
337347
},
338-
_requestContext.UserCancellationToken,
348+
_requestContext.UserCancellationToken,
339349
_requestContext.Logger)
340350
.ConfigureAwait(false);
341351

342-
X509Certificate2 bindingCertificate = certEndpointAndClientId.Item1;
343-
string endpointBaseForToken = certEndpointAndClientId.Item2;
344-
string clientIdForToken = certEndpointAndClientId.Item3;
352+
X509Certificate2 bindingCertificate = mtlsBinding.Certificate;
353+
string endpointBaseForToken = mtlsBinding.Endpoint;
354+
string clientIdForToken = mtlsBinding.ClientId;
345355

346356
ManagedIdentityRequest request = new ManagedIdentityRequest(
347357
HttpMethod.Post,
@@ -440,65 +450,13 @@ private async Task<string> GetAttestationJwtAsync(
440450
return response.AttestationToken;
441451
}
442452

443-
// ...unchanged usings and class header...
444-
445-
/// <summary>
446-
/// Read-through cache: try cache; if missing, run async factory once (per key),
447-
/// store the result, and return it. Thread-safe for the given cacheKey.
448-
/// </summary>
449-
private static async Task<Tuple<X509Certificate2, string, string>> GetOrCreateMtlsBindingAsync(
453+
private Task<MtlsBindingInfo> GetOrCreateMtlsBindingAsync(
450454
string cacheKey,
451-
Func<Task<Tuple<X509Certificate2, string, string>>> factory,
455+
Func<Task<MtlsBindingInfo>> factory,
452456
CancellationToken cancellationToken,
453457
ILoggerAdapter logger)
454458
{
455-
if (string.IsNullOrWhiteSpace(cacheKey))
456-
throw new ArgumentException("cacheKey must be non-empty.", nameof(cacheKey));
457-
if (factory is null)
458-
throw new ArgumentNullException(nameof(factory));
459-
460-
X509Certificate2 cachedCertificate;
461-
string cachedEndpointBase;
462-
string cachedClientId;
463-
464-
// 1) Only lookup by cacheKey
465-
if (s_mtlsCertificateCache.TryGet(cacheKey, out var cached, logger))
466-
{
467-
cachedCertificate = cached.Certificate;
468-
cachedEndpointBase = cached.Endpoint;
469-
cachedClientId = cached.ClientId;
470-
471-
return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId);
472-
}
473-
474-
// 2) Gate per cacheKey
475-
var gate = s_perKeyGates.GetOrAdd(cacheKey, _ => new SemaphoreSlim(1, 1));
476-
await gate.WaitAsync(cancellationToken).ConfigureAwait(false);
477-
478-
try
479-
{
480-
// Re-check after acquiring the gate
481-
if (s_mtlsCertificateCache.TryGet(cacheKey, out cached, logger))
482-
{
483-
cachedCertificate = cached.Certificate;
484-
cachedEndpointBase = cached.Endpoint;
485-
cachedClientId = cached.ClientId;
486-
return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId);
487-
}
488-
489-
// 3) Mint + cache under the provided cacheKey
490-
var created = await factory().ConfigureAwait(false);
491-
492-
s_mtlsCertificateCache.Set(cacheKey,
493-
new CertificateCacheValue(created.Item1, created.Item2, created.Item3),
494-
logger);
495-
496-
return created;
497-
}
498-
finally
499-
{
500-
gate.Release();
501-
}
459+
return _mtlsCache.GetOrCreateAsync(cacheKey, factory, cancellationToken, logger);
502460
}
503461

504462
internal static void ResetCertCacheForTest()
@@ -508,14 +466,6 @@ internal static void ResetCertCacheForTest()
508466
{
509467
s_mtlsCertificateCache.Clear();
510468
}
511-
512-
foreach (var gate in s_perKeyGates.Values)
513-
{
514-
try
515-
{ gate.Dispose(); }
516-
catch { }
517-
}
518-
s_perKeyGates.Clear();
519469
}
520470
}
521471
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Diagnostics;
6+
using System.Threading;
7+
using Microsoft.Identity.Client.PlatformsCommon.Shared;
8+
9+
namespace Microsoft.Identity.Client.ManagedIdentity.V2
10+
{
11+
/// <summary>
12+
/// Executes paramref name="action"/ under a cross-process, per-alias mutex.
13+
/// We attempt 2 namespaces, in order:
14+
/// 1) <c>Global\</c> — preferred so we dedupe across all sessions on the machine
15+
/// (e.g., service + user session). This can be denied by OS policy or missing
16+
/// SeCreateGlobalPrivilege in some contexts.
17+
/// 2) <c>Local\</c> — fallback to still dedupe within the current session when
18+
/// <c>Global\</c> is not permitted.
19+
/// Using both ensures we never throw (persistence is best-effort) while getting
20+
/// machine-wide dedupe when allowed and session-local dedupe otherwise.
21+
/// Notes:
22+
/// - The mutex name is derived from <c>alias</c> (= cacheKey) via SHA-256 hex (truncated)
23+
/// to avoid invalid characters / length issues.
24+
/// - On non-Windows runtimes the Global/Local prefixes are treated as part of the name;
25+
/// behavior remains correct but dedupe scope is platform-defined.
26+
/// - Abandoned mutexes are treated as acquired to avoid blocking after a crash.
27+
/// </summary>
28+
internal static class InterprocessLock
29+
{
30+
// Prefer Global\ for cross-session dedupe; fall back to Local\
31+
// if ACLs block Global\ to remain non-throwing.
32+
public static bool TryWithAliasLock(
33+
string alias,
34+
TimeSpan timeout,
35+
Action action,
36+
Action<string> logVerbose)
37+
{
38+
var globalName = GetMutexNameForAlias(alias, preferGlobal: true);
39+
var localName = GetMutexNameForAlias(alias, preferGlobal: false);
40+
41+
// Try to acquire and run under the named mutex scope.
42+
// Returns true if action ran, false if lock busy or failure.
43+
// first try Global\, then Local\ if Global\ unauthorized.
44+
bool TryScope(string name, out bool unauthorized)
45+
{
46+
unauthorized = false;
47+
try
48+
{
49+
using var mutex = new Mutex(initiallyOwned: false, name);
50+
51+
bool entered;
52+
var waitTimer = Stopwatch.StartNew();
53+
try
54+
{
55+
entered = mutex.WaitOne(timeout);
56+
}
57+
catch (AbandonedMutexException ex)
58+
{
59+
entered = true;
60+
logVerbose.Invoke($"[PersistentCert] Abandoned mutex '{name}', treating as acquired. {ex.Message}");
61+
}
62+
finally
63+
{
64+
waitTimer.Stop();
65+
}
66+
67+
if (!entered)
68+
{
69+
logVerbose.Invoke(
70+
$"[PersistentCert] Skip persist (lock busy '{name}', waited {waitTimer.Elapsed.TotalMilliseconds:F0} ms).");
71+
return false;
72+
}
73+
74+
try
75+
{
76+
action();
77+
}
78+
catch (Exception ex)
79+
{
80+
logVerbose.Invoke($"[PersistentCert] Action failed under '{name}': {ex.Message}");
81+
return false;
82+
}
83+
finally
84+
{
85+
try
86+
{ mutex.ReleaseMutex(); }
87+
catch { /* best-effort */ }
88+
}
89+
90+
return true;
91+
}
92+
catch (UnauthorizedAccessException)
93+
{
94+
logVerbose.Invoke($"[PersistentCert] No access to mutex scope '{name}', trying next.");
95+
unauthorized = true;
96+
return false;
97+
}
98+
catch (Exception ex)
99+
{
100+
logVerbose.Invoke($"[PersistentCert] Lock failure '{name}': {ex.Message}");
101+
return false;
102+
}
103+
}
104+
105+
// Try Global\ first; only fallback to Local\ if Global\ is unauthorized
106+
if (TryScope(globalName, out var unauthorizedGlobal))
107+
{
108+
return true;
109+
}
110+
111+
// Fallback is only appropriate when Global\ is disallowed by ACLs.
112+
// If Global\ was just busy or the action failed, do not try Local
113+
if (unauthorizedGlobal)
114+
{
115+
if (TryScope(localName, out _))
116+
{
117+
return true;
118+
}
119+
}
120+
121+
return false;
122+
}
123+
124+
public static string GetMutexNameForAlias(string alias, bool preferGlobal = true)
125+
{
126+
string suffix = HashAlias(Canonicalize(alias));
127+
return (preferGlobal ? @"Global\" : @"Local\") + "MSAL_MI_P_" + suffix;
128+
}
129+
130+
private static string Canonicalize(string alias) =>
131+
(alias ?? string.Empty).Trim().ToUpperInvariant();
132+
133+
private static string HashAlias(string s)
134+
{
135+
try
136+
{
137+
var hex = new CommonCryptographyManager().CreateSha256HashHex(s);
138+
// Truncate to 32 chars to fit mutex name length limits
139+
return string.IsNullOrEmpty(hex)
140+
? "0"
141+
: (hex.Length > 32 ? hex.Substring(0, 32) : hex);
142+
}
143+
catch
144+
{
145+
return "0";
146+
}
147+
}
148+
}
149+
}

0 commit comments

Comments
 (0)