Skip to content

Commit 37e83bb

Browse files
committed
Test progress bar output
Also fixing some inconsistencies on the way. Such as only update if needed and capitalize status
1 parent 6998b75 commit 37e83bb

File tree

2 files changed

+120
-32
lines changed

2 files changed

+120
-32
lines changed

graphdatascience/query_runner/progress/query_progress_logger.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,27 @@
66
from tqdm.auto import tqdm
77

88
from ...server_version.server_version import ServerVersion
9-
from .progress_provider import ProgressProvider
9+
from .progress_provider import ProgressProvider, TaskWithProgress
1010
from .query_progress_provider import CypherQueryFunction, QueryProgressProvider, ServerVersionFunction
1111
from .static_progress_provider import StaticProgressProvider, StaticProgressStore
1212

1313
DataFrameProducer = Callable[[], DataFrame]
1414

1515

1616
class QueryProgressLogger:
17-
_LOG_POLLING_INTERVAL = 0.5
18-
1917
def __init__(
2018
self,
2119
run_cypher_func: CypherQueryFunction,
2220
server_version_func: ServerVersionFunction,
21+
polling_interval: float = 0.5,
22+
progress_bar_options: dict[str, Any] = {},
2323
):
2424
self._run_cypher_func = run_cypher_func
2525
self._server_version_func = server_version_func
2626
self._static_progress_provider = StaticProgressProvider()
2727
self._query_progress_provider = QueryProgressProvider(run_cypher_func, server_version_func)
28+
self._polling_interval = polling_interval
29+
self._progress_bar_options = progress_bar_options
2830

2931
def run_with_progress_logging(
3032
self, runnable: DataFrameProducer, job_id: str, database: Optional[str] = None
@@ -54,39 +56,18 @@ def _select_progress_provider(self, job_id: str) -> ProgressProvider:
5456
)
5557

5658
def _log(
57-
self, future: "Future[Any]", job_id: str, progress_provider: ProgressProvider, database: Optional[str] = None
59+
self, future: Future[Any], job_id: str, progress_provider: ProgressProvider, database: Optional[str] = None
5860
) -> None:
5961
pbar: Optional[tqdm[NoReturn]] = None
6062
warn_if_failure = True
6163

