Skip to content

Commit 95a8491

Browse files
committed
Next
1 parent 88dfca0 commit 95a8491

File tree

6 files changed

+76
-19
lines changed

6 files changed

+76
-19
lines changed

examples/kge-distmult.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,13 @@ def read_data():
9494

9595

9696
def put_data_in_db(dataset):
97-
for rel_split in tqdm(dataset, desc="Relationship"):
98-
for rel_type in tqdm(dataset[rel_split], mininterval=1, leave=False):
97+
res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes")
98+
if res['num_nodes'].values[0] > 0:
99+
print("Data already in db, number of nodes: ", res['num_nodes'].values[0])
100+
return
101+
pbar = tqdm(desc='Putting data in db', total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]))
102+
for rel_split in dataset:
103+
for rel_type in dataset[rel_split]:
99104
edges = dataset[rel_split][rel_type]
100105

101106
# MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)
@@ -109,6 +114,8 @@ def put_data_in_db(dataset):
109114
""",
110115
params={"ll": edges},
111116
)
117+
pbar.update(len(edges))
118+
pbar.close()
112119

113120
for rel_split in dataset:
114121
res = gds.run_cypher(
@@ -120,7 +127,7 @@ def put_data_in_db(dataset):
120127
print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships)
121128

122129

123-
# put_data_in_db(dataset)
130+
put_data_in_db(dataset)
124131

125132
ALL_RELS = dataset["TRAIN"].keys()
126133
gds.graph.drop("trainGraph", failIfMissing=False)
@@ -159,13 +166,18 @@ def inspect_graph(G):
159166
gds.set_compute_cluster_ip("localhost")
160167

161168
kkge = gds.kge
169+
kmodel = gds.kge.model
170+
171+
print(gds.debug.arrow())
162172

163173
gds.kge.model.train(
164174
G_train,
165-
scoring_function="distmult",
166-
num_epochs=10,
167-
embedding_dimension=100,
175+
scoring_function="DistMult",
176+
num_epochs=1,
177+
embedding_dimension=10,
168178
)
179+
180+
print('Finished training')
169181
#
170182
# node_projection = {"Entity": {"properties": "id"}}
171183
# relationship_projection = [

graphdatascience/graph_data_science.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,12 @@ def __init__(
8787
None if arrow is True else arrow,
8888
)
8989

90-
# if auth is not None:
91-
# with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
92-
# pub_key = rsa.PublicKey.load_pkcs1(f.read())
93-
# self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
90+
if auth is not None:
91+
with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
92+
pub_key = rsa.PublicKey.load_pkcs1(f.read())
93+
self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
94+
# self._encrypted_db_password = None
95+
9496

9597
self._compute_cluster_ip = None
9698

@@ -143,7 +145,6 @@ def fastpath(self) -> FastPathRunner:
143145

144146
@property
145147
def kge(self) -> KgeRunner:
146-
print("!kge")
147148
if not isinstance(self._query_runner, ArrowQueryRunner):
148149
raise ValueError("Running FastPath requires GDS with the Arrow server enabled")
149150
if self._compute_cluster_ip is None:
@@ -156,7 +157,7 @@ def kge(self) -> KgeRunner:
156157
self._server_version,
157158
self._compute_cluster_ip,
158159
self._encrypted_db_password,
159-
None,
160+
self._query_runner._gds_arrow_client._host + ":" + str(self._query_runner._gds_arrow_client._port),
160161
)
161162

162163
def __getattr__(self, attr: str) -> IndirectCallBuilder:

graphdatascience/model/kge_runner.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,21 @@ def __init__(
2727
encrypted_db_password: str,
2828
arrow_uri: str,
2929
):
30-
print("!init", flush=True)
3130
self._query_runner = query_runner
3231
self._namespace = namespace
3332
self._server_version = server_version
3433
self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005"
35-
self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815"
34+
# self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8491"
3635
self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080"
3736
self._encrypted_db_password = encrypted_db_password
3837
self._arrow_uri = arrow_uri
38+
print("KgeRunner __dict__:")
39+
print(self.__dict__)
3940

40-
@client_only_endpoint("gds.kge")
41+
@property
4142
def model(self):
42-
print("!model")
4343
return self
4444

45-
# @client_only_endpoint("gds.kge.model") and name is train
4645
# @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
4746
@client_only_endpoint("gds.kge.model")
4847
def train(
@@ -53,7 +52,6 @@ def train(
5352
embedding_dimension,
5453
mlflow_experiment_name: Optional[str] = None,
5554
) -> Series:
56-
print("!!!train")
5755
graph_config = {"name": G.name()}
5856

5957
algo_config = {
@@ -85,6 +83,7 @@ def train(
8583
return Series({"status": "finished"})
8684

8785
def _start_job(self, config: Dict[str, Any]) -> str:
86+
print(config)
8887
res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config)
8988
res.raise_for_status()
9089
job_id = res.json()["job_id"]

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ def __init__(
9292
if tls_root_certs:
9393
client_options["tls_root_certs"] = tls_root_certs
9494

95+
print("location:")
96+
print(location)
97+
print("client_options:")
98+
print(client_options)
99+
print("auth:")
100+
print(auth)
95101
self._flight_client = flight.FlightClient(location, **client_options)
96102

97103
def connection_info(self) -> Tuple[str, int]:
@@ -128,6 +134,10 @@ def get_property(
128134
}
129135

130136
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
137+
print("ticket:")
138+
print(ticket)
139+
print("_flight_client")
140+
print(self._flight_client)
131141
get = self._flight_client.do_get(ticket)
132142
arrow_table = get.read_all()
133143

graphdatascience/tests/integration/test_graph_construct.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,36 @@ def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience)
558558

559559
with pytest.warns(DeprecationWarning):
560560
gds.alpha.graph.construct("hello", nodes, relationships)
561+
562+
@pytest.mark.enterprise
563+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 1, 0))
564+
def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience) -> None:
565+
nodes = DataFrame({"nodeId": [0, 1, 2, 3]})
566+
relationships = DataFrame({"sourceNodeId": [0, 1, 2, 3], "targetNodeId": [1, 2, 3, 0]})
567+
568+
with pytest.warns(DeprecationWarning):
569+
gds.alpha.graph.construct("hello", nodes, relationships)
570+
571+
572+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
573+
def test_roundtrip_with_arrow(gds: GraphDataScience) -> None:
574+
G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}})
575+
576+
rel_df = gds.graph.relationshipProperty.stream(G, "relX")
577+
node_df = gds.graph.nodeProperty.stream(G, "x")
578+
579+
G_2 = gds.graph.construct("arrowGraph", node_df, rel_df)
580+
581+
res = gds.graph.list()
582+
try:
583+
assert set(res['graphName'].tolist()) == {'g', 'arrowGraph'}
584+
assert G.node_count() == G_2.node_count()
585+
assert G.relationship_count() == G_2.relationship_count()
586+
finally:
587+
G_2.drop()
588+
589+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
590+
def test_drop_list_warning_reproduction(gds: GraphDataScience) -> None:
591+
G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}})
592+
res = gds.graph.list()
593+
assert res['graphName'].tolist() == ['g']

graphdatascience/tests/integration/test_graph_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def test_graph_relationships_stream_without_arrow(gds_without_arrow: GraphDataSc
854854

855855
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
856856
def test_graph_relationships_stream_with_arrow(gds: GraphDataScience) -> None:
857-
G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL", "REL2"])
857+
G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL_0", "REL2"])
858858

859859
if gds.server_version() >= ServerVersion(2, 5, 0):
860860
result = gds.graph.relationships.stream(G, ["REL", "REL2"])
@@ -1058,3 +1058,5 @@ def test_empty_relationships_stream(gds: GraphDataScience) -> None:
10581058

10591059
result = gds.graph.relationships.stream(G, ["SIMILAR"])
10601060
assert result.empty
1061+
1062+

0 commit comments

Comments
 (0)