Skip to content

Commit 64be889

Browse files
committed
perf(eda): improve progress bar performance
And added option to disable progress bar.
1 parent 8ebe9cc commit 64be889

File tree

4 files changed

+100
-25
lines changed

4 files changed

+100
-25
lines changed

dataprep/eda/correlation/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def plot_correlation(
2222
*,
2323
value_range: Optional[Tuple[float, float]] = None,
2424
k: Optional[int] = None,
25+
progress: bool = True,
2526
) -> Report:
2627
"""
2728
This function is designed to calculate the correlation between columns
@@ -32,15 +33,17 @@ def plot_correlation(
3233
Parameters
3334
----------
3435
df
35-
The pandas data_frame for which plots are calculated for each column
36+
The pandas data_frame for which plots are calculated for each column.
3637
x
37-
A valid column name of the data frame
38+
A valid column name of the data frame.
3839
y
39-
A valid column name of the data frame
40+
A valid column name of the data frame.
4041
value_range
41-
Range of value
42+
Range of value.
4243
k
43-
Choose top-k element
44+
Choose top-k element.
45+
progress
46+
Enable the progress bar.
4447
4548
Examples
4649
--------
@@ -61,7 +64,7 @@ def plot_correlation(
6164
This function only supports numerical or categorical data,
6265
and it is better to drop None, Nan and Null value before using it
6366
"""
64-
with ProgressBar(minimum=1):
67+
with ProgressBar(minimum=1, disable=not progress):
6568
intermediate = compute_correlation(df, x=x, y=y, value_range=value_range, k=k)
6669
figure = render_correlation(intermediate)
6770

dataprep/eda/distribution/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def plot(
3838
yscale: str = "linear",
3939
tile_size: Optional[float] = None,
4040
dtype: Optional[DTypeDef] = None,
41+
progress: bool = True,
4142
) -> Union[Report, Container]:
4243
"""Generates plots for exploratory data analysis.
4344
@@ -133,6 +134,9 @@ def plot(
133134
E.g. dtype = {"a": Continuous, "b": "Nominal"} or
134135
dtype = {"a": Continuous(), "b": "nominal"}
135136
or dtype = Continuous() or dtype = "Continuous" or dtype = Continuous()
137+
progress
138+
Enable the progress bar.
139+
136140
Examples
137141
--------
138142
>>> import pandas as pd
@@ -144,7 +148,7 @@ def plot(
144148
"""
145149
# pylint: disable=too-many-locals,line-too-long
146150

