@@ -67,7 +67,7 @@ use tonic::{
6767 body:: BoxBody ,
6868 client:: GrpcService ,
6969 codegen:: InterceptedService ,
70- metadata:: { MetadataKey , MetadataValue } ,
70+ metadata:: { MetadataKey , MetadataMap , MetadataValue } ,
7171 service:: Interceptor ,
7272 transport:: { Certificate , Channel , Endpoint , Identity } ,
7373 Code , Status ,
@@ -133,6 +133,15 @@ pub struct ClientOptions {
133133 /// If set (which it is by default), HTTP2 gRPC keep alive will be enabled.
134134 #[ builder( default = "Some(ClientKeepAliveConfig::default())" ) ]
135135 pub keep_alive : Option < ClientKeepAliveConfig > ,
136+
137+ /// HTTP headers to include on every RPC call.
138+ #[ builder( default ) ]
139+ pub headers : Option < HashMap < String , String > > ,
140+
141+ /// API key which is set as the "Authorization" header with "Bearer " prepended. This will only
142+ /// be applied if the headers don't already have an "Authorization" header.
143+ #[ builder( default ) ]
144+ pub api_key : Option < String > ,
136145}
137146
138147/// Configuration options for TLS
@@ -279,7 +288,7 @@ pub enum ClientInitError {
279288pub struct ConfiguredClient < C > {
280289 client : C ,
281290 options : Arc < ClientOptions > ,
282- headers : Arc < RwLock < HashMap < String , String > > > ,
291+ headers : Arc < RwLock < ClientHeaders > > ,
283292 /// Capabilities as read from the `get_system_info` RPC call made on client connection
284293 capabilities : Option < get_system_info_response:: Capabilities > ,
285294 workers : Arc < SlotManager > ,
@@ -288,8 +297,12 @@ pub struct ConfiguredClient<C> {
288297impl < C > ConfiguredClient < C > {
289298 /// Set HTTP request headers overwriting previous headers
290299 pub fn set_headers ( & self , headers : HashMap < String , String > ) {
291- let mut guard = self . headers . write ( ) ;
292- * guard = headers;
300+ self . headers . write ( ) . user_headers = headers;
301+ }
302+
303+ /// Set API key, overwriting previous
304+ pub fn set_api_key ( & self , api_key : Option < String > ) {
305+ self . headers . write ( ) . api_key = api_key;
293306 }
294307
295308 /// Returns the options the client is configured with
@@ -309,6 +322,34 @@ impl<C> ConfiguredClient<C> {
309322 }
310323}
311324
325+ #[ derive( Debug ) ]
326+ struct ClientHeaders {
327+ user_headers : HashMap < String , String > ,
328+ api_key : Option < String > ,
329+ }
330+
331+ impl ClientHeaders {
332+ fn apply_to_metadata ( & self , metadata : & mut MetadataMap ) {
333+ for ( key, val) in self . user_headers . iter ( ) {
334+ // Only if not already present
335+ if !metadata. contains_key ( key) {
336+ // Ignore invalid keys/values
337+ if let ( Ok ( key) , Ok ( val) ) = ( MetadataKey :: from_str ( key) , val. parse ( ) ) {
338+ metadata. insert ( key, val) ;
339+ }
340+ }
341+ }
342+ if let Some ( api_key) = & self . api_key {
343+ // Only if not already present
344+ if !metadata. contains_key ( "authorization" ) {
345+ if let Ok ( val) = format ! ( "Bearer {}" , api_key) . parse ( ) {
346+ metadata. insert ( "authorization" , val) ;
347+ }
348+ }
349+ }
350+ }
351+ }
352+
312353// The configured client is effectively a "smart" (dumb) pointer
313354impl < C > Deref for ConfiguredClient < C > {
314355 type Target = C ;
@@ -331,12 +372,8 @@ impl ClientOptions {
331372 & self ,
332373 namespace : impl Into < String > ,
333374 metrics_meter : Option < TemporalMeter > ,
334- headers : Option < Arc < RwLock < HashMap < String , String > > > > ,
335375 ) -> Result < RetryClient < Client > , ClientInitError > {
336- let client = self
337- . connect_no_namespace ( metrics_meter, headers)
338- . await ?
339- . into_inner ( ) ;
376+ let client = self . connect_no_namespace ( metrics_meter) . await ?. into_inner ( ) ;
340377 let client = Client :: new ( client, namespace. into ( ) ) ;
341378 let retry_client = RetryClient :: new ( client, self . retry_config . clone ( ) ) ;
342379 Ok ( retry_client)
@@ -349,7 +386,6 @@ impl ClientOptions {
349386 pub async fn connect_no_namespace (
350387 & self ,
351388 metrics_meter : Option < TemporalMeter > ,
352- headers : Option < Arc < RwLock < HashMap < String , String > > > > ,
353389 ) -> Result < RetryClient < ConfiguredClient < TemporalServiceClientWithMetrics > > , ClientInitError >
354390 {
355391 let channel = Channel :: from_shared ( self . target_url . to_string ( ) ) ?;
@@ -374,7 +410,10 @@ impl ClientOptions {
374410 metrics : metrics_meter. clone ( ) . map ( MetricsContext :: new) ,
375411 } )
376412 . service ( channel) ;
377- let headers = headers. unwrap_or_default ( ) ;
413+ let headers = Arc :: new ( RwLock :: new ( ClientHeaders {
414+ user_headers : self . headers . clone ( ) . unwrap_or_default ( ) ,
415+ api_key : self . api_key . clone ( ) ,
416+ } ) ) ;
378417 let interceptor = ServiceCallInterceptor {
379418 opts : self . clone ( ) ,
380419 headers : headers. clone ( ) ,
@@ -442,7 +481,7 @@ impl ClientOptions {
442481pub struct ServiceCallInterceptor {
443482 opts : ClientOptions ,
444483 /// Only accessed as a reader
445- headers : Arc < RwLock < HashMap < String , String > > > ,
484+ headers : Arc < RwLock < ClientHeaders > > ,
446485}
447486
448487impl Interceptor for ServiceCallInterceptor {
@@ -468,16 +507,7 @@ impl Interceptor for ServiceCallInterceptor {
468507 . unwrap_or_else ( |_| MetadataValue :: from_static ( "" ) ) ,
469508 ) ;
470509 }
471- let headers = & * self . headers . read ( ) ;
472- for ( k, v) in headers {
473- if metadata. contains_key ( k) {
474- // Don't overwrite per-request specified headers
475- continue ;
476- }
477- if let ( Ok ( k) , Ok ( v) ) = ( MetadataKey :: from_str ( k) , v. parse ( ) ) {
478- metadata. insert ( k, v) ;
479- }
480- }
510+ self . headers . read ( ) . apply_to_metadata ( metadata) ;
481511 if !metadata. contains_key ( "grpc-timeout" ) {
482512 request. set_timeout ( OTHER_CALL_TIMEOUT ) ;
483513 }
@@ -1559,7 +1589,7 @@ mod tests {
15591589 use super :: * ;
15601590
15611591 #[ test]
1562- fn respects_per_call_headers ( ) {
1592+ fn applies_headers ( ) {
15631593 let opts = ClientOptionsBuilder :: default ( )
15641594 . identity ( "enchicat" . to_string ( ) )
15651595 . target_url ( Url :: parse ( "https://smolkitty" ) . unwrap ( ) )
@@ -1568,16 +1598,55 @@ mod tests {
15681598 . build ( )
15691599 . unwrap ( ) ;
15701600
1571- let mut static_headers = HashMap :: new ( ) ;
1572- static_headers. insert ( "enchi" . to_string ( ) , "kitty" . to_string ( ) ) ;
1573- let mut iceptor = ServiceCallInterceptor {
1601+ // Initial header set
1602+ let headers = Arc :: new ( RwLock :: new ( ClientHeaders {
1603+ user_headers : HashMap :: new ( ) ,
1604+ api_key : Some ( "my-api-key" . to_owned ( ) ) ,
1605+ } ) ) ;
1606+ headers
1607+ . clone ( )
1608+ . write ( )
1609+ . user_headers
1610+ . insert ( "my-meta-key" . to_owned ( ) , "my-meta-val" . to_owned ( ) ) ;
1611+ let mut interceptor = ServiceCallInterceptor {
15741612 opts,
1575- headers : Arc :: new ( RwLock :: new ( static_headers ) ) ,
1613+ headers : headers . clone ( ) ,
15761614 } ;
1615+
1616+ // Confirm on metadata
1617+ let req = interceptor. call ( tonic:: Request :: new ( ( ) ) ) . unwrap ( ) ;
1618+ assert_eq ! ( req. metadata( ) . get( "my-meta-key" ) . unwrap( ) , "my-meta-val" ) ;
1619+ assert_eq ! (
1620+ req. metadata( ) . get( "authorization" ) . unwrap( ) ,
1621+ "Bearer my-api-key"
1622+ ) ;
1623+
1624+ // Overwrite at request time
15771625 let mut req = tonic:: Request :: new ( ( ) ) ;
1578- req. metadata_mut ( ) . insert ( "enchi" , "cat" . parse ( ) . unwrap ( ) ) ;
1579- let next_req = iceptor. call ( req) . unwrap ( ) ;
1580- assert_eq ! ( next_req. metadata( ) . get( "enchi" ) . unwrap( ) , "cat" ) ;
1626+ req. metadata_mut ( )
1627+ . insert ( "my-meta-key" , "my-meta-val2" . parse ( ) . unwrap ( ) ) ;
1628+ req. metadata_mut ( )
1629+ . insert ( "authorization" , "my-api-key2" . parse ( ) . unwrap ( ) ) ;
1630+ let req = interceptor. call ( req) . unwrap ( ) ;
1631+ assert_eq ! ( req. metadata( ) . get( "my-meta-key" ) . unwrap( ) , "my-meta-val2" ) ;
1632+ assert_eq ! ( req. metadata( ) . get( "authorization" ) . unwrap( ) , "my-api-key2" ) ;
1633+
1634+ // Overwrite auth on header
1635+ headers
1636+ . clone ( )
1637+ . write ( )
1638+ . user_headers
1639+ . insert ( "authorization" . to_owned ( ) , "my-api-key3" . to_owned ( ) ) ;
1640+ let req = interceptor. call ( tonic:: Request :: new ( ( ) ) ) . unwrap ( ) ;
1641+ assert_eq ! ( req. metadata( ) . get( "my-meta-key" ) . unwrap( ) , "my-meta-val" ) ;
1642+ assert_eq ! ( req. metadata( ) . get( "authorization" ) . unwrap( ) , "my-api-key3" ) ;
1643+
1644+ // Remove headers and auth and confirm gone
1645+ headers. clone ( ) . write ( ) . user_headers . clear ( ) ;
1646+ headers. clone ( ) . write ( ) . api_key . take ( ) ;
1647+ let req = interceptor. call ( tonic:: Request :: new ( ( ) ) ) . unwrap ( ) ;
1648+ assert ! ( !req. metadata( ) . contains_key( "my-meta-key" ) ) ;
1649+ assert ! ( !req. metadata( ) . contains_key( "authorization" ) ) ;
15811650 }
15821651
15831652 #[ test]
0 commit comments