Skip to content

Commit 9520e84

Browse files
committed
Copy from branch
1 parent d525f1d commit 9520e84

File tree

3 files changed

+876
-1
lines changed

3 files changed

+876
-1
lines changed

examples/FastPathExamples.ipynb

Lines changed: 760 additions & 0 deletions
Large diffs are not rendered by default.

graphdatascience/graph_data_science.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
None if arrow is True else arrow,
8383
)
8484

85-
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
85+
super().__init__(self._query_runner, "gds", self._server_version)
8686

8787
@property
8888
def graph(self) -> GraphProcRunner:
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import logging
2+
import os
3+
import time
4+
from typing import Any, Dict, Optional
5+
6+
import requests
7+
from pandas import Series
8+
9+
from ..error.client_only_endpoint import client_only_endpoint
10+
from ..error.illegal_attr_checker import IllegalAttrChecker
11+
from ..error.uncallable_namespace import UncallableNamespace
12+
from ..graph.graph_object import Graph
13+
from ..query_runner.query_runner import QueryRunner
14+
from ..server_version.compatible_with import compatible_with
15+
from ..server_version.server_version import ServerVersion
16+
17+
logging.basicConfig(level=logging.INFO)
18+
19+
20+
class FastPathRunner(UncallableNamespace, IllegalAttrChecker):
21+
def __init__(
22+
self,
23+
query_runner: QueryRunner,
24+
namespace: str,
25+
server_version: ServerVersion,
26+
compute_cluster_ip: str,
27+
encrypted_db_password: str,
28+
arrow_uri: str,
29+
):
30+
self._query_runner = query_runner
31+
self._namespace = namespace
32+
self._server_version = server_version
33+
self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005"
34+
self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815"
35+
self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080"
36+
self._encrypted_db_password = encrypted_db_password
37+
self._arrow_uri = arrow_uri
38+
39+
@compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
40+
@client_only_endpoint("gds.fastpath")
41+
def mutate(
42+
self,
43+
G: Graph,
44+
graph_filter: Optional[Dict[str, Any]] = None,
45+
mlflow_experiment_name: Optional[str] = None,
46+
**algo_config: Any,
47+
) -> Series:
48+
if graph_filter is None:
49+
# Take full graph if no filter provided
50+
node_filter = G.node_properties().to_dict()
51+
rel_filter = G.relationship_properties().to_dict()
52+
graph_filter = {"node_filter": node_filter, "rel_filter": rel_filter}
53+
54+
graph_config = {"name": G.name()}
55+
graph_config.update(graph_filter)
56+
57+
config = {
58+
"user_name": "DUMMY_USER",
59+
"task": "FASTPATH",
60+
"task_config": {
61+
"graph_config": graph_config,
62+
"task_config": algo_config,
63+
"stream_node_results": True,
64+
},
65+
"encrypted_db_password": self._encrypted_db_password,
66+
"graph_arrow_uri": self._arrow_uri,
67+
}
68+
69+
if mlflow_experiment_name is not None:
70+
config["task_config"]["mlflow"] = {
71+
"config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name}
72+
}
73+
74+
job_id = self._start_job(config)
75+
76+
self._wait_for_job(job_id)
77+
78+
return Series({"status": "finished"})
79+
80+
# return self._stream_results(job_id)
81+
82+
def _start_job(self, config: Dict[str, Any]) -> str:
83+
res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config)
84+
res.raise_for_status()
85+
job_id = res.json()["job_id"]
86+
logging.info(f"Job with ID '{job_id}' started")
87+
88+
return job_id
89+
90+
def _wait_for_job(self, job_id: str) -> None:
91+
while True:
92+
time.sleep(1)
93+
94+
res = requests.get(f"{self._compute_cluster_web_uri}/api/machine-learning/status/{job_id}")
95+
96+
res_json = res.json()
97+
if res_json["job_status"] == "exited":
98+
logging.info("FastPath job completed!")
99+
return
100+
elif res_json["job_status"] == "failed":
101+
error = f"FastPath job failed with errors:{os.linesep}{os.linesep.join(res_json['errors'])}"
102+
if res.status_code == 400:
103+
raise ValueError(error)
104+
else:
105+
raise RuntimeError(error)
106+
107+
# def _stream_results(self, job_id: str) -> DataFrame:
108+
# client = pa.flight.connect(self._compute_cluster_arrow_uri)
109+
110+
# upload_descriptor = pa.flight.FlightDescriptor.for_path(f"{job_id}.nodes")
111+
# flight = client.get_flight_info(upload_descriptor)
112+
# reader = client.do_get(flight.endpoints[0].ticket)
113+
# read_table = reader.read_all()
114+
115+
# return read_table.to_pandas()

0 commit comments

Comments
 (0)