Skip to content

Commit 1624c15

Browse files
committed
Implement logProgress
1 parent 59612c3 commit 1624c15

File tree

6 files changed

+186
-0
lines changed

6 files changed

+186
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
5+
from graphdatascience.procedure_surface.api.base_result import BaseResult
6+
7+
8+
class SystemEndpoints(ABC):
9+
@abstractmethod
10+
def list_progress(
11+
self,
12+
job_id: str | None = None,
13+
show_completed: bool = False,
14+
) -> list[ProgressResult]:
15+
pass
16+
17+
18+
class ProgressResult(BaseResult):
19+
username: str
20+
job_id: str
21+
task_name: str
22+
progress: str
23+
progress_bar: str
24+
status: str
25+
time_started: str
26+
elapsed_time: str
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
4+
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize
5+
from graphdatascience.procedure_surface.api.system_endpoints import ProgressResult, SystemEndpoints
6+
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
7+
8+
9+
class SystemArrowEndpoints(SystemEndpoints):
10+
def __init__(self, arrow_client: AuthenticatedArrowClient):
11+
self._arrow_client = arrow_client
12+
13+
def list_progress(
14+
self,
15+
job_id: str | None = None,
16+
show_completed: bool = False,
17+
) -> list[ProgressResult]:
18+
config = ConfigConverter.convert_to_gds_config(
19+
job_id=job_id,
20+
show_completed=show_completed,
21+
)
22+
23+
rows = deserialize(self._arrow_client.do_action_with_retry("v2/listProgress", config))
24+
return [ProgressResult(**row) for row in rows]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from __future__ import annotations
2+
3+
from graphdatascience.call_parameters import CallParameters
4+
from graphdatascience.procedure_surface.api.system_endpoints import ProgressResult, SystemEndpoints
5+
from graphdatascience.query_runner.query_runner import QueryRunner
6+
7+
8+
class SystemCypherEndpoints(SystemEndpoints):
9+
def __init__(self, query_runner: QueryRunner):
10+
self._query_runner = query_runner
11+
12+
def list_progress(
13+
self,
14+
job_id: str | None = None,
15+
show_completed: bool = False,
16+
) -> list[ProgressResult]:
17+
params = CallParameters(
18+
job_id=job_id if job_id else "",
19+
show_completed=True,
20+
)
21+
22+
result = self._query_runner.call_procedure(endpoint="gds.listProgress", params=params)
23+
return [ProgressResult(**row.to_dict()) for _, row in result.iterrows()]

graphdatascience/session/session_v2_endpoints.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
)
6565
from graphdatascience.procedure_surface.arrow.node_embedding.hashgnn_arrow_endpoints import HashGNNArrowEndpoints
6666
from graphdatascience.procedure_surface.arrow.node_embedding.node2vec_arrow_endpoints import Node2VecArrowEndpoints
67+
from graphdatascience.procedure_surface.arrow.system_arrow_endpoints import SystemArrowEndpoints
6768
from graphdatascience.query_runner.query_runner import QueryRunner
6869

6970

@@ -91,6 +92,10 @@ def graph(self) -> CatalogArrowEndpoints:
9192
def config(self) -> ConfigArrowEndpoints:
9293
return ConfigArrowEndpoints(self._arrow_client)
9394

95+
@property
96+
def system(self) -> SystemArrowEndpoints:
97+
return SystemArrowEndpoints(self._arrow_client)
98+
9499
## Algorithms
95100

96101
@property
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Generator
2+
3+
import pytest
4+
5+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
6+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
7+
from graphdatascience.procedure_surface.arrow.community.wcc_arrow_endpoints import WccArrowEndpoints
8+
from graphdatascience.procedure_surface.arrow.system_arrow_endpoints import SystemArrowEndpoints
9+
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph
10+
11+
graph = """
12+
CREATE
13+
(a1)-[:T]->(a2)
14+
"""
15+
16+
17+
@pytest.fixture(scope="class")
18+
def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]:
19+
with create_graph(arrow_client, "g", graph) as G:
20+
yield G
21+
22+
23+
@pytest.fixture(scope="class")
24+
def job_id(arrow_client: AuthenticatedArrowClient, sample_graph: GraphV2) -> Generator[str, None, None]:
25+
job_id = "test_job_id"
26+
WccArrowEndpoints(arrow_client).stats(sample_graph, job_id=job_id)
27+
yield job_id
28+
29+
30+
@pytest.fixture
31+
def system_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[SystemArrowEndpoints, None, None]:
32+
yield SystemArrowEndpoints(arrow_client)
33+
34+
35+
def test_list_progress_job_id(system_endpoints: SystemArrowEndpoints, job_id: str) -> None:
36+
results = system_endpoints.list_progress(job_id=job_id, show_completed=True)
37+
38+
assert len(results) == 1
39+
40+
result = results[0]
41+
assert result.username == "neo4j"
42+
assert result.job_id == job_id
43+
assert "WCC" in result.task_name
44+
assert result.progress == "100%"
45+
assert "#" in result.progress_bar
46+
47+
48+
def test_list_nothing(system_endpoints: SystemArrowEndpoints) -> None:
49+
results = system_endpoints.list_progress()
50+
assert len(results) == 0
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Generator
2+
3+
import pytest
4+
5+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
6+
from graphdatascience.procedure_surface.cypher.community.wcc_cypher_endpoints import WccCypherEndpoints
7+
from graphdatascience.procedure_surface.cypher.system_cypher_endpoints import SystemCypherEndpoints
8+
from graphdatascience.query_runner.query_runner import QueryRunner
9+
from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import create_graph
10+
11+
12+
@pytest.fixture
13+
def system_endpoints(query_runner: QueryRunner) -> Generator[SystemCypherEndpoints, None, None]:
14+
yield SystemCypherEndpoints(query_runner)
15+
16+
17+
graph = """
18+
CREATE
19+
(a1)-[:T]->(a2)
20+
"""
21+
22+
projection_query = """
23+
MATCH (n)-[r]->(m)
24+
WITH gds.graph.project('g', n, m, {}) AS G
25+
RETURN G
26+
"""
27+
28+
29+
@pytest.fixture(scope="class")
30+
def sample_graph(query_runner: QueryRunner) -> Generator[GraphV2, None, None]:
31+
with create_graph(query_runner, "g", graph, projection_query) as G:
32+
yield G
33+
34+
35+
@pytest.fixture(scope="class")
36+
def job_id(query_runner: QueryRunner, sample_graph: GraphV2) -> Generator[str, None, None]:
37+
job_id = "test_job_id"
38+
WccCypherEndpoints(query_runner).mutate(sample_graph, job_id=job_id, log_progress=True, mutate_property="wcc")
39+
yield job_id
40+
41+
42+
@pytest.mark.skip(reason="Enable when we figure out how to retain jobs")
43+
def test_list_progress_job_id(system_endpoints: SystemCypherEndpoints, job_id: str) -> None:
44+
results = system_endpoints.list_progress(show_completed=True)
45+
46+
assert len(results) == 1
47+
48+
result = results[0]
49+
assert result.username == "neo4j"
50+
assert result.job_id == job_id
51+
assert "WCC" in result.task_name
52+
assert result.progress == "100%"
53+
assert "#" in result.progress_bar
54+
55+
56+
def test_list_nothing(system_endpoints: SystemCypherEndpoints) -> None:
57+
results = system_endpoints.list_progress()
58+
assert len(results) == 0

0 commit comments

Comments
 (0)