1+ use std:: sync:: Arc ;
2+
13// Import the base64 crate Engine trait anonymously so we can
24// call its methods without adding to the namespace.
35use base64:: engine:: general_purpose:: STANDARD as BASE64 ;
46use base64:: engine:: Engine as _;
7+ use tracing:: Instrument ;
8+
9+ // Auths in those groups are independent of each other.
10+ // This lets us reduce mutex contention
11+ #[ derive( Hash , Eq , PartialEq , Clone ) ]
12+ struct AuthTimersGroupKey {
13+ url : String ,
14+ username : String ,
15+ }
516
6- lazy_static ! {
7- static ref AUTH : std:: sync:: Mutex <std:: collections:: HashMap <Handle , Header >> =
8- std:: sync:: Mutex :: new( std:: collections:: HashMap :: new( ) ) ;
9- static ref AUTH_TIMERS : std:: sync:: Mutex <AuthTimers > =
10- std:: sync:: Mutex :: new( std:: collections:: HashMap :: new( ) ) ;
17+ impl AuthTimersGroupKey {
18+ fn new ( url : & str , handle : & Handle ) -> Self {
19+ let ( username, _) = handle. parse ( ) . unwrap_or_default ( ) ;
20+
21+ Self {
22+ url : url. to_string ( ) ,
23+ username,
24+ }
25+ }
1126}
1227
13- type AuthTimers = std:: collections:: HashMap < ( String , Handle ) , std:: time:: Instant > ;
28+ // Within a group, we can hold the lock for longer to verify the auth with upstream
29+ type AuthTimersGroup = std:: collections:: HashMap < Handle , std:: time:: Instant > ;
30+ type AuthTimers =
31+ std:: collections:: HashMap < AuthTimersGroupKey , Arc < tokio:: sync:: Mutex < AuthTimersGroup > > > ;
32+
33+ lazy_static ! {
34+ // Note the use of std::sync::Mutex: access to those structures should only be performed
35+ // shortly, without blocking the async runtime for long time and without holding the
36+ // lock across an await point.
37+ static ref AUTH : std:: sync:: Mutex <std:: collections:: HashMap <Handle , Header >> = Default :: default ( ) ;
38+ static ref AUTH_TIMERS : std:: sync:: Mutex <AuthTimers > = Default :: default ( ) ;
39+ }
1440
1541// Wrapper struct for storing passwords to avoid having
1642// them output to traces by accident
17- #[ derive( Clone ) ]
43+ #[ derive( Clone , Default ) ]
1844struct Header {
1945 pub header : Option < hyper:: header:: HeaderValue > ,
2046}
2147
22- #[ derive( Clone , PartialEq , Eq , Hash , serde:: Serialize , serde:: Deserialize ) ]
48+ #[ derive( Clone , PartialEq , Eq , Hash , PartialOrd , Ord , serde:: Serialize , serde:: Deserialize ) ]
2349pub struct Handle {
24- pub hash : String ,
50+ pub hash : Option < String > ,
2551}
2652
2753impl std:: fmt:: Debug for Handle {
@@ -32,39 +58,50 @@ impl std::fmt::Debug for Handle {
3258
3359impl Handle {
3460 // Returns a pair: (username, password)
35- pub fn parse ( & self ) -> josh:: JoshResult < ( String , String ) > {
36- let line = josh:: some_or!(
37- AUTH . lock( )
61+ pub fn parse ( & self ) -> Option < ( String , String ) > {
62+ let get_result = || -> josh:: JoshResult < ( String , String ) > {
63+ let line = AUTH
64+ . lock ( )
3865 . unwrap ( )
3966 . get ( self )
4067 . and_then ( |h| h. header . as_ref ( ) )
41- . map( |h| h. as_bytes( ) . to_owned( ) ) ,
42- {
43- return Ok ( ( "" . to_string( ) , "" . to_string( ) ) ) ;
44- }
45- ) ;
68+ . map ( |h| h. as_bytes ( ) . to_owned ( ) )
69+ . ok_or_else ( || josh:: josh_error ( "no auth found" ) ) ?;
4670
47- let u = josh:: ok_or!( String :: from_utf8( line[ 6 ..] . to_vec( ) ) , {
48- return Ok ( ( "" . to_string( ) , "" . to_string( ) ) ) ;
49- } ) ;
50- let decoded = josh:: ok_or!( BASE64 . decode( u) , {
51- return Ok ( ( "" . to_string( ) , "" . to_string( ) ) ) ;
52- } ) ;
53- let s = josh:: ok_or!( String :: from_utf8( decoded) , {
54- return Ok ( ( "" . to_string( ) , "" . to_string( ) ) ) ;
55- } ) ;
56- let ( username, password) = s. as_str ( ) . split_once ( ':' ) . unwrap_or ( ( "" , "" ) ) ;
57- Ok ( ( username. to_string ( ) , password. to_string ( ) ) )
71+ let line = String :: from_utf8 ( line) ?;
72+ let ( _, token) = line
73+ . split_once ( ' ' )
74+ . ok_or_else ( || josh:: josh_error ( "Unsupported auth type" ) ) ?;
75+
76+ let decoded = BASE64 . decode ( token) ?;
77+ let decoded = String :: from_utf8 ( decoded) ?;
78+
79+ let ( username, password) = decoded
80+ . split_once ( ':' )
81+ . ok_or_else ( || josh:: josh_error ( "No password found" ) ) ?;
82+
83+ Ok ( ( username. to_string ( ) , password. to_string ( ) ) )
84+ } ;
85+
86+ match get_result ( ) {
87+ Ok ( pair) => Some ( pair) ,
88+ Err ( e) => {
89+ tracing:: trace!(
90+ handle = ?self ,
91+ "Falling back to default auth: {:?}" ,
92+ e
93+ ) ;
94+
95+ None
96+ }
97+ }
5898 }
5999}
60100
61101pub fn add_auth ( token : & str ) -> josh:: JoshResult < Handle > {
62102 let header = hyper:: header:: HeaderValue :: from_str ( & format ! ( "Basic {}" , BASE64 . encode( token) ) ) ?;
63103 let hp = Handle {
64- hash : format ! (
65- "{:?}" ,
66- git2:: Oid :: hash_object( git2:: ObjectType :: Blob , header. as_bytes( ) ) ?
67- ) ,
104+ hash : Some ( git2:: Oid :: hash_object ( git2:: ObjectType :: Blob , header. as_bytes ( ) ) ?. to_string ( ) ) ,
68105 } ;
69106 let p = Header {
70107 header : Some ( header) ,
@@ -73,65 +110,122 @@ pub fn add_auth(token: &str) -> josh::JoshResult<Handle> {
73110 Ok ( hp)
74111}
75112
76- pub async fn check_auth ( url : & str , auth : & Handle , required : bool ) -> josh:: JoshResult < bool > {
77- if required && auth. hash . is_empty ( ) {
78- return Ok ( false ) ;
79- }
113+ #[ tracing:: instrument( ) ]
114+ pub async fn check_http_auth ( url : & str , auth : & Handle , required : bool ) -> josh:: JoshResult < bool > {
115+ use opentelemetry_semantic_conventions:: trace:: HTTP_RESPONSE_STATUS_CODE ;
80116
81- if let Some ( last) = AUTH_TIMERS . lock ( ) ?. get ( & ( url. to_string ( ) , auth. clone ( ) ) ) {
82- let since = std:: time:: Instant :: now ( ) . duration_since ( * last) ;
83- tracing:: trace!( "last: {:?}, since: {:?}" , last, since) ;
84- if since < std:: time:: Duration :: from_secs ( 60 * 30 ) {
85- tracing:: trace!( "cached auth" ) ;
86- return Ok ( true ) ;
87- }
117+ if required && auth. hash . is_none ( ) {
118+ return Ok ( false ) ;
88119 }
89120
90- tracing:: trace!( "no cached auth {:?}" , * AUTH_TIMERS . lock( ) ?) ;
121+ let group_key = AuthTimersGroupKey :: new ( url, & auth) ;
122+ let auth_timers = AUTH_TIMERS
123+ . lock ( )
124+ . unwrap ( )
125+ . entry ( group_key. clone ( ) )
126+ . or_default ( )
127+ . clone ( ) ;
91128
92- let https = hyper_tls:: HttpsConnector :: new ( ) ;
93- let client = hyper:: Client :: builder ( ) . build :: < _ , hyper:: Body > ( https) ;
129+ let auth_header = AUTH . lock ( ) . unwrap ( ) . get ( auth) . cloned ( ) . unwrap_or_default ( ) ;
94130
95- let password = AUTH
96- . lock ( ) ?
97- . get ( auth)
98- . unwrap_or ( & Header { header : None } )
99- . to_owned ( ) ;
100131 let refs_url = format ! ( "{}/info/refs?service=git-upload-pack" , url) ;
132+ let do_request = || {
133+ let refs_url = refs_url. clone ( ) ;
134+ let do_request_span = tracing:: info_span!( "check_http_auth: make request" ) ;
101135
102- let builder = hyper :: Request :: builder ( )
103- . method ( hyper :: Method :: GET )
104- . uri ( & refs_url ) ;
136+ async move {
137+ let https = hyper_tls :: HttpsConnector :: new ( ) ;
138+ let client = hyper :: Client :: builder ( ) . build :: < _ , hyper :: Body > ( https ) ;
105139
106- let builder = if let Some ( value) = password. header {
107- builder. header ( hyper:: header:: AUTHORIZATION , value)
108- } else {
109- builder
140+ let builder = hyper:: Request :: builder ( )
141+ . method ( hyper:: Method :: GET )
142+ . uri ( & refs_url) ;
143+
144+ let builder = if let Some ( value) = auth_header. header {
145+ builder. header ( hyper:: header:: AUTHORIZATION , value)
146+ } else {
147+ builder
148+ } ;
149+
150+ let request = builder. body ( hyper:: Body :: empty ( ) ) ?;
151+ let resp = client. request ( request) . await ?;
152+
153+ Ok :: < _ , josh:: JoshError > ( resp)
154+ }
155+ . instrument ( do_request_span)
110156 } ;
111157
112- let request = builder. body ( hyper:: Body :: empty ( ) ) ?;
113- let resp = client. request ( request) . await ?;
158+ // Only lock the mutex if auth handle is not empty, because otherwise
159+ // for remotes that require auth, we could run into situation where
160+ // multiple requests are executed essentially sequentially because
161+ // remote always returns 401 for authenticated requests and we never
162+ // populate the auth_timers map
163+ let resp = if auth. hash . is_some ( ) {
164+ let mut auth_timers = auth_timers. lock ( ) . await ;
165+
166+ if let Some ( last) = auth_timers. get ( auth) {
167+ let since = std:: time:: Instant :: now ( ) . duration_since ( * last) ;
168+ let expired = since > std:: time:: Duration :: from_secs ( 60 * 30 ) ;
169+
170+ tracing:: info!(
171+ last = ?last,
172+ since = ?since,
173+ expired = %expired,
174+ "check_http_auth: found auth entry"
175+ ) ;
176+
177+ if !expired {
178+ return Ok ( true ) ;
179+ }
180+ }
114181
115- let status = resp. status ( ) ;
182+ tracing:: info!(
183+ auth_timers = ?auth_timers,
184+ "check_http_auth: no valid cached auth"
185+ ) ;
116186
117- tracing:: trace!( "http resp.status {:?}" , resp. status( ) ) ;
187+ let resp = do_request ( ) . await ?;
188+ if resp. status ( ) . is_success ( ) {
189+ auth_timers. insert ( auth. clone ( ) , std:: time:: Instant :: now ( ) ) ;
190+ }
191+
192+ resp
193+ } else {
194+ do_request ( ) . await ?
195+ } ;
196+
197+ let status = resp. status ( ) ;
118198
119- let err_msg = format ! ( "got http response: {} {:?}" , refs_url, resp) ;
199+ tracing:: event!(
200+ tracing:: Level :: INFO ,
201+ { HTTP_RESPONSE_STATUS_CODE } = status. as_u16( ) ,
202+ "check_http_auth: response"
203+ ) ;
120204
121205 if status == hyper:: StatusCode :: OK {
122- AUTH_TIMERS
123- . lock ( ) ?
124- . insert ( ( url. to_string ( ) , auth. clone ( ) ) , std:: time:: Instant :: now ( ) ) ;
125206 Ok ( true )
126207 } else if status == hyper:: StatusCode :: UNAUTHORIZED {
127- tracing:: warn! ( "resp.status == 401: {:?}" , & err_msg ) ;
128- tracing:: trace! (
129- "body: {:?}" ,
130- std :: str :: from_utf8 ( & hyper :: body :: to_bytes ( resp . into_body ( ) ) . await ? )
208+ tracing:: event! (
209+ tracing:: Level :: WARN ,
210+ { HTTP_RESPONSE_STATUS_CODE } = status . as_u16 ( ) ,
211+ "check_http_auth: unauthorized"
131212 ) ;
213+
214+ let response = hyper:: body:: to_bytes ( resp. into_body ( ) ) . await ?;
215+ let response = String :: from_utf8_lossy ( & response) ;
216+
217+ tracing:: event!(
218+ tracing:: Level :: TRACE ,
219+ "http.response.body" = %response,
220+ "check_http_auth: unauthorized" ,
221+ ) ;
222+
132223 Ok ( false )
133224 } else {
134- return Err ( josh:: josh_error ( & err_msg) ) ;
225+ return Err ( josh:: josh_error ( & format ! (
226+ "check_http_auth: got http response: {} {:?}" ,
227+ refs_url, resp
228+ ) ) ) ;
135229 }
136230}
137231
@@ -144,9 +238,8 @@ pub fn strip_auth(
144238
145239 if let Some ( header) = header {
146240 let hp = Handle {
147- hash : format ! (
148- "{:?}" ,
149- git2:: Oid :: hash_object( git2:: ObjectType :: Blob , header. as_bytes( ) ) ?
241+ hash : Some (
242+ git2:: Oid :: hash_object ( git2:: ObjectType :: Blob , header. as_bytes ( ) ) ?. to_string ( ) ,
150243 ) ,
151244 } ;
152245 let p = Header {
@@ -156,10 +249,5 @@ pub fn strip_auth(
156249 return Ok ( ( hp, req) ) ;
157250 }
158251
159- Ok ( (
160- Handle {
161- hash : "" . to_owned ( ) ,
162- } ,
163- req,
164- ) )
252+ Ok ( ( Handle { hash : None } , req) )
165253}
0 commit comments