147-
with ProgressBar(minimum=1):
151+
with ProgressBar(minimum=1, disable=not progress):
148152
intermediate = compute(
149153
df,
150154
x=x,

dataprep/eda/missing/__init__.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def plot_missing(
2424
bins: int = 30,
2525
ndist_sample: int = 100,
2626
dtype: Optional[DTypeDef] = None,
27+
progress: bool = True,
2728
) -> Report:
2829
"""
2930
This function is designed to deal with missing values
@@ -33,20 +34,22 @@ def plot_missing(
3334
Parameters
3435
----------
3536
df
36-
the pandas data_frame for which plots are calculated for each column
37+
the pandas data_frame for which plots are calculated for each column.
3738
x
38-
a valid column name of the data frame
39+
a valid column name of the data frame.
3940
y
40-
a valid column name of the data frame
41+
a valid column name of the data frame.
4142
bins
42-
The number of rows in the figure
43+
The number of rows in the figure.
4344
ndist_sample
44-
The number of sample points
45+
The number of sample points.
4546
wdtype: str or DType or dict of str or dict of DType, default None
4647
Specify Data Types for designated column or all columns.
4748
E.g. dtype = {"a": Continuous, "b": "Nominal"} or
4849
dtype = {"a": Continuous(), "b": "nominal"}
49-
or dtype = Continuous() or dtype = "Continuous" or dtype = Continuous()
50+
or dtype = Continuous() or dtype = "Continuous" or dtype = Continuous().
51+
progress
52+
Enable the progress bar.
5053
5154
Examples
5255
----------
@@ -57,7 +60,7 @@ def plot_missing(
5760
>>> plot_missing(df, "HDI_for_year", "population")
5861
"""
5962

60-
with ProgressBar(minimum=1):
63+
with ProgressBar(minimum=1, disable=not progress):
6164
itmdt = compute_missing(
6265
df, x, y, dtype=dtype, bins=bins, ndist_sample=ndist_sample
6366
)

dataprep/eda/progress_bar.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""ProgressBar shows the how many dask tasks finished/remains using tqdm."""
22

3-
from typing import Any, Optional, Dict, Tuple, Union
3+
import sys
44
from time import time
5+
from typing import Any, Dict, Optional, Tuple, Union
56

67
from dask.callbacks import Callback
78

@@ -15,18 +16,21 @@
1516
# pylint: disable=method-hidden,too-many-instance-attributes
1617
class ProgressBar(Callback): # type: ignore
1718
"""A progress bar for DataPrep.EDA.
19+
Not thread safe.
1820
1921
Parameters
2022
----------
2123
minimum : int, optional
2224
Minimum time threshold in seconds before displaying a progress bar.
2325
Default is 0 (always display)
24-
_min_tasks : int, optional
26+
min_tasks : int, optional
2527
Minimum graph size to show a progress bar, default is 5
2628
width : int, optional
2729
Width of the bar. None means auto width.
2830
interval : float, optional
29-
Update resolution in seconds, default is 0.1 seconds
31+
Update resolution in seconds, default is 0.1 seconds.
32+
disable : bool, optional
33+
Disable the progress bar.
3034
"""
3135

3236
_minimum: float = 0
@@ -38,50 +42,75 @@ class ProgressBar(Callback): # type: ignore
3842
_state: Optional[Dict[str, Any]] = None
3943
_started: Optional[float] = None
4044
_last_task: Optional[str] = None # in case we initialize the pbar in _finish
45+
_pbar_runtime: float = 0
46+
_last_updated: Optional[float] = None
47+
_disable: bool = False
4148

42-
def __init__(
49+
def __init__( # pylint: disable=too-many-arguments
4350
self,
4451
minimum: float = 0,
4552
min_tasks: int = 5,
4653
width: Optional[int] = None,
4754
interval: float = 0.1,
55+
disable: bool = False,
4856
) -> None:
4957
super().__init__()
5058
self._minimum = minimum
5159
self._min_tasks = min_tasks
5260
self._width = width
5361
self._interval = interval
62+
self._disable = disable
5463

5564
def _start(self, _dsk: Any) -> None:
5665
"""A hook to start this callback."""
5766

5867
def _start_state(self, _dsk: Any, state: Dict[str, Any]) -> None:
5968
"""A hook called before every task gets executed."""
60-
self._started = time()
69+
if self._disable:
70+
return
71+
72+
then = time()
73+
74+
self._last_updated = self._started = time()
75+
6176
self._state = state
6277
_, ntasks = self._count_tasks()
6378

6479
if ntasks > self._min_tasks:
65-
self._init_bar()
80+
self._init_pbar()
81+
82+
self._pbar_runtime += time() - then
6683

6784
def _pretask(
6885
self, key: Union[str, Tuple[str, ...]], _dsk: Any, _state: Dict[str, Any]
6986
) -> None:
7087
"""A hook called before one task gets executed."""
88+
if self._disable:
89+
return
90+
91+
then = time()
92+
7193
if self._started is None:
7294
raise ValueError("ProgressBar not started properly")
7395

7496
if self._pbar is None and time() - self._started > self._minimum:
75-
self._init_bar()
97+
self._init_pbar()
7698

7799
if isinstance(key, tuple):
78100
key = key[0]
79101

80102
if self._pbar is not None:
81-
self._pbar.set_description(f"Computing {key}")
103+
if self._last_updated is None:
104+
raise ValueError("ProgressBar not started properly")
105+
106+
if time() - self._last_updated > self._interval:
107+
self._pbar.set_description(f"Computing {key}")
108+
self._last_updated = time()
82109
else:
83110
self._last_task = key
84111

112+
self._pbar_runtime += time() - then
113+
85114
def _posttask(
86115
self,
87116
_key: str,
@@ -92,21 +121,47 @@ def _posttask(
92121
) -> None:
93122
"""A hook called after one task gets executed."""
94123

124+
if self._disable:
125+
return
126+
127+
then = time()
128+
95129
if self._pbar is not None:
96-
self._update_bar()
130+
if self._last_updated is None:
131+
raise ValueError("ProgressBar not started properly")
132+
133+
if time() - self._last_updated > self._interval:
134+
self._update_bar()
135+
self._last_updated = time()
136+
137+
self._pbar_runtime += time() - then
97138

98139
def _finish(self, _dsk: Any, _state: Dict[str, Any], _errored: bool) -> None:
99140
"""A hook called after all tasks get executed."""
141+
if self._disable:
142+
return
143+
144+
then = time()
145+
100146
if self._started is None:
101147
raise ValueError("ProgressBar not started properly")
102148

103149
if self._pbar is None and time() - self._started > self._minimum:
104-
self._init_bar()
150+
self._init_pbar()
105151

106152
if self._pbar is not None:
107153
self._update_bar()
108154
self._pbar.close()
109155

156+
self._pbar_runtime += time() - then
157+
158+
if self._pbar_runtime / (time() - self._started) > 0.3:
159+
print(
160+
"[ProgressBar] ProgressBar takes additional 10%+ of the computation time,"
161+
" consider disable it by passing 'progress=False' to the plot function.",
162+
file=sys.stderr,
163+
)
164+
110165
self._state = None
111166
self._started = None
112167
self._pbar = None
@@ -118,7 +173,7 @@ def _update_bar(self) -> None:
118173

119174
self._pbar.update(max(0, ndone - self._pbar.n))
120175

121-
def _init_bar(self) -> None:
176+
def _init_pbar(self) -> None:
122177
if self._pbar is not None:
123178
raise ValueError("ProgressBar already initialized.")
124179
ndone, ntasks = self._count_tasks()
@@ -157,3 +212,13 @@ def _count_tasks(self) -> Tuple[int, int]:
157212
ntasks = sum(len(state[k]) for k in ["ready", "waiting", "running"]) + ndone
158213

159214
return ndone, ntasks
215+
216+
def register(self) -> None:
217+
raise ValueError(
218+
"ProgressBar is not thread safe thus cannot be regestered globally"
219+
)
220+
221+
def unregister(self) -> None:
222+
raise ValueError(
223+
"ProgressBar is not thread safe thus cannot be unregestered globally"
224+
)

0 commit comments

Comments
 (0)