22
33use base64:: { engine:: general_purpose, Engine as _} ;
44use reqwest:: {
5- header:: { HeaderMap , HeaderValue , ACCEPT , USER_AGENT } ,
5+ header:: { HeaderMap , HeaderValue , ACCEPT , AUTHORIZATION , USER_AGENT } ,
66 Client ,
77} ;
88use serde:: Deserialize ;
@@ -18,6 +18,10 @@ use std::fs;
1818use std:: path:: PathBuf ;
1919use std:: sync:: RwLock ;
2020
21+ // Azure Identity imports for MSI authentication
22+ use azure_core:: credentials:: TokenCredential ;
23+ use azure_identity:: { ManagedIdentityCredential , ManagedIdentityCredentialOptions , UserAssignedId } ;
24+
2125/// Authentication methods for the Geneva Config Client.
2226///
2327/// The client supports two authentication methods:
@@ -53,25 +57,29 @@ pub enum AuthMethod {
5357 /// * `path` - Path to the PKCS#12 (.p12) certificate file
5458 /// * `password` - Password to decrypt the PKCS#12 file
5559 Certificate { path : PathBuf , password : String } ,
56- /// Azure Managed Identity authentication
57- ///
58- /// Note(TODO): This is not yet implemented.
59- ManagedIdentity ,
60+ /// System-assigned managed identity (auto-detected)
61+ SystemManagedIdentity ,
62+ /// User-assigned managed identity by client ID
63+ UserManagedIdentity { client_id : String } ,
64+ /// User-assigned managed identity by object ID
65+ UserManagedIdentityByObjectId { object_id : String } ,
66+ /// User-assigned managed identity by resource ID
67+ UserManagedIdentityByResourceId { resource_id : String } ,
6068 #[ cfg( feature = "mock_auth" ) ]
6169 MockAuth , // No authentication, used for testing purposes
6270}
6371
6472#[ derive( Debug , Error ) ]
6573pub ( crate ) enum GenevaConfigClientError {
6674 // Authentication-related errors
67- #[ error( "Authentication method not implemented: {0}" ) ]
68- AuthMethodNotImplemented ( String ) ,
6975 #[ error( "Missing Auth Info: {0}" ) ]
7076 AuthInfoNotFound ( String ) ,
7177 #[ error( "Invalid or malformed JWT token: {0}" ) ]
7278 JwtTokenError ( String ) ,
7379 #[ error( "Certificate error: {0}" ) ]
7480 Certificate ( String ) ,
81+ #[ error( "MSI authentication error: {0}" ) ]
82+ MsiAuth ( String ) ,
7583
7684 // Networking / HTTP / TLS
7785 #[ error( "HTTP error: {0}" ) ]
@@ -129,6 +137,7 @@ pub(crate) struct GenevaConfigClientConfig {
129137 pub ( crate ) region : String ,
130138 pub ( crate ) config_major_version : u32 ,
131139 pub ( crate ) auth_method : AuthMethod , // agent_identity and agent_version are hardcoded for now
140+ pub ( crate ) msi_resource : Option < String > , // Required when using any Managed Identity variant
132141}
133142
134143#[ allow( dead_code) ]
@@ -246,10 +255,10 @@ impl GenevaConfigClient {
246255 . map_err ( |e| GenevaConfigClientError :: Certificate ( e. to_string ( ) ) ) ?;
247256 client_builder = client_builder. use_preconfigured_tls ( tls_connector) ;
248257 }
249- AuthMethod :: ManagedIdentity => {
250- return Err ( GenevaConfigClientError :: AuthMethodNotImplemented (
251- "Managed Identity authentication is not implemented yet" . into ( ) ,
252- ) ) ;
258+ AuthMethod :: SystemManagedIdentity
259+ | AuthMethod :: UserManagedIdentity { .. }
260+ | AuthMethod :: UserManagedIdentityByObjectId { .. }
261+ | AuthMethod :: UserManagedIdentityByResourceId { .. } => { /* no special HTTP client changes needed */
253262 }
254263 #[ cfg( feature = "mock_auth" ) ]
255264 AuthMethod :: MockAuth => {
@@ -268,11 +277,24 @@ impl GenevaConfigClient {
268277 let encoded_identity = general_purpose:: STANDARD . encode ( & identity) ;
269278 let version_str = format ! ( "Ver{0}v0" , config. config_major_version) ;
270279
280+ // Use different API endpoints based on authentication method
281+ // Certificate auth uses "api", MSI auth uses "userapi"
282+ let api_path = match & config. auth_method {
283+ AuthMethod :: Certificate { .. } => "api" ,
284+ AuthMethod :: SystemManagedIdentity
285+ | AuthMethod :: UserManagedIdentity { .. }
286+ | AuthMethod :: UserManagedIdentityByObjectId { .. }
287+ | AuthMethod :: UserManagedIdentityByResourceId { .. } => "userapi" ,
288+ #[ cfg( feature = "mock_auth" ) ]
289+ AuthMethod :: MockAuth => "api" , // treat mock like certificate path for URL shape
290+ } ;
291+
271292 let mut pre_url = String :: with_capacity ( config. endpoint . len ( ) + 200 ) ;
272293 write ! (
273294 & mut pre_url,
274- "{}/api /agent/v3/{}/{}/MonitoringStorageKeys/?Namespace={}&Region={}&Identity={}&OSType={}&ConfigMajorVersion={}" ,
295+ "{}/{} /agent/v3/{}/{}/MonitoringStorageKeys/?Namespace={}&Region={}&Identity={}&OSType={}&ConfigMajorVersion={}" ,
275296 config. endpoint. trim_end_matches( '/' ) ,
297+ api_path,
276298 config. environment,
277299 config. account,
278300 config. namespace,
@@ -310,6 +332,66 @@ impl GenevaConfigClient {
310332 headers
311333 }
312334
335+ /// Get MSI token for GCS authentication
336+ async fn get_msi_token ( & self ) -> Result < String > {
337+ let resource = self . config . msi_resource . as_ref ( ) . ok_or_else ( || {
338+ GenevaConfigClientError :: MsiAuth (
339+ "msi_resource not set in config (required for Managed Identity auth)" . to_string ( ) ,
340+ )
341+ } ) ?;
342+
343+ // Normalize resource (strip trailing "/.default" if provided by user)
344+ let base = resource. trim_end_matches ( "/.default" ) . trim_end_matches ( '/' ) ;
345+
346+ // Candidate scopes tried with Azure Identity
347+ let mut scope_candidates: Vec < String > = vec ! [ format!( "{base}/.default" ) , base. to_string( ) ] ;
348+ // Add variant with trailing slash if not already present
349+ if !base. ends_with ( '/' ) {
350+ scope_candidates. push ( format ! ( "{base}/" ) ) ;
351+ }
352+
353+ // Build credential based on selector
354+ let user_assigned_id = match & self . config . auth_method {
355+ AuthMethod :: SystemManagedIdentity => None ,
356+ AuthMethod :: UserManagedIdentity { client_id } => {
357+ Some ( UserAssignedId :: ClientId ( client_id. clone ( ) ) )
358+ }
359+ AuthMethod :: UserManagedIdentityByObjectId { object_id } => {
360+ Some ( UserAssignedId :: ObjectId ( object_id. clone ( ) ) )
361+ }
362+ AuthMethod :: UserManagedIdentityByResourceId { resource_id } => {
363+ Some ( UserAssignedId :: ResourceId ( resource_id. clone ( ) ) )
364+ }
365+ _ => {
366+ return Err ( GenevaConfigClientError :: MsiAuth (
367+ "get_msi_token called but auth method is not a managed identity variant"
368+ . to_string ( ) ,
369+ ) )
370+ }
371+ } ;
372+
373+ let options = ManagedIdentityCredentialOptions {
374+ user_assigned_id,
375+ ..Default :: default ( )
376+ } ;
377+ let credential = ManagedIdentityCredential :: new ( Some ( options) ) . map_err ( |e| {
378+ GenevaConfigClientError :: MsiAuth ( format ! ( "Failed to create MSI credential: {e}" ) )
379+ } ) ?;
380+
381+ let mut last_err: Option < String > = None ;
382+ for scope in & scope_candidates {
383+ match credential. get_token ( & [ scope. as_str ( ) ] , None ) . await {
384+ Ok ( token) => return Ok ( token. token . secret ( ) . to_string ( ) ) ,
385+ Err ( e) => last_err = Some ( e. to_string ( ) ) ,
386+ }
387+ }
388+ let detail = last_err. unwrap_or_else ( || "no error detail" . into ( ) ) ;
389+ Err ( GenevaConfigClientError :: MsiAuth ( format ! (
390+ "Managed Identity token acquisition failed. Scopes tried: {scopes}. Last error: {detail}. IMDS fallback intentionally disabled." ,
391+ scopes = scope_candidates. join( ", " )
392+ ) ) )
393+ }
394+
313395 /// Retrieves ingestion gateway information from the Geneva Config Service.
314396 ///
315397 /// # HTTP API Details
@@ -381,7 +463,16 @@ impl GenevaConfigClient {
381463 GenevaConfigClientError :: InternalError ( "Failed to parse token expiry" . into ( ) )
382464 } ) ?;
383465
384- let token_endpoint = extract_endpoint_from_token ( & fresh_ingestion_gateway_info. auth_token ) ?;
466+ let token_endpoint =
467+ match extract_endpoint_from_token ( & fresh_ingestion_gateway_info. auth_token ) {
468+ Ok ( ep) => ep,
469+ Err ( err) => {
470+ // Fallback: some tokens legitimately omit the Endpoint claim; use server endpoint.
471+ #[ cfg( debug_assertions) ]
472+ eprintln ! ( "[geneva][debug] token Endpoint claim missing or unparsable: {err}" ) ;
473+ fresh_ingestion_gateway_info. endpoint . clone ( )
474+ }
475+ } ;
385476
386477 // Now update the cache with exclusive write access
387478 let mut guard = self
@@ -432,10 +523,29 @@ impl GenevaConfigClient {
432523 . headers ( self . static_headers . clone ( ) ) ; // Clone only cheap references
433524
434525 request = request. header ( "x-ms-client-request-id" , req_id) ;
435- let response = request
436- . send ( )
437- . await
438- . map_err ( GenevaConfigClientError :: Http ) ?;
526+
527+ // Add MSI authentication for managed identity auth method
528+ match & self . config . auth_method {
529+ AuthMethod :: SystemManagedIdentity
530+ | AuthMethod :: UserManagedIdentity { .. }
531+ | AuthMethod :: UserManagedIdentityByObjectId { .. }
532+ | AuthMethod :: UserManagedIdentityByResourceId { .. } => {
533+ let msi_token = self . get_msi_token ( ) . await ?;
534+ request = request. header ( AUTHORIZATION , format ! ( "Bearer {}" , msi_token) ) ;
535+ }
536+ AuthMethod :: Certificate { .. } => { /* mTLS only */ }
537+ #[ cfg( feature = "mock_auth" ) ]
538+ AuthMethod :: MockAuth => { /* no auth header */ }
539+ }
540+
541+ // Log the request details for debugging
542+ let response = match request. send ( ) . await {
543+ Ok ( response) => response,
544+ Err ( e) => {
545+ return Err ( GenevaConfigClientError :: Http ( e) ) ;
546+ }
547+ } ;
548+
439549 // Check if the response is successful
440550 let status = response. status ( ) ;
441551 let body = response. text ( ) . await ?;
@@ -506,12 +616,18 @@ fn extract_endpoint_from_token(token: &str) -> Result<String> {
506616 _ => payload. to_string ( ) ,
507617 } ;
508618
509- // Decode the Base64-encoded payload into raw bytes
510- let decoded = general_purpose:: URL_SAFE_NO_PAD
511- . decode ( payload)
512- . map_err ( |e| {
513- GenevaConfigClientError :: JwtTokenError ( format ! ( "Failed to decode JWT: {e}" ) )
514- } ) ?;
619+ // Decode the Base64-encoded payload into raw bytes with a more tolerant approach.
620+ let decoded = match general_purpose:: URL_SAFE_NO_PAD . decode ( & payload) {
621+ Ok ( b) => b,
622+ Err ( e_url) => match general_purpose:: STANDARD . decode ( & payload) {
623+ Ok ( b) => b,
624+ Err ( e_std) => {
625+ return Err ( GenevaConfigClientError :: JwtTokenError ( format ! (
626+ "Failed to decode JWT (url_safe and standard): url_err={e_url}; std_err={e_std}"
627+ ) ) )
628+ }
629+ } ,
630+ } ;
515631
516632 // Convert the raw bytes into a UTF-8 string
517633 let decoded_str = String :: from_utf8 ( decoded) . map_err ( |e| {
@@ -522,15 +638,12 @@ fn extract_endpoint_from_token(token: &str) -> Result<String> {
522638 let payload_json: serde_json:: Value =
523639 serde_json:: from_str ( & decoded_str) . map_err ( GenevaConfigClientError :: SerdeJson ) ?;
524640
525- // Extract "Endpoint" from JWT payload as a string, or fail if missing or invalid.
526- let endpoint = payload_json[ "Endpoint" ]
527- . as_str ( )
528- . ok_or_else ( || {
529- GenevaConfigClientError :: JwtTokenError ( "No Endpoint claim in JWT token" . to_string ( ) )
530- } ) ?
531- . to_string ( ) ;
532-
533- Ok ( endpoint)
641+ if let Some ( ep) = payload_json[ "Endpoint" ] . as_str ( ) {
642+ return Ok ( ep. to_string ( ) ) ;
643+ }
644+ Err ( GenevaConfigClientError :: JwtTokenError (
645+ "No Endpoint claim in JWT token" . to_string ( ) ,
646+ ) )
534647}
535648
536649#[ cfg( feature = "self_signed_certs" ) ]
0 commit comments