Skip to content

Commit a0fd880

Browse files
committed
Fix write-back progress bar for write-back to self-managed dbs
1 parent 8639dfb commit a0fd880

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

graphdatascience/query_runner/protocol/write_protocols.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def run_write_back(
149149
terminationFlag: TerminationFlag,
150150
) -> DataFrame:
151151
def is_not_completed(result: DataFrame) -> bool:
152-
status: str = result.squeeze()["status"]
152+
status: str = result.iloc[0]["status"]
153153
return status != Status.COMPLETED.name
154154

155155
logger = logging.getLogger()
@@ -175,9 +175,12 @@ def write_fn(progress_bar: Optional[TqdmProgressBar]) -> DataFrame:
175175
mode=QueryMode.WRITE,
176176
custom_error=False,
177177
)
178+
result_row = result.iloc[0].to_dict()
179+
# for self-managed dbs the endpoint doesn't return the progress yet
180+
progress = result_row.get("progress", 0.0) * 100
178181

179182
if progress_bar:
180-
progress_bar.update(status=result.squeeze()["status"], progress=result.squeeze()["progress"] * 100)
183+
progress_bar.update(status=result_row["status"], progress=progress)
181184

182185
return result
183186

graphdatascience/tests/unit/query_runner/test_write_protocols.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,30 @@ def test_write_back_v3_progress_logging() -> None:
3434
assert any(["Write-Back (graph: myGraph): 100%|##########| 100.0/100" in line for line in bar_output])
3535

3636

37+
def test_write_back_v3_progress_logging_without_bar() -> None:
38+
# for self-managed dbs the endpoint doesnt return the progress yet
39+
with StringIO() as pbarOutputStream:
40+
qr = CollectingQueryRunner(ServerVersion(0, 0, 0))
41+
qr.add__mock_result("gds.arrow.write.v3", DataFrame([{"status": Status.COMPLETED.name}]))
42+
43+
wp = RemoteWriteBackV3(progress_bar_options={"file": pbarOutputStream, "mininterval": 0})
44+
45+
wp.run_write_back(
46+
query_runner=qr,
47+
parameters=CallParameters(graphName="myGraph", jobId="myJob"),
48+
log_progress=True,
49+
terminationFlag=TerminationFlagNoop(),
50+
yields=None,
51+
)
52+
53+
bar_output = pbarOutputStream.getvalue().split("\r")
54+
55+
assert any(
56+
["Write-Back (graph: myGraph): 0%| | 0.0/100 [00:00<?, ?%/s]" in line for line in bar_output]
57+
), bar_output
58+
assert any(["Write-Back (graph: myGraph): 100%|##########| 100.0/100" in line for line in bar_output])
59+
60+
3761
def test_write_back_v3_progress_logging_aborted() -> None:
3862
with StringIO() as pbarOutputStream:
3963
qr = CollectingQueryRunner(ServerVersion(0, 0, 0))

0 commit comments

Comments
 (0)