@@ -34,12 +34,15 @@ type FetchTimers = HashMap<String, std::time::Instant>;
3434type Polls =
3535 Arc < std:: sync:: Mutex < std:: collections:: HashSet < ( String , josh_proxy:: auth:: Handle , String ) > > > ;
3636
37+ type HeadsMap = Arc < std:: sync:: RwLock < std:: collections:: HashMap < String , String > > > ;
38+
3739#[ derive( Clone ) ]
3840struct JoshProxyService {
3941 port : String ,
4042 repo_path : std:: path:: PathBuf ,
4143 upstream_url : String ,
4244 fetch_timers : Arc < RwLock < FetchTimers > > ,
45+ heads_map : HeadsMap ,
4346 fetch_permits : Arc < tokio:: sync:: Semaphore > ,
4447 filter_permits : Arc < tokio:: sync:: Semaphore > ,
4548 poll : Polls ,
@@ -66,10 +69,10 @@ async fn fetch_upstream(
6669 let auth = auth. clone ( ) ;
6770 let key = remote_url. clone ( ) ;
6871
69- let refs_to_fetch = if ! headref. is_empty ( ) && !headref. starts_with ( "refs/heads/" ) {
70- vec ! [ "refs/heads/*" , "refs/tags/*" , headref]
72+ let refs_to_fetch = if headref != "HEAD" && !headref. starts_with ( "refs/heads/" ) {
73+ vec ! [ "HEAD*" , " refs/heads/*", "refs/tags/*" , headref]
7174 } else {
72- vec ! [ "refs/heads/*" , "refs/tags/*" ]
75+ vec ! [ "HEAD*" , " refs/heads/*", "refs/tags/*" ]
7376 } ;
7477
7578 let refs_to_fetch: Vec < _ > = refs_to_fetch. iter ( ) . map ( |x| x. to_string ( ) ) . collect ( ) ;
@@ -113,6 +116,7 @@ async fn fetch_upstream(
113116 }
114117
115118 let fetch_timers = service. fetch_timers . clone ( ) ;
119+ let heads_map = service. heads_map . clone ( ) ;
116120 let br_path = service. repo_path . clone ( ) ;
117121
118122 let s = tracing:: span!( tracing:: Level :: TRACE , "fetch worker" ) ;
@@ -126,6 +130,21 @@ async fn fetch_upstream(
126130 } )
127131 . await ?;
128132
133+ let us = upstream_repo. clone ( ) ;
134+ let s = tracing:: span!( tracing:: Level :: TRACE , "get_head worker" ) ;
135+ let br_path = service. repo_path . clone ( ) ;
136+ let ru = remote_url. clone ( ) ;
137+ let a = auth. clone ( ) ;
138+ let hres = tokio:: task:: spawn_blocking ( move || {
139+ let _e = s. enter ( ) ;
140+ josh_proxy:: get_head ( & br_path, & ru, & a)
141+ } )
142+ . await ?;
143+
144+ if let Ok ( hres) = hres {
145+ heads_map. write ( ) ?. insert ( us, hres) ;
146+ }
147+
129148 std:: mem:: drop ( permit) ;
130149
131150 if let Ok ( res) = res {
@@ -227,6 +246,7 @@ async fn do_filter(
227246 headref : String ,
228247) -> josh:: JoshResult < ( ) > {
229248 let permit = service. filter_permits . acquire ( ) . await ;
249+ let heads_map = service. heads_map . clone ( ) ;
230250
231251 let s = tracing:: span!( tracing:: Level :: TRACE , "do_filter worker" ) ;
232252 let r = tokio:: task:: spawn_blocking ( move || {
@@ -273,13 +293,25 @@ async fn do_filter(
273293 temp_ns. reference ( & headref) ,
274294 ) ) ;
275295
296+ let mut headref = headref;
297+
276298 josh:: filter_refs ( & transaction, filter, & from_to) ?;
277- transaction. repo ( ) . reference_symbolic (
278- & temp_ns. reference ( "HEAD" ) ,
279- & temp_ns. reference ( & headref) ,
280- true ,
281- "" ,
282- ) ?;
299+ if headref == "HEAD" {
300+ headref = heads_map
301+ . read ( ) ?
302+ . get ( & upstream_repo)
303+ . unwrap_or ( & "invalid" . to_string ( ) )
304+ . clone ( ) ;
305+ }
306+ transaction
307+ . repo ( )
308+ . reference_symbolic (
309+ & temp_ns. reference ( "HEAD" ) ,
310+ & temp_ns. reference ( & headref) ,
311+ true ,
312+ "" ,
313+ )
314+ . ok ( ) ;
283315 Ok ( ( ) )
284316 } )
285317 . await ?;
@@ -368,17 +400,14 @@ async fn call_service(
368400 } else {
369401 return Ok ( Response :: builder ( )
370402 . status ( 302 )
371- . header (
372- "Location" ,
373- format ! ( "/~/browse{}@refs/heads/master(:/)/()" , path) ,
374- )
403+ . header ( "Location" , format ! ( "/~/browse{}@HEAD(:/)/()" , path) )
375404 . body ( hyper:: Body :: empty ( ) ) ?) ;
376405 }
377406 } ;
378407
379408 let mut headref = parsed_url. headref . trim_start_matches ( '@' ) . to_owned ( ) ;
380409 if headref. is_empty ( ) {
381- headref = "refs/heads/master " . to_string ( ) ;
410+ headref = "HEAD " . to_string ( ) ;
382411 }
383412
384413 let remote_url = [
@@ -611,6 +640,7 @@ async fn run_proxy() -> josh::JoshResult<i32> {
611640 repo_path : local. to_owned ( ) ,
612641 upstream_url : remote. to_owned ( ) ,
613642 fetch_timers : Arc :: new ( RwLock :: new ( FetchTimers :: new ( ) ) ) ,
643+ heads_map : Arc :: new ( RwLock :: new ( std:: collections:: HashMap :: new ( ) ) ) ,
614644 poll : Arc :: new ( std:: sync:: Mutex :: new ( std:: collections:: HashSet :: new ( ) ) ) ,
615645 fetch_permits : Arc :: new ( tokio:: sync:: Semaphore :: new (
616646 ARGS . value_of ( "n" ) . unwrap_or ( "1" ) . parse ( ) ?,
0 commit comments