Skip to content

Commit 6f28c9c

Browse files
committed
Next
1 parent 7981a1d commit 6f28c9c

File tree

6 files changed

+76
-19
lines changed

6 files changed

+76
-19
lines changed

examples/kge-distmult.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,13 @@ def read_data():
9494

9595

9696
def put_data_in_db(dataset):
97-
for rel_split in tqdm(dataset, desc="Relationship"):
98-
for rel_type in tqdm(dataset[rel_split], mininterval=1, leave=False):
97+
res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes")
98+
if res['num_nodes'].values[0] > 0:
99+
print("Data already in db, number of nodes: ", res['num_nodes'].values[0])
100+
return
101+
pbar = tqdm(desc='Putting data in db', total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]))
102+
for rel_split in dataset:
103+
for rel_type in dataset[rel_split]:
99104
edges = dataset[rel_split][rel_type]
100105

101106
# MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)
@@ -109,6 +114,8 @@ def put_data_in_db(dataset):
109114
""",
110115
params={"ll": edges},
111116
)
117+
pbar.update(len(edges))
118+
pbar.close()
112119

113120
for rel_split in dataset:
114121
res = gds.run_cypher(
@@ -120,7 +127,7 @@ def put_data_in_db(dataset):
120127
print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships)
121128

122129

123-
# put_data_in_db(dataset)
130+
put_data_in_db(dataset)
124131

125132
ALL_RELS = dataset["TRAIN"].keys()
126133
gds.graph.drop("trainGraph", failIfMissing=False)
@@ -159,13 +166,18 @@ def inspect_graph(G):
159166
gds.set_compute_cluster_ip("localhost")
160167

161168
kkge = gds.kge
169+
kmodel = gds.kge.model
170+
171+
print(gds.debug.arrow())
162172

163173
gds.kge.model.train(
164174
G_train,
165-
scoring_function="distmult",
166-
num_epochs=10,
167-
embedding_dimension=100,
175+
scoring_function="DistMult",
176+
num_epochs=1,
177+
embedding_dimension=10,
168178
)
179+
180+
print('Finished training')
169181
#
170182
# node_projection = {"Entity": {"properties": "id"}}
171183
# relationship_projection = [

graphdatascience/graph_data_science.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ def __init__(
8686
None if arrow is True else arrow,
8787
)
8888

89-
# if auth is not None:
90-
# with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
91-
# pub_key = rsa.PublicKey.load_pkcs1(f.read())
92-
# self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
89+
if auth is not None:
90+
with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
91+
pub_key = rsa.PublicKey.load_pkcs1(f.read())
92+
self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
93+
# self._encrypted_db_password = None
94+
9395

9496
self._compute_cluster_ip = None
9597

@@ -142,7 +144,6 @@ def fastpath(self) -> FastPathRunner:
142144

143145
@property
144146
def kge(self) -> KgeRunner:
145-
print("!kge")
146147
if not isinstance(self._query_runner, ArrowQueryRunner):
147148
raise ValueError("Running FastPath requires GDS with the Arrow server enabled")
148149
if self._compute_cluster_ip is None:
@@ -155,7 +156,7 @@ def kge(self) -> KgeRunner:
155156
self._server_version,
156157
self._compute_cluster_ip,
157158
self._encrypted_db_password,
158-
None,
159+
self._query_runner._gds_arrow_client._host + ":" + str(self._query_runner._gds_arrow_client._port),
159160
)
160161

161162
def __getattr__(self, attr: str) -> IndirectCallBuilder:

graphdatascience/model/kge_runner.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,21 @@ def __init__(
2727
encrypted_db_password: str,
2828
arrow_uri: str,
2929
):
30-
print("!init", flush=True)
3130
self._query_runner = query_runner
3231
self._namespace = namespace
3332
self._server_version = server_version
3433
self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005"
35-
self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815"
34+
# self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8491"
3635
self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080"
3736
self._encrypted_db_password = encrypted_db_password
3837
self._arrow_uri = arrow_uri
38+
print("KgeRunner __dict__:")
39+
print(self.__dict__)
3940

40-
@client_only_endpoint("gds.kge")
41+
@property
4142
def model(self):
42-
print("!model")
4343
return self
4444

45-
# @client_only_endpoint("gds.kge.model") and name is train
4645
# @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
4746
@client_only_endpoint("gds.kge.model")
4847
def train(
@@ -53,7 +52,6 @@ def train(
5352
embedding_dimension,
5453
mlflow_experiment_name: Optional[str] = None,
5554
) -> Series:
56-
print("!!!train")
5755
graph_config = {"name": G.name()}
5856

5957
algo_config = {
@@ -85,6 +83,7 @@ def train(
8583
return Series({"status": "finished"})
8684

8785
def _start_job(self, config: Dict[str, Any]) -> str:
86+
print(config)
8887
res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config)
8988
res.raise_for_status()
9089
job_id = res.json()["job_id"]

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ def __init__(
8888
if tls_root_certs:
8989
client_options["tls_root_certs"] = tls_root_certs
9090

91+
print("location:")
92+
print(location)
93+
print("client_options:")
94+
print(client_options)
95+
print("auth:")
96+
print(auth)
9197
self._flight_client = flight.FlightClient(location, **client_options)
9298

9399
def connection_info(self) -> Tuple[str, int]:
@@ -124,6 +130,10 @@ def get_property(
124130
}
125131

126132
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
133+
print("ticket:")
134+
print(ticket)
135+
print("_flight_client")
136+
print(self._flight_client)
127137
get = self._flight_client.do_get(ticket)
128138
arrow_table = get.read_all()
129139

graphdatascience/tests/integration/test_graph_construct.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,36 @@ def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience)
558558

559559
with pytest.warns(DeprecationWarning):
560560
gds.alpha.graph.construct("hello", nodes, relationships)
561+
562+
@pytest.mark.enterprise
563+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 1, 0))
564+
def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience) -> None:
565+
nodes = DataFrame({"nodeId": [0, 1, 2, 3]})
566+
relationships = DataFrame({"sourceNodeId": [0, 1, 2, 3], "targetNodeId": [1, 2, 3, 0]})
567+
568+
with pytest.warns(DeprecationWarning):
569+
gds.alpha.graph.construct("hello", nodes, relationships)
570+
571+
572+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
573+
def test_roundtrip_with_arrow(gds: GraphDataScience) -> None:
574+
G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}})
575+
576+
rel_df = gds.graph.relationshipProperty.stream(G, "relX")
577+
node_df = gds.graph.nodeProperty.stream(G, "x")
578+
579+
G_2 = gds.graph.construct("arrowGraph", node_df, rel_df)
580+
581+
res = gds.graph.list()
582+
try:
583+
assert set(res['graphName'].tolist()) == {'g', 'arrowGraph'}
584+
assert G.node_count() == G_2.node_count()
585+
assert G.relationship_count() == G_2.relationship_count()
586+
finally:
587+
G_2.drop()
588+
589+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
590+
def test_drop_list_warning_reproduction(gds: GraphDataScience) -> None:
591+
G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}})
592+
res = gds.graph.list()
593+
assert res['graphName'].tolist() == ['g']

graphdatascience/tests/integration/test_graph_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ def test_graph_relationships_stream_without_arrow(gds_without_arrow: GraphDataSc
848848

849849
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
850850
def test_graph_relationships_stream_with_arrow(gds: GraphDataScience) -> None:
851-
G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL", "REL2"])
851+
G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL_0", "REL2"])
852852

853853
if gds.server_version() >= ServerVersion(2, 5, 0):
854854
result = gds.graph.relationships.stream(G, ["REL", "REL2"])
@@ -1052,3 +1052,5 @@ def test_empty_relationships_stream(gds: GraphDataScience) -> None:
10521052

10531053
result = gds.graph.relationships.stream(G, ["SIMILAR"])
10541054
assert result.empty
1055+
1056+

0 commit comments

Comments
 (0)