Skip to content

Commit 6998b75

Browse files
committed
Display running subtask in progress bar
1 parent 87bc4b7 commit 6998b75

File tree

5 files changed

+61
-19
lines changed

5 files changed

+61
-19
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
## Improvements
1818

1919
* Display progress bar for remote projection and open-ended tasks.
20+
* Improve progress bar by showing currently running task.
2021
* Allow passing the optional graph filter also as type `str` to `gds.graph.list()` instead of only `Graph`.
2122

2223

graphdatascience/query_runner/progress/progress_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class TaskWithProgress:
88
task_name: str
99
progress_percent: str
1010
status: str
11+
sub_tasks_description: Optional[str] = None
1112

1213

1314
class ProgressProvider(ABC):

graphdatascience/query_runner/progress/query_progress_logger.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def _log(
7070
if has_relative_progress:
7171
pbar = tqdm(total=100, unit="%", desc=root_task_name, maxinterval=self._LOG_POLLING_INTERVAL)
7272
else:
73-
# TODO add {n_fmt} once task_with_progress provides the absolute progress
7473
pbar = tqdm(
7574
total=None,
7675
unit="",
@@ -79,7 +78,9 @@ def _log(
7978
bar_format="{desc} [elapsed: {elapsed} {postfix}]",
8079
)
8180

82-
pbar.set_postfix_str(f"status: {task_with_progress.status}")
81+
pbar.set_postfix_str(
82+
f"status: {task_with_progress.status}, task: {task_with_progress.sub_tasks_description}"
83+
)
8384
if has_relative_progress:
8485
parsed_progress = float(progress_percent[:-1])
8586
new_progress = parsed_progress - pbar.n

graphdatascience/query_runner/progress/query_progress_provider.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,33 @@ def __init__(
2121

2222
def root_task_with_progress(self, job_id: str, database: Optional[str] = None) -> TaskWithProgress:
2323
tier = "beta." if self._server_version_func() < ServerVersion(2, 5, 0) else ""
24-
# we only retrieve the progress of the root task
24+
25+
# expect at exactly one row (query will fail if not existing)
2526
progress = self._run_cypher_func(
2627
f"CALL gds.{tier}listProgress('{job_id}')"
2728
+ " YIELD taskName, progress, status"
28-
+ " RETURN taskName, progress, status"
29-
+ " LIMIT 1",
29+
+ " RETURN taskName, progress, status",
3030
database,
31-
).squeeze() # expect at exactly one row (query will fail if not existing)
32-
33-
progress_percent = progress["progress"]
34-
root_task_name = progress["taskName"].split("|--")[-1][1:]
35-
36-
return TaskWithProgress(root_task_name, progress_percent, progress["status"])
31+
)
32+
33+
# compute depth of each subtask
34+
progress["trimmedName"] = progress["taskName"].str.lstrip()
35+
progress["depth"] = progress["taskName"].str.len() - progress["trimmedName"].str.len()
36+
progress.sort_values("depth", ascending=True, inplace=True)
37+
38+
root_task = progress.iloc[0]
39+
root_progress_percent = root_task["progress"]
40+
root_task_name = root_task["trimmedName"].replace("|--", "")
41+
root_status = root_task["status"]
42+
43+
subtask_descriptions = None
44+
running_tasks = progress[progress["status"] == "RUNNING"]
45+
if running_tasks["taskName"].size > 1: # at least one subtask
46+
subtasks = running_tasks[1:] # remove root task
47+
subtask_descriptions = "::".join(
48+
list(subtasks["taskName"].apply(lambda name: name.split("|--")[-1].strip()))
49+
)
50+
51+
return TaskWithProgress(
52+
root_task_name, root_progress_percent, root_status, sub_tasks_description=subtask_descriptions
53+
)

graphdatascience/tests/unit/query_runner/progress/test_query_progress_logger.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212

1313
def test_call_through_functions() -> None:
1414
def fake_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
15-
assert (
16-
query
17-
== "CALL gds.listProgress('foo') YIELD taskName, progress, status RETURN taskName, progress, status LIMIT 1"
18-
)
15+
assert "CALL gds.listProgress('foo')" in query
1916
assert database == "database"
2017

2118
return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}])
@@ -46,10 +43,7 @@ def fake_query() -> DataFrame:
4643

4744
def test_uses_beta_endpoint() -> None:
4845
def fake_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
49-
assert (
50-
query
51-
== "CALL gds.beta.listProgress('foo') YIELD taskName, progress, status RETURN taskName, progress, status LIMIT 1"
52-
)
46+
assert "CALL gds.beta.listProgress('foo')" in query
5347
assert database == "database"
5448

5549
return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}])
@@ -76,6 +70,33 @@ def simple_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
7670
assert isinstance(progress_provider, QueryProgressProvider)
7771

7872

73+
def test_uses_query_provider_with_task_description() -> None:
74+
server_version = ServerVersion(3, 0, 0)
75+
detailed_progress = DataFrame(
76+
[
77+
{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"},
78+
{"progress": "n/a", "taskName": " |-- root 1/1", "status": "RUNNING"},
79+
{"progress": "n/a", "taskName": " |-- leaf", "status": "RUNNING"},
80+
{"progress": "n/a", "taskName": "finished task", "status": "FINISHED"},
81+
{"progress": "n/a", "taskName": "pending task", "status": "PENDING"},
82+
]
83+
)
84+
85+
query_runner = CollectingQueryRunner(server_version, result_mock={"gds.listProgress": detailed_progress})
86+
87+
def simple_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
88+
return query_runner.run_cypher(query, db=database)
89+
90+
qpl = QueryProgressLogger(simple_run_cypher, lambda: server_version)
91+
progress_provider = qpl._select_progress_provider("test-job")
92+
assert isinstance(progress_provider, QueryProgressProvider)
93+
94+
progress = progress_provider.root_task_with_progress("test-job", "database")
95+
96+
assert progress.sub_tasks_description == "root 1/1::leaf"
97+
assert progress.task_name == "Test task"
98+
99+
79100
def test_uses_static_store() -> None:
80101
def fake_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
81102
return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}])
@@ -88,3 +109,4 @@ def fake_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
88109
task_with_volume = progress_provider.root_task_with_progress("test-job")
89110
assert task_with_volume.task_name == "Test task"
90111
assert task_with_volume.progress_percent == "n/a"
112+
assert task_with_volume.sub_tasks_description is None

0 commit comments

Comments
 (0)