62-
while wait([future], timeout=self._LOG_POLLING_INTERVAL).not_done:
64+
while wait([future], timeout=self._polling_interval).not_done:
6365
try:
6466
task_with_progress = progress_provider.root_task_with_progress(job_id, database)
65-
root_task_name = task_with_progress.task_name
66-
progress_percent = task_with_progress.progress_percent
67-
68-
has_relative_progress = progress_percent != "n/a"
6967
if pbar is None:
70-
if has_relative_progress:
71-
pbar = tqdm(total=100, unit="%", desc=root_task_name, maxinterval=self._LOG_POLLING_INTERVAL)
72-
else:
73-
pbar = tqdm(
74-
total=None,
75-
unit="",
76-
desc=root_task_name,
77-
maxinterval=self._LOG_POLLING_INTERVAL,
78-
bar_format="{desc} [elapsed: {elapsed} {postfix}]",
79-
)
80-
81-
pbar.set_postfix_str(
82-
f"status: {task_with_progress.status}, task: {task_with_progress.sub_tasks_description}"
83-
)
84-
if has_relative_progress:
85-
parsed_progress = float(progress_percent[:-1])
86-
new_progress = parsed_progress - pbar.n
87-
pbar.update(new_progress)
88-
else:
89-
pbar.refresh() # show latest elapsed time + postfix
68+
pbar = self._init_pbar(task_with_progress)
69+
70+
self._update_pbar(pbar, task_with_progress)
9071
except Exception as e:
9172
# Do nothing if the procedure either:
9273
# * has not started yet,
@@ -100,7 +81,51 @@ def _log(
10081
continue
10182

10283
if pbar is not None:
103-
if pbar.total is not None:
104-
pbar.update(pbar.total - pbar.n)
105-
pbar.set_postfix_str("status: finished")
84+
self._finish_pbar(pbar)
85+
86+
def _init_pbar(self, task_with_progress: TaskWithProgress) -> tqdm: # type: ignore
87+
root_task_name = task_with_progress.task_name
88+
parsed_progress = QueryProgressLogger._relative_progress(task_with_progress)
89+
if parsed_progress is None: # Qualitative progress report
90+
return tqdm(
91+
total=None,
92+
unit="",
93+
desc=root_task_name,
94+
maxinterval=self._polling_interval,
95+
bar_format="{desc} [elapsed: {elapsed} {postfix}]",
96+
**self._progress_bar_options,
97+
)
98+
else:
99+
return tqdm(
100+
total=100,
101+
unit="%",
102+
desc=root_task_name,
103+
maxinterval=self._polling_interval,
104+
**self._progress_bar_options,
105+
)
106+
107+
def _update_pbar(self, pbar: tqdm, task_with_progress: TaskWithProgress) -> None: # type: ignore
108+
parsed_progress = QueryProgressLogger._relative_progress(task_with_progress)
109+
postfix = (
110+
f"status: {task_with_progress.status}, task: {task_with_progress.sub_tasks_description}"
111+
if task_with_progress.sub_tasks_description
112+
else f"status: {task_with_progress.status}"
113+
)
114+
pbar.set_postfix_str(postfix, refresh=False)
115+
if parsed_progress is not None:
116+
new_progress = parsed_progress - pbar.n
117+
pbar.update(new_progress)
118+
else:
106119
pbar.refresh()
120+
121+
def _finish_pbar(self, pbar: tqdm) -> None: # type: ignore
122+
if pbar.total is not None:
123+
pbar.update(pbar.total - pbar.n)
124+
pbar.set_postfix_str("status: FINISHED", refresh=True)
125+
126+
@staticmethod
127+
def _relative_progress(task: TaskWithProgress) -> Optional[float]:
128+
try:
129+
return float(task.progress_percent.removesuffix("%"))
130+
except ValueError:
131+
return None

graphdatascience/tests/unit/query_runner/progress/test_query_progress_logger.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import re
12
import time
3+
from io import StringIO
24
from typing import Optional
35

46
from pandas import DataFrame
57

68
from graphdatascience import ServerVersion
9+
from graphdatascience.query_runner.progress.progress_provider import TaskWithProgress
710
from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger
811
from graphdatascience.query_runner.progress.query_progress_provider import QueryProgressProvider
912
from graphdatascience.query_runner.progress.static_progress_provider import StaticProgressProvider, StaticProgressStore
@@ -97,6 +100,66 @@ def simple_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
97100
assert progress.task_name == "Test task"
98101

99102

103+
def test_progress_bar_quantitive_output() -> None:
104+
def simple_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
105+
raise NotImplementedError("Should not be called!")
106+
107+
with StringIO() as pbarOutputStream:
108+
qpl = QueryProgressLogger(
109+
simple_run_cypher,
110+
lambda: ServerVersion(3, 0, 0),
111+
progress_bar_options={"file": pbarOutputStream, "mininterval": 0},
112+
)
113+
114+
pbar = qpl._init_pbar(TaskWithProgress("test task", "0%", "PENDING", ""))
115+
assert pbarOutputStream.getvalue().split("\r")[-1] == "test task: 0%| | 0/100 [00:00<?, ?%/s]"
116+
117+
qpl._update_pbar(pbar, TaskWithProgress("test task", "0%", "PENDING", ""))
118+
assert (
119+
pbarOutputStream.getvalue().split("\r")[-1]
120+
== "test task: 0%| | 0.0/100 [00:00<?, ?%/s, status: PENDING]"
121+
)
122+
qpl._update_pbar(pbar, TaskWithProgress("test task", "42%", "RUNNING", "root::1/1::leaf"))
123+
124+
running_output = pbarOutputStream.getvalue().split("\r")[-1]
125+
assert re.match(
126+
r"test task: 42%\|####2 \| 42.0/100 \[00:00<00:00, \d+.\d*%/s, status: RUNNING, task: root::1/1::leaf\]",
127+
running_output,
128+
), running_output
129+
130+
qpl._finish_pbar(pbar)
131+
finished_output = pbarOutputStream.getvalue().split("\r")[-1]
132+
assert re.match(
133+
r"test task: 100%\|##########\| 100.0/100 \[00:00<00:00, \d+.\d+%/s, status: FINISHED\]", finished_output
134+
), finished_output
135+
136+
137+
def test_progress_bar_qualitative_output() -> None:
138+
def simple_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
139+
raise NotImplementedError("Should not be called!")
140+
141+
with StringIO() as pbarOutputStream:
142+
qpl = QueryProgressLogger(
143+
simple_run_cypher,
144+
lambda: ServerVersion(3, 0, 0),
145+
progress_bar_options={"file": pbarOutputStream, "mininterval": 100},
146+
)
147+
148+
pbar = qpl._init_pbar(TaskWithProgress("test task", "n/a", "PENDING", ""))
149+
150+
qpl._update_pbar(pbar, TaskWithProgress("test task", "n/a", "PENDING", ""))
151+
qpl._update_pbar(pbar, TaskWithProgress("test task", "", "RUNNING", "root 1/1::leaf"))
152+
qpl._finish_pbar(pbar)
153+
154+
assert pbarOutputStream.getvalue().rstrip() == "".join(
155+
[
156+
"\rtest task [elapsed: 00:00 ]\rtest task [elapsed: 00:00 , status: PENDING]",
157+
"\rtest task [elapsed: 00:00 , status: RUNNING, task: root 1/1::leaf]",
158+
"\rtest task [elapsed: 00:00 , status: FINISHED]",
159+
]
160+
)
161+
162+
100163
def test_uses_static_store() -> None:
101164
def fake_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
102165
return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}])

0 commit comments

Comments
 (0)