44package com .microsoft .aad .msal4j ;
55
66import com .nimbusds .oauth2 .sdk .util .URLUtils ;
7+ import labapi .App ;
78import org .junit .jupiter .api .Nested ;
89import org .junit .jupiter .api .Test ;
910import org .junit .jupiter .api .TestInstance ;
1011import org .junit .jupiter .api .extension .ExtendWith ;
1112import org .junit .jupiter .params .ParameterizedTest ;
1213import org .junit .jupiter .params .provider .MethodSource ;
1314import org .junit .jupiter .params .provider .ValueSource ;
15+ import org .mockito .ArgumentCaptor ;
1416import org .mockito .junit .jupiter .MockitoExtension ;
1517
1618import java .net .SocketException ;
1719import java .nio .file .Path ;
1820import java .nio .file .Paths ;
21+ import java .util .Collections ;
1922import java .util .HashMap ;
2023import java .util .List ;
2124import java .util .Map ;
@@ -78,56 +81,51 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
7881 Map <String , List <String >> queryParameters = new HashMap <>();
7982
8083 switch (source ) {
81- case APP_SERVICE : {
84+ case APP_SERVICE :
8285 endpoint = appServiceEndpoint ;
83-
84- queryParameters .put ("api-version" , singletonList ("2019-08-01" ));
85- queryParameters .put ("resource" , singletonList (resource ));
86-
86+ queryParameters .put ("api-version" , Collections .singletonList ("2019-08-01" ));
87+ queryParameters .put ("resource" , Collections .singletonList (resource ));
8788 headers .put ("X-IDENTITY-HEADER" , "secret" );
8889 break ;
89- }
90- case CLOUD_SHELL : {
90+ case CLOUD_SHELL :
9191 endpoint = cloudShellEndpoint ;
92-
9392 headers .put ("ContentType" , "application/x-www-form-urlencoded" );
9493 headers .put ("Metadata" , "true" );
95-
96- queryParameters .put ("resource" , singletonList (resource ));
94+ queryParameters .put ("resource" , Collections .singletonList (resource ));
9795 break ;
98- }
99- case IMDS : {
96+ case IMDS :
10097 endpoint = IMDS_ENDPOINT ;
101- queryParameters .put ("api-version" , singletonList ("2018-02-01" ));
102- queryParameters .put ("resource" , singletonList (resource ));
98+ queryParameters .put ("api-version" , Collections . singletonList ("2018-02-01" ));
99+ queryParameters .put ("resource" , Collections . singletonList (resource ));
103100 headers .put ("Metadata" , "true" );
104101 break ;
105- }
106- case AZURE_ARC : {
102+ case AZURE_ARC :
107103 endpoint = azureArcEndpoint ;
108-
109- queryParameters .put ("api-version" , singletonList ("2019-11-01" ));
110- queryParameters .put ("resource" , singletonList (resource ));
111-
104+ queryParameters .put ("api-version" , Collections .singletonList ("2019-11-01" ));
105+ queryParameters .put ("resource" , Collections .singletonList (resource ));
112106 headers .put ("Metadata" , "true" );
113107 break ;
114- }
115- case SERVICE_FABRIC : {
108+ case SERVICE_FABRIC :
116109 endpoint = serviceFabricEndpoint ;
117- queryParameters .put ("api-version" , singletonList ("2019-07-01-preview" ));
118- queryParameters .put ("resource" , singletonList (resource ));
119-
110+ queryParameters .put ("api-version" , Collections .singletonList ("2019-07-01-preview" ));
111+ queryParameters .put ("resource" , Collections .singletonList (resource ));
120112 headers .put ("secret" , "secret" );
121113 break ;
122- }
114+ case NONE :
115+ case DEFAULT_TO_IMDS :
116+ endpoint = IMDS_ENDPOINT ;
117+ queryParameters .put ("api-version" , Collections .singletonList ("2018-02-01" ));
118+ queryParameters .put ("resource" , Collections .singletonList (resource ));
119+ headers .put ("Metadata" , "true" );
120+ break ;
123121 }
124122
125123 switch (id .getIdType ()) {
126124 case CLIENT_ID :
127- queryParameters .put ("client_id" , singletonList (id .getUserAssignedId ()));
125+ queryParameters .put ("client_id" , Collections . singletonList (id .getUserAssignedId ()));
128126 break ;
129127 case RESOURCE_ID :
130- queryParameters .put ("mi_res_id" , singletonList (id .getUserAssignedId ()));
128+ queryParameters .put ("mi_res_id" , Collections . singletonList (id .getUserAssignedId ()));
131129 break ;
132130 case OBJECT_ID :
133131 queryParameters .put ("object_id" , singletonList (id .getUserAssignedId ()));
@@ -314,9 +312,10 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou
314312 miApp .tokenCache ().accessTokens .clear ();
315313
316314 try {
317- IAuthenticationResult result = miApp .acquireTokenForManagedIdentity (
315+ miApp .acquireTokenForManagedIdentity (
318316 ManagedIdentityParameters .builder (resource )
319317 .build ()).get ();
318+ fail ("MsalServiceException is expected but not thrown." );
320319 } catch (Exception e ) {
321320 assertNotNull (e );
322321 assertNotNull (e .getCause ());
@@ -325,10 +324,7 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou
325324 MsalServiceException msalMsiException = (MsalServiceException ) e .getCause ();
326325 assertEquals (source .name (), msalMsiException .managedIdentitySource ());
327326 assertEquals (MsalError .USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED , msalMsiException .errorCode ());
328- return ;
329327 }
330-
331- fail ("MsalServiceException is expected but not thrown." );
332328 }
333329
334330 @ ParameterizedTest
@@ -637,21 +633,116 @@ void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String end
637633
638634 miApp = ManagedIdentityApplication
639635 .builder (ManagedIdentityId .systemAssigned ())
640- .clientCapabilities (singletonList ("cp1" ))
641636 .httpClient (httpClientMock )
642637 .build ();
643638
644639 // Clear caching to avoid cross test pollution.
645640 miApp .tokenCache ().accessTokens .clear ();
646641
647642 String claimsJson = "{\" default\" :\" claim\" }" ;
643+
644+ // First call, get the token from the identity provider.
648645 IAuthenticationResult result = miApp .acquireTokenForManagedIdentity (
649646 ManagedIdentityParameters .builder (resource )
650- .claims (claimsJson )
651647 .build ()).get ();
652648
653649 assertNotNull (result .accessToken ());
654- verify (httpClientMock , times (1 )).send (any ());
650+ assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
651+
652+ // Second call, get the token from the cache without passing the claims.
653+ result = miApp .acquireTokenForManagedIdentity (
654+ ManagedIdentityParameters .builder (resource )
655+ .build ()).get ();
656+
657+ assertNotNull (result .accessToken ());
658+ assertEquals (TokenSource .CACHE , result .metadata ().tokenSource ());
659+
660+ // Third call, when claims are passed bypass the cache.
661+ result = miApp .acquireTokenForManagedIdentity (
662+ ManagedIdentityParameters .builder (resource )
663+ .claims (claimsJson )
664+ .build ()).get ();
665+
666+ assertNotNull (result .accessToken ());
667+ assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
668+
669+ verify (httpClientMock , times (2 )).send (any ());
670+ }
671+
672+ @ ParameterizedTest
673+ @ MethodSource ("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError" )
674+ void managedIdentity_ClaimsAndCapabilities (ManagedIdentitySourceType source , String endpoint ) throws Exception {
675+ IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper (source , endpoint );
676+ ManagedIdentityApplication .setEnvironmentVariables (environmentVariables );
677+ DefaultHttpClient httpClientMock = mock (DefaultHttpClient .class );
678+ if (source == SERVICE_FABRIC ) {
679+ ServiceFabricManagedIdentitySource .setHttpClient (httpClientMock );
680+ }
681+
682+ when (httpClientMock .send (expectedRequest (source , resource ))).thenReturn (expectedResponse (200 , getSuccessfulResponse (resource )));
683+
684+ miApp = ManagedIdentityApplication
685+ .builder (ManagedIdentityId .systemAssigned ())
686+ .clientCapabilities (singletonList ("cp1" ))
687+ .httpClient (httpClientMock )
688+ .build ();
689+
690+ // Clear caching to avoid cross test pollution.
691+ miApp .tokenCache ().accessTokens .clear ();
692+
693+ String claimsJson = "{\" default\" :\" claim\" }" ;
694+ // First call, get the token from the identity provider.
695+ IAuthenticationResult result = miApp .acquireTokenForManagedIdentity (
696+ ManagedIdentityParameters .builder (resource )
697+ .build ()).get ();
698+
699+ assertNotNull (result .accessToken ());
700+ assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
701+
702+ // Second call, get the token from the cache without passing the claims.
703+ result = miApp .acquireTokenForManagedIdentity (
704+ ManagedIdentityParameters .builder (resource )
705+ .build ()).get ();
706+
707+ assertNotNull (result .accessToken ());
708+ assertEquals (TokenSource .CACHE , result .metadata ().tokenSource ());
709+
710+ // Third call, when claims are passed bypass the cache.
711+ result = miApp .acquireTokenForManagedIdentity (
712+ ManagedIdentityParameters .builder (resource )
713+ .claims (claimsJson )
714+ .build ()).get ();
715+
716+ assertNotNull (result .accessToken ());
717+ assertEquals (TokenSource .IDENTITY_PROVIDER , result .metadata ().tokenSource ());
718+
719+ verify (httpClientMock , times (2 )).send (any ());
720+ }
721+
722+ @ ParameterizedTest
723+ @ MethodSource ("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createInvalidClaimsData" )
724+ void managedIdentity_InvalidClaims (String claimsJson ) throws Exception {
725+ IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper (APP_SERVICE , appServiceEndpoint );
726+ ManagedIdentityApplication .setEnvironmentVariables (environmentVariables );
727+ DefaultHttpClient httpClientMock = mock (DefaultHttpClient .class );
728+
729+ miApp = ManagedIdentityApplication
730+ .builder (ManagedIdentityId .systemAssigned ())
731+ .httpClient (httpClientMock )
732+ .build ();
733+
734+ CompletableFuture <IAuthenticationResult > future = miApp .acquireTokenForManagedIdentity (
735+ ManagedIdentityParameters .builder (resource )
736+ .claims (claimsJson )
737+ .build ());
738+
739+ ExecutionException ex = assertThrows (ExecutionException .class , future ::get );
740+ assertInstanceOf (MsalClientException .class , ex .getCause ());
741+ MsalClientException msalException = (MsalClientException ) ex .getCause ();
742+ assertEquals (AuthenticationErrorCode .INVALID_JSON , msalException .errorCode ());
743+
744+ // Verify no HTTP requests were made for invalid claims
745+ verify (httpClientMock , never ()).send (any ());
655746 }
656747
657748 @ Nested
0 commit comments