11"""ProgressBar shows the how many dask tasks finished/remains using tqdm."""
22
3- from typing import Any , Optional , Dict , Tuple , Union
3+ import sys
44from time import time
5+ from typing import Any , Dict , Optional , Tuple , Union
56
67from dask .callbacks import Callback
78
1516# pylint: disable=method-hidden,too-many-instance-attributes
1617class ProgressBar (Callback ): # type: ignore
1718 """A progress bar for DataPrep.EDA.
19+ Not thread safe.
1820
1921 Parameters
2022 ----------
2123 minimum : int, optional
2224 Minimum time threshold in seconds before displaying a progress bar.
2325 Default is 0 (always display)
24- _min_tasks : int, optional
26+ min_tasks : int, optional
2527 Minimum graph size to show a progress bar, default is 5
2628 width : int, optional
2729 Width of the bar. None means auto width.
2830 interval : float, optional
29- Update resolution in seconds, default is 0.1 seconds
31+ Update resolution in seconds, default is 0.1 seconds.
32+ disable : bool, optional
33+ Disable the progress bar.
3034 """
3135
3236 _minimum : float = 0
@@ -38,50 +42,75 @@ class ProgressBar(Callback): # type: ignore
3842 _state : Optional [Dict [str , Any ]] = None
3943 _started : Optional [float ] = None
4044 _last_task : Optional [str ] = None # in case we initialize the pbar in _finish
45+ _pbar_runtime : float = 0
46+ _last_updated : Optional [float ] = None
47+ _disable : bool = False
4148
42- def __init__ (
49+ def __init__ ( # pylint: disable=too-many-arguments
4350 self ,
4451 minimum : float = 0 ,
4552 min_tasks : int = 5 ,
4653 width : Optional [int ] = None ,
4754 interval : float = 0.1 ,
55+ disable : bool = False ,
4856 ) -> None :
4957 super ().__init__ ()
5058 self ._minimum = minimum
5159 self ._min_tasks = min_tasks
5260 self ._width = width
5361 self ._interval = interval
62+ self ._disable = disable
5463
5564 def _start (self , _dsk : Any ) -> None :
5665 """A hook to start this callback."""
5766
5867 def _start_state (self , _dsk : Any , state : Dict [str , Any ]) -> None :
5968 """A hook called before every task gets executed."""
60- self ._started = time ()
69+ if self ._disable :
70+ return
71+
72+ then = time ()
73+
74+ self ._last_updated = self ._started = time ()
75+
6176 self ._state = state
6277 _ , ntasks = self ._count_tasks ()
6378
6479 if ntasks > self ._min_tasks :
65- self ._init_bar ()
80+ self ._init_pbar ()
81+
82+ self ._pbar_runtime += time () - then
6683
6784 def _pretask (
6885 self , key : Union [str , Tuple [str , ...]], _dsk : Any , _state : Dict [str , Any ]
6986 ) -> None :
7087 """A hook called before one task gets executed."""
88+ if self ._disable :
89+ return
90+
91+ then = time ()
92+
7193 if self ._started is None :
7294 raise ValueError ("ProgressBar not started properly" )
7395
7496 if self ._pbar is None and time () - self ._started > self ._minimum :
75- self ._init_bar ()
97+ self ._init_pbar ()
7698
7799 if isinstance (key , tuple ):
78100 key = key [0 ]
79101
80102 if self ._pbar is not None :
81- self ._pbar .set_description (f"Computing { key } " )
103+ if self ._last_updated is None :
104+ raise ValueError ("ProgressBar not started properly" )
105+
106+ if time () - self ._last_updated > self ._interval :
107+ self ._pbar .set_description (f"Computing { key } " )
108+ self ._last_updated = time ()
82109 else :
83110 self ._last_task = key
84111
112+ self ._pbar_runtime += time () - then
113+
85114 def _posttask (
86115 self ,
87116 _key : str ,
@@ -92,21 +121,47 @@ def _posttask(
92121 ) -> None :
93122 """A hook called after one task gets executed."""
94123
124+ if self ._disable :
125+ return
126+
127+ then = time ()
128+
95129 if self ._pbar is not None :
96- self ._update_bar ()
130+ if self ._last_updated is None :
131+ raise ValueError ("ProgressBar not started properly" )
132+
133+ if time () - self ._last_updated > self ._interval :
134+ self ._update_bar ()
135+ self ._last_updated = time ()
136+
137+ self ._pbar_runtime += time () - then
97138
98139 def _finish (self , _dsk : Any , _state : Dict [str , Any ], _errored : bool ) -> None :
99140 """A hook called after all tasks get executed."""
141+ if self ._disable :
142+ return
143+
144+ then = time ()
145+
100146 if self ._started is None :
101147 raise ValueError ("ProgressBar not started properly" )
102148
103149 if self ._pbar is None and time () - self ._started > self ._minimum :
104- self ._init_bar ()
150+ self ._init_pbar ()
105151
106152 if self ._pbar is not None :
107153 self ._update_bar ()
108154 self ._pbar .close ()
109155
156+ self ._pbar_runtime += time () - then
157+
158+ if self ._pbar_runtime / (time () - self ._started ) > 0.3 :
159+ print (
160+ "[ProgressBar] ProgressBar takes additional 10%+ of the computation time,"
161+ " consider disable it by passing 'progress=False' to the plot function." ,
162+ file = sys .stderr ,
163+ )
164+
110165 self ._state = None
111166 self ._started = None
112167 self ._pbar = None
@@ -118,7 +173,7 @@ def _update_bar(self) -> None:
118173
119174 self ._pbar .update (max (0 , ndone - self ._pbar .n ))
120175
121- def _init_bar (self ) -> None :
176+ def _init_pbar (self ) -> None :
122177 if self ._pbar is not None :
123178 raise ValueError ("ProgressBar already initialized." )
124179 ndone , ntasks = self ._count_tasks ()
@@ -157,3 +212,13 @@ def _count_tasks(self) -> Tuple[int, int]:
157212 ntasks = sum (len (state [k ]) for k in ["ready" , "waiting" , "running" ]) + ndone
158213
159214 return ndone , ntasks
215+
216+ def register (self ) -> None :
217+ raise ValueError (
218+ "ProgressBar is not thread safe thus cannot be regestered globally"
219+ )
220+
221+ def unregister (self ) -> None :
222+ raise ValueError (
223+ "ProgressBar is not thread safe thus cannot be unregestered globally"
224+ )
0 commit comments