Skip to content

Commit b304167

Browse files
committed
Copy from branch
1 parent 3e2d7e4 commit b304167

File tree

3 files changed

+919
-0
lines changed

3 files changed

+919
-0
lines changed

examples/FastPathExamples.ipynb

Lines changed: 760 additions & 0 deletions
Large diffs are not rendered by default.

graphdatascience/graph_data_science.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
from __future__ import annotations
22

3+
import pathlib
4+
import sys
35
from typing import Any, Dict, Optional, Tuple, Type, Union
46

7+
import rsa
58
from neo4j import Driver
69
from pandas import DataFrame
710

811
from .call_builder import IndirectCallBuilder
912
from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
1013
from .error.uncallable_namespace import UncallableNamespace
14+
from .model.fastpath_runner import FastPathRunner
1115
from .query_runner.arrow_query_runner import ArrowQueryRunner
1216
from .query_runner.neo4j_query_runner import Neo4jQueryRunner
1317
from .query_runner.query_runner import QueryRunner
@@ -81,8 +85,31 @@ def __init__(
8185
None if arrow is True else arrow,
8286
)
8387

88+
if auth is not None:
89+
with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
90+
pub_key = rsa.PublicKey.load_pkcs1(f.read())
91+
self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
92+
93+
self._compute_cluster_ip = None
94+
8495
super().__init__(self._query_runner, "gds", self._server_version)
8596

97+
def set_compute_cluster_ip(self, ip: str) -> None:
98+
self._compute_cluster_ip = ip
99+
100+
@staticmethod
101+
def _path(package: str, resource: str) -> pathlib.Path:
102+
if sys.version_info >= (3, 9):
103+
from importlib.resources import files
104+
105+
# files() returns a Traversable, but usages require a Path object
106+
return pathlib.Path(str(files(package) / resource))
107+
else:
108+
from importlib.resources import path
109+
110+
# we dont want to use a context manager here, so we need to call __enter__ manually
111+
return path(package, resource).__enter__()
112+
86113
@property
87114
def graph(self) -> GraphProcRunner:
88115
return GraphProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
@@ -95,6 +122,23 @@ def alpha(self) -> AlphaEndpoints:
95122
def beta(self) -> BetaEndpoints:
96123
return BetaEndpoints(self._query_runner, "gds.beta", self._server_version)
97124

125+
@property
126+
def fastpath(self) -> FastPathRunner:
127+
if not isinstance(self._query_runner, ArrowQueryRunner):
128+
raise ValueError("Running FastPath requires GDS with the Arrow server enabled")
129+
if self._compute_cluster_ip is None:
130+
raise ValueError(
131+
"You must set a valid computer cluster ip with the method `set_compute_cluster_ip` to use this feature"
132+
)
133+
return FastPathRunner(
134+
self._query_runner,
135+
"gds.fastpath",
136+
self._server_version,
137+
self._compute_cluster_ip,
138+
self._encrypted_db_password,
139+
self._query_runner.uri,
140+
)
141+
98142
def __getattr__(self, attr: str) -> IndirectCallBuilder:
99143
return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version)
100144

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import logging
2+
import os
3+
import time
4+
from typing import Any, Dict, Optional
5+
6+
import requests
7+
from pandas import Series
8+
9+
from ..error.client_only_endpoint import client_only_endpoint
10+
from ..error.illegal_attr_checker import IllegalAttrChecker
11+
from ..error.uncallable_namespace import UncallableNamespace
12+
from ..graph.graph_object import Graph
13+
from ..query_runner.query_runner import QueryRunner
14+
from ..server_version.compatible_with import compatible_with
15+
from ..server_version.server_version import ServerVersion
16+
17+
logging.basicConfig(level=logging.INFO)
18+
19+
20+
class FastPathRunner(UncallableNamespace, IllegalAttrChecker):
21+
def __init__(
22+
self,
23+
query_runner: QueryRunner,
24+
namespace: str,
25+
server_version: ServerVersion,
26+
compute_cluster_ip: str,
27+
encrypted_db_password: str,
28+
arrow_uri: str,
29+
):
30+
self._query_runner = query_runner
31+
self._namespace = namespace
32+
self._server_version = server_version
33+
self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005"
34+
self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815"
35+
self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080"
36+
self._encrypted_db_password = encrypted_db_password
37+
self._arrow_uri = arrow_uri
38+
39+
@compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
40+
@client_only_endpoint("gds.fastpath")
41+
def mutate(
42+
self,
43+
G: Graph,
44+
graph_filter: Optional[Dict[str, Any]] = None,
45+
mlflow_experiment_name: Optional[str] = None,
46+
**algo_config: Any,
47+
) -> Series:
48+
if graph_filter is None:
49+
# Take full graph if no filter provided
50+
node_filter = G.node_properties().to_dict()
51+
rel_filter = G.relationship_properties().to_dict()
52+
graph_filter = {"node_filter": node_filter, "rel_filter": rel_filter}
53+
54+
graph_config = {"name": G.name()}
55+
graph_config.update(graph_filter)
56+
57+
config = {
58+
"user_name": "DUMMY_USER",
59+
"task": "FASTPATH",
60+
"task_config": {
61+
"graph_config": graph_config,
62+
"task_config": algo_config,
63+
"stream_node_results": True,
64+
},
65+
"encrypted_db_password": self._encrypted_db_password,
66+
"graph_arrow_uri": self._arrow_uri,
67+
}
68+
69+
if mlflow_experiment_name is not None:
70+
config["task_config"]["mlflow"] = {
71+
"config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name}
72+
}
73+
74+
job_id = self._start_job(config)
75+
76+
self._wait_for_job(job_id)
77+
78+
return Series({"status": "finished"})
79+
80+
# return self._stream_results(job_id)
81+
82+
def _start_job(self, config: Dict[str, Any]) -> str:
83+
res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config)
84+
res.raise_for_status()
85+
job_id = res.json()["job_id"]
86+
logging.info(f"Job with ID '{job_id}' started")
87+
88+
return job_id
89+
90+
def _wait_for_job(self, job_id: str) -> None:
91+
while True:
92+
time.sleep(1)
93+
94+
res = requests.get(f"{self._compute_cluster_web_uri}/api/machine-learning/status/{job_id}")
95+
96+
res_json = res.json()
97+
if res_json["job_status"] == "exited":
98+
logging.info("FastPath job completed!")
99+
return
100+
elif res_json["job_status"] == "failed":
101+
error = f"FastPath job failed with errors:{os.linesep}{os.linesep.join(res_json['errors'])}"
102+
if res.status_code == 400:
103+
raise ValueError(error)
104+
else:
105+
raise RuntimeError(error)
106+
107+
# def _stream_results(self, job_id: str) -> DataFrame:
108+
# client = pa.flight.connect(self._compute_cluster_arrow_uri)
109+
110+
# upload_descriptor = pa.flight.FlightDescriptor.for_path(f"{job_id}.nodes")
111+
# flight = client.get_flight_info(upload_descriptor)
112+
# reader = client.do_get(flight.endpoints[0].ticket)
113+
# read_table = reader.read_all()
114+
115+
# return read_table.to_pandas()

0 commit comments

Comments
 (0)