@@ -12,6 +12,7 @@ use oauth2::{
1212} ;
1313use reqwest:: { blocking:: Client , Url } ;
1414use serde:: Serialize ;
15+ use std:: cell:: Cell ;
1516use std:: collections:: HashMap ;
1617use std:: error:: Error as StdError ;
1718use std:: fs;
@@ -45,6 +46,7 @@ pub enum StatusType {
4546 Sending ,
4647 WaitingForResults ,
4748 Finished ,
49+ IntermediateStepFinished ,
4850}
4951
5052// compatible with anyhow
@@ -60,6 +62,8 @@ pub struct TmcCore {
6062 auth_url : String ,
6163 token : Option < Token > ,
6264 progress_report : Option < UpdateClosure > ,
65+ progress_steps_done : Cell < u32 > ,
66+ progress_steps_total : u32 ,
6367 client_name : String ,
6468 client_version : String ,
6569}
@@ -103,6 +107,8 @@ impl TmcCore {
103107 auth_url,
104108 token : None ,
105109 progress_report : None ,
110+ progress_steps_done : Cell :: new ( 0 ) ,
111+ progress_steps_total : 1 ,
106112 client_name,
107113 client_version,
108114 } )
@@ -139,12 +145,14 @@ impl TmcCore {
139145 self . progress_report = Some ( Box :: new ( progress_report) ) ;
140146 }
141147
142- pub fn report_progress (
143- & self ,
144- message : & ' static str ,
145- status_type : StatusType ,
146- percent_done : f64 ,
147- ) {
148+ pub fn increment_progress_steps ( & mut self ) {
149+ self . progress_steps_total += 1 ;
150+ }
151+
152+ fn report_progress ( & self , message : & ' static str , status_type : StatusType , percent_done : f64 ) {
153+ let from_prev_steps = self . progress_steps_done . get ( ) as f64 ;
154+ let percent_done = ( from_prev_steps + percent_done) / self . progress_steps_total as f64 ;
155+
148156 self . progress_report . as_ref ( ) . map ( |f| {
149157 f ( StatusUpdate {
150158 finished : false ,
@@ -155,15 +163,21 @@ impl TmcCore {
155163 } ) ;
156164 }
157165
158- pub fn report_complete ( & self , message : & ' static str ) {
159- self . progress_report . as_ref ( ) . map ( |f| {
160- f ( StatusUpdate {
161- finished : true ,
162- message,
163- percent_done : 1.0 ,
164- status_type : StatusType :: Finished ,
165- } )
166- } ) ;
166+ fn report_complete ( & self , message : & ' static str ) {
167+ self . progress_steps_done
168+ . set ( self . progress_steps_done . get ( ) + 1 ) ;
169+ if self . progress_steps_done . get ( ) == self . progress_steps_total {
170+ self . progress_report . as_ref ( ) . map ( |f| {
171+ f ( StatusUpdate {
172+ finished : true ,
173+ message,
174+ percent_done : 1.0 ,
175+ status_type : StatusType :: Finished ,
176+ } )
177+ } ) ;
178+ } else {
179+ self . report_progress ( message, StatusType :: IntermediateStepFinished , 0.0 ) ;
180+ }
167181 }
168182
169183 /// Attempts to log in with the given credentials, returns an error if an authentication token is already present.
@@ -1286,4 +1300,36 @@ mod test {
12861300 serde_json:: to_string( & f) . unwrap( )
12871301 ) ;
12881302 }
1303+
1304+ #[ test]
1305+ fn multi_step_progress ( ) {
1306+ use std:: sync:: { Arc , Mutex } ;
1307+
1308+ let ( mut core, _) = init ( ) ;
1309+ let report = Arc :: new ( Mutex :: default ( ) ) ;
1310+
1311+ let report_clone = Arc :: clone ( & report) ;
1312+ core. set_progress_report ( move |rep| {
1313+ log:: debug!( "got {:#?}" , rep) ;
1314+ let report = Arc :: clone ( & report_clone) ;
1315+ * report. lock ( ) . unwrap ( ) = Some ( rep) ;
1316+ Ok ( ( ) )
1317+ } ) ;
1318+ core. increment_progress_steps ( ) ;
1319+ core. increment_progress_steps ( ) ;
1320+
1321+ core. report_progress ( "msg" , StatusType :: Downloading , 0.2 ) ;
1322+ let err = f64:: EPSILON ;
1323+ assert ! ( ( report. lock( ) . unwrap( ) . as_ref( ) . unwrap( ) . percent_done - ( 0.2 / 3.0 ) ) . abs( ) < err) ;
1324+ core. report_progress ( "msg" , StatusType :: Downloading , 0.8 ) ;
1325+ assert ! ( ( report. lock( ) . unwrap( ) . as_ref( ) . unwrap( ) . percent_done - ( 0.8 / 3.0 ) ) . abs( ) < err) ;
1326+ core. report_complete ( "msg" ) ;
1327+ assert ! ( ( report. lock( ) . unwrap( ) . as_ref( ) . unwrap( ) . percent_done - ( 1.0 / 3.0 ) ) . abs( ) < err) ;
1328+ core. report_complete ( "msg" ) ;
1329+ assert ! ( ( report. lock( ) . unwrap( ) . as_ref( ) . unwrap( ) . percent_done - ( 2.0 / 3.0 ) ) . abs( ) < err) ;
1330+ core. report_progress ( "msg" , StatusType :: Downloading , 0.5 ) ;
1331+ assert ! ( ( report. lock( ) . unwrap( ) . as_ref( ) . unwrap( ) . percent_done - ( 2.5 / 3.0 ) ) . abs( ) < err) ;
1332+ core. report_complete ( "msg" ) ;
1333+ assert ! ( ( report. lock( ) . unwrap( ) . as_ref( ) . unwrap( ) . percent_done - 1.0 ) . abs( ) < err) ;
1334+ }
12891335}
0 commit comments