@@ -70,16 +70,34 @@ private String getMsiErrorResponseNoRetry() {
7070 return "{\" statusCode\" :\" 123\" ,\" message\" :\" Not one of the retryable error responses\" ,\" correlationId\" :\" 7d0c9763-ff1d-4842-a3f3-6d49e64f4513\" }" ;
7171 }
7272
73+ private HttpRequest expectedRequest (ManagedIdentitySourceType source , String resource , boolean hasClaims , boolean hasCapabilities , String expectedTokenHash ) {
74+ return expectedRequest (source , resource , ManagedIdentityId .systemAssigned (), hasClaims , hasCapabilities , expectedTokenHash );
75+ }
76+
77+ private HttpRequest expectedRequest (ManagedIdentitySourceType source , String resource , ManagedIdentityId id ) {
78+ return expectedRequest (source , resource , id , false , false , null );
79+ }
80+
7381 private HttpRequest expectedRequest (ManagedIdentitySourceType source , String resource ) {
74- return expectedRequest (source , resource , ManagedIdentityId .systemAssigned ());
82+ return expectedRequest (source , resource , ManagedIdentityId .systemAssigned (), false , false , null );
7583 }
7684
7785 private HttpRequest expectedRequest (ManagedIdentitySourceType source , String resource ,
78- ManagedIdentityId id ) {
86+ ManagedIdentityId id , boolean hasClaims , boolean hasCapabilities , String expectedTokenHash ) {
7987 String endpoint = null ;
8088 Map <String , String > headers = new HashMap <>();
8189 Map <String , List <String >> queryParameters = new HashMap <>();
8290
91+ if (Constants .TOKEN_REVOCATION_SUPPORTED_ENVIRONMENTS .contains (source )) {
92+ if (hasCapabilities ) {
93+ queryParameters .put (Constants .CLIENT_CAPABILITY_REQUEST_PARAM , Collections .singletonList ("cp1" ));
94+ }
95+
96+ if (hasClaims ) {
97+ queryParameters .put (Constants .TOKEN_HASH_CLAIM , Collections .singletonList (expectedTokenHash ));
98+ }
99+ }
100+
83101 switch (source ) {
84102 case APP_SERVICE :
85103 endpoint = appServiceEndpoint ;
@@ -93,12 +111,6 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
93111 headers .put ("Metadata" , "true" );
94112 queryParameters .put ("resource" , Collections .singletonList (resource ));
95113 break ;
96- case IMDS :
97- endpoint = IMDS_ENDPOINT ;
98- queryParameters .put ("api-version" , Collections .singletonList ("2018-02-01" ));
99- queryParameters .put ("resource" , Collections .singletonList (resource ));
100- headers .put ("Metadata" , "true" );
101- break ;
102114 case AZURE_ARC :
103115 endpoint = azureArcEndpoint ;
104116 queryParameters .put ("api-version" , Collections .singletonList ("2019-11-01" ));
@@ -111,6 +123,7 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
111123 queryParameters .put ("resource" , Collections .singletonList (resource ));
112124 headers .put ("secret" , "secret" );
113125 break ;
126+ case IMDS :
114127 case NONE :
115128 case DEFAULT_TO_IMDS :
116129 endpoint = IMDS_ENDPOINT ;
@@ -657,6 +670,9 @@ void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String end
657670 assertNotNull (result .accessToken ());
658671 assertEquals (TokenSource .CACHE , result .metadata ().tokenSource ());
659672
673+ String expectedTokenHash = StringHelper .createSha256HashHexString (result .accessToken ());
674+ when (httpClientMock .send (expectedRequest (source , resource , true , false , expectedTokenHash ))).thenReturn (expectedResponse (200 , getSuccessfulResponse (resource )));
675+
660676 // Third call, when claims are passed bypass the cache.
661677 result = miApp .acquireTokenForManagedIdentity (
662678 ManagedIdentityParameters .builder (resource )
@@ -669,6 +685,46 @@ void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String end
669685 verify (httpClientMock , times (2 )).send (any ());
670686 }
671687
688+ @ ParameterizedTest
689+ @ MethodSource ("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError" )
690+ void managedIdentityTest_WithCapabilitiesOnly (ManagedIdentitySourceType source , String endpoint ) throws Exception {
691+ IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper (source , endpoint );
692+ ManagedIdentityApplication .setEnvironmentVariables (environmentVariables );
693+ DefaultHttpClient httpClientMock = mock (DefaultHttpClient .class );
694+ if (source == SERVICE_FABRIC ) {
695+ ServiceFabricManagedIdentitySource .setHttpClient (httpClientMock );
696+ }
697+
698+ when (httpClientMock .send (expectedRequest (source , resource , false , true , null ))).thenReturn (expectedResponse (200 , getSuccessfulResponse (resource )));
699+
700+ miApp = ManagedIdentityApplication
701+ .builder (ManagedIdentityId .systemAssigned ())
702+ .httpClient (httpClientMock )
703+ .clientCapabilities (singletonList ("cp1" ))
704+ .build ();
705+
706+ // Clear caching to avoid cross test pollution.
707+ miApp .tokenCache ().accessTokens .clear ();
708+
709+ // First call, get the token from the identity provider.
710+ IAuthenticationResult result = miApp .acquireTokenForManagedIdentity (
711+ ManagedIdentityParameters .builder (resource )
712+ .build ()).get ();
713+
714+ assertNotNull (result .accessToken ());
715+ assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
716+
717+ // Second call, get the token from the cache without passing the claims.
718+ result = miApp .acquireTokenForManagedIdentity (
719+ ManagedIdentityParameters .builder (resource )
720+ .build ()).get ();
721+
722+ assertNotNull (result .accessToken ());
723+ assertEquals (TokenSource .CACHE , result .metadata ().tokenSource ());
724+
725+ verify (httpClientMock , times (1 )).send (any ());
726+ }
727+
672728 @ ParameterizedTest
673729 @ MethodSource ("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError" )
674730 void managedIdentity_ClaimsAndCapabilities (ManagedIdentitySourceType source , String endpoint ) throws Exception {
@@ -679,7 +735,7 @@ void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, Str
679735 ServiceFabricManagedIdentitySource .setHttpClient (httpClientMock );
680736 }
681737
682- when (httpClientMock .send (expectedRequest (source , resource ))).thenReturn (expectedResponse (200 , getSuccessfulResponse (resource )));
738+ when (httpClientMock .send (expectedRequest (source , resource , false , true , null ))).thenReturn (expectedResponse (200 , getSuccessfulResponse (resource )));
683739
684740 miApp = ManagedIdentityApplication
685741 .builder (ManagedIdentityId .systemAssigned ())
@@ -707,6 +763,9 @@ void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, Str
707763 assertNotNull (result .accessToken ());
708764 assertEquals (TokenSource .CACHE , result .metadata ().tokenSource ());
709765
766+ String expectedTokenHash = StringHelper .createSha256HashHexString (result .accessToken ());
767+ when (httpClientMock .send (expectedRequest (source , resource , true , true , expectedTokenHash ))).thenReturn (expectedResponse (200 , getSuccessfulResponse (resource )));
768+
710769 // Third call, when claims are passed bypass the cache.
711770 result = miApp .acquireTokenForManagedIdentity (
712771 ManagedIdentityParameters .builder (resource )
@@ -715,8 +774,6 @@ void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, Str
715774
716775 assertNotNull (result .accessToken ());
717776 assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
718-
719- verify (httpClientMock , times (2 )).send (any ());
720777 }
721778
722779 @ ParameterizedTest
0 commit comments