Skip to content

Commit cea569e

Browse files
committed
Next
1 parent ebfc92f commit cea569e

File tree

5 files changed

+68
-22
lines changed

5 files changed

+68
-22
lines changed

examples/kge-distmult.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,13 @@ def read_data():
9494

9595

9696
def put_data_in_db(dataset):
97-
for rel_split in tqdm(dataset, desc="Relationship"):
98-
for rel_type in tqdm(dataset[rel_split], mininterval=1, leave=False):
97+
res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes")
98+
if res['num_nodes'].values[0] > 0:
99+
print("Data already in db, number of nodes: ", res['num_nodes'].values[0])
100+
return
101+
pbar = tqdm(desc='Putting data in db', total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]))
102+
for rel_split in dataset:
103+
for rel_type in dataset[rel_split]:
99104
edges = dataset[rel_split][rel_type]
100105

101106
# MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)
@@ -109,6 +114,8 @@ def put_data_in_db(dataset):
109114
""",
110115
params={"ll": edges},
111116
)
117+
pbar.update(len(edges))
118+
pbar.close()
112119

113120
for rel_split in dataset:
114121
res = gds.run_cypher(
@@ -120,7 +127,7 @@ def put_data_in_db(dataset):
120127
print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships)
121128

122129

123-
# put_data_in_db(dataset)
130+
put_data_in_db(dataset)
124131

125132
ALL_RELS = dataset["TRAIN"].keys()
126133
gds.graph.drop("trainGraph", failIfMissing=False)
@@ -159,13 +166,18 @@ def inspect_graph(G):
159166
gds.set_compute_cluster_ip("localhost")
160167

161168
kkge = gds.kge
169+
kmodel = gds.kge.model
170+
171+
print(gds.debug.arrow())
162172

163173
gds.kge.model.train(
164174
G_train,
165-
scoring_function="distmult",
166-
num_epochs=10,
167-
embedding_dimension=100,
175+
scoring_function="DistMult",
176+
num_epochs=1,
177+
embedding_dimension=10,
168178
)
179+
180+
print('Finished training')
169181
#
170182
# node_projection = {"Entity": {"properties": "id"}}
171183
# relationship_projection = [

graphdatascience/graph_data_science.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ def __init__(
8787
None if arrow is True else arrow,
8888
)
8989

90-
if auth is not None:
91-
with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
92-
pub_key = rsa.PublicKey.load_pkcs1(f.read())
93-
self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
90+
# if auth is not None:
91+
# with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
92+
# pub_key = rsa.PublicKey.load_pkcs1(f.read())
93+
# self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
9494

9595
self._compute_cluster_ip = None
9696

@@ -143,20 +143,20 @@ def fastpath(self) -> FastPathRunner:
143143

144144
@property
145145
def kge(self) -> KgeRunner:
146-
print("!!!kge")
147-
# if not isinstance(self._query_runner, ArrowQueryRunner):
148-
# raise ValueError("Running FastPath requires GDS with the Arrow server enabled")
146+
print("!kge")
147+
if not isinstance(self._query_runner, ArrowQueryRunner):
148+
raise ValueError("Running FastPath requires GDS with the Arrow server enabled")
149149
if self._compute_cluster_ip is None:
150150
raise ValueError(
151151
"You must set a valid computer cluster ip with the method `set_compute_cluster_ip` to use this feature"
152152
)
153153
return KgeRunner(
154154
self._query_runner,
155-
"gds.kge",
155+
"gds.kge.model",
156156
self._server_version,
157157
self._compute_cluster_ip,
158158
self._encrypted_db_password,
159-
self._query_runner.uri,
159+
None,
160160
)
161161

162162
def __getattr__(self, attr: str) -> IndirectCallBuilder:

graphdatascience/model/kge_runner.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,21 @@ def __init__(
2727
encrypted_db_password: str,
2828
arrow_uri: str,
2929
):
30-
print("!init", flush=True)
3130
self._query_runner = query_runner
3231
self._namespace = namespace
3332
self._server_version = server_version
3433
self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005"
35-
self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815"
34+
# self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8491"
3635
self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080"
3736
self._encrypted_db_password = encrypted_db_password
3837
self._arrow_uri = arrow_uri
38+
print("KgeRunner __dict__:")
39+
print(self.__dict__)
3940

40-
@client_only_endpoint("gds.kge")
41+
@property
4142
def model(self):
42-
print("!model")
4343
return self
4444

45-
# @client_only_endpoint("gds.kge.model") and name is train
4645
# @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
4746
@client_only_endpoint("gds.kge.model")
4847
def train(
@@ -53,7 +52,6 @@ def train(
5352
embedding_dimension,
5453
mlflow_experiment_name: Optional[str] = None,
5554
) -> Series:
56-
print("!!!train")
5755
graph_config = {"name": G.name()}
5856

5957
algo_config = {
@@ -85,6 +83,7 @@ def train(
8583
return Series({"status": "finished"})
8684

8785
def _start_job(self, config: Dict[str, Any]) -> str:
86+
print(config)
8887
res = requests.post(f"{self._compute_cluster_web_uri}/api/machine-learning/start", json=config)
8988
res.raise_for_status()
9089
job_id = res.json()["job_id"]

graphdatascience/tests/integration/test_graph_construct.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,36 @@ def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience)
558558

559559
with pytest.warns(DeprecationWarning):
560560
gds.alpha.graph.construct("hello", nodes, relationships)
561+
562+
@pytest.mark.enterprise
563+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 1, 0))
564+
def test_graph_alpha_construct_backward_compat_with_arrow(gds: GraphDataScience) -> None:
565+
nodes = DataFrame({"nodeId": [0, 1, 2, 3]})
566+
relationships = DataFrame({"sourceNodeId": [0, 1, 2, 3], "targetNodeId": [1, 2, 3, 0]})
567+
568+
with pytest.warns(DeprecationWarning):
569+
gds.alpha.graph.construct("hello", nodes, relationships)
570+
571+
572+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
573+
def test_roundtrip_with_arrow(gds: GraphDataScience) -> None:
574+
G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}})
575+
576+
rel_df = gds.graph.relationshipProperty.stream(G, "relX")
577+
node_df = gds.graph.nodeProperty.stream(G, "x")
578+
579+
G_2 = gds.graph.construct("arrowGraph", node_df, rel_df)
580+
581+
res = gds.graph.list()
582+
try:
583+
assert set(res['graphName'].tolist()) == {'g', 'arrowGraph'}
584+
assert G.node_count() == G_2.node_count()
585+
assert G.relationship_count() == G_2.relationship_count()
586+
finally:
587+
G_2.drop()
588+
589+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
590+
def test_drop_list_warning_reproduction(gds: GraphDataScience) -> None:
591+
G, _ = gds.graph.project(GRAPH_NAME, {"Node": {"properties": ["x", "y"]}}, {"REL": {"properties": "relX"}})
592+
res = gds.graph.list()
593+
assert res['graphName'].tolist() == ['g']

graphdatascience/tests/integration/test_graph_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def test_graph_relationships_stream_without_arrow(gds_without_arrow: GraphDataSc
854854

855855
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
856856
def test_graph_relationships_stream_with_arrow(gds: GraphDataScience) -> None:
857-
G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL", "REL2"])
857+
G, _ = gds.graph.project(GRAPH_NAME, "*", ["REL_0", "REL2"])
858858

859859
if gds.server_version() >= ServerVersion(2, 5, 0):
860860
result = gds.graph.relationships.stream(G, ["REL", "REL2"])
@@ -1058,3 +1058,5 @@ def test_empty_relationships_stream(gds: GraphDataScience) -> None:
10581058

10591059
result = gds.graph.relationships.stream(G, ["SIMILAR"])
10601060
assert result.empty
1061+
1062+

0 commit comments

Comments
 (0)