Skip to content

Commit 60ec5da

Browse files
authored
Merge pull request #658 from orazve/rel-prop-error
Support relationship type as str in `gds.graph.relationshipProperties.stream`
2 parents b476162 + f1cd3eb commit 60ec5da

File tree

8 files changed

+80
-39
lines changed

8 files changed

+80
-39
lines changed

changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
## Bug fixes
1818

1919
* Fixed a bug which caused the auth token returned from the GDS Arrow Server was not correctly received.
20+
* Fixed a bug which didn't allow the user to specify `relationship_types` as a string in `gds.graph.relationshipProperties.stream()`.
21+
* Fixed a bug in `kge-predict-transe-pyg-train.ipynb` which now uses the `gds.graph.relationshipProperty.stream()` call and can correctly handle multiple relationships between the same pair of nodes. Issue ref: [#554](https://github.com/neo4j/graph-data-science-client/issues/554)
2022

2123
## Improvements
2224

doc/modules/ROOT/pages/tutorials/kge-predict-transe-pyg-train.adoc

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ version (2.5+ or later) installed.
2929

3030
Additionally, the following Python libraries are required:
3131

32-
* `graphdatascience`
33-
(https://neo4j.com/docs/graph-data-science-client/current/installation/[see
34-
documentation for installation instructions])
35-
* PyG
36-
(https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see
37-
PyG documentation for installation instructions])
32+
* `graphdatascience`,
33+
https://neo4j.com/docs/graph-data-science-client/current/installation/[see
34+
documentation for installation instructions]
35+
* `pytorch-geometric` version >= 2.5.0,
36+
https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see
37+
PyG documentation for installation instructions]
3838

3939
== Setup
4040

@@ -233,15 +233,13 @@ format it into a `Data` structure suitable for training with PyG.
233233
[source, python, role=no-test]
234234
----
235235
def create_data_from_graph(relationship_type):
236-
rels_tmp = gds.graph.relationshipProperties.stream(
237-
ttv_G, ["rel_id"], relationship_type, separate_property_columns=True
238-
)
236+
rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
239237
topology = [
240238
rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
241239
rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
242240
]
243241
edge_index = torch.tensor(topology, dtype=torch.long)
244-
edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)
242+
edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)
245243
data = Data(edge_index=edge_index, edge_type=edge_type)
246244
data.num_nodes = len(nodeId_to_id)
247245
display(data)
@@ -303,7 +301,7 @@ def train_model_with_pyg():
303301
head_index=data.edge_index[0],
304302
rel_type=data.edge_type,
305303
tail_index=data.edge_index[1],
306-
batch_size=20000,
304+
batch_size=1000,
307305
k=10,
308306
)
309307
@@ -316,12 +314,11 @@ def train_model_with_pyg():
316314
rank, hits = test(val_tensor_data)
317315
print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")
318316
319-
print(model)
320-
rank, hits_at_10 = test(test_tensor_data)
321-
print(f"Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}")
322-
323317
torch.save(model, f"./model_{epoch_count}.pt")
324318
319+
mean_rank, mrr, hits_at_k = test(test_tensor_data)
320+
print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")
321+
325322
return model
326323
----
327324

examples/kge-predict-transe-pyg-train.ipynb

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
"\n",
3838
"Additionally, the following Python libraries are required:\n",
3939
"\n",
40-
"- `graphdatascience` ([see documentation for installation instructions](https://neo4j.com/docs/graph-data-science-client/current/installation/))\n",
41-
"- PyG ([see PyG documentation for installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html))\n",
40+
"- `graphdatascience`, [see documentation for installation instructions](https://neo4j.com/docs/graph-data-science-client/current/installation/)\n",
41+
"- `pytorch-geometric` version >= 2.5.0, [see PyG documentation for installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)\n",
4242
"\n",
4343
"## Setup\n",
4444
"\n",
@@ -307,15 +307,13 @@
307307
"outputs": [],
308308
"source": [
309309
"def create_data_from_graph(relationship_type):\n",
310-
" rels_tmp = gds.graph.relationshipProperties.stream(\n",
311-
" ttv_G, [\"rel_id\"], relationship_type, separate_property_columns=True\n",
312-
" )\n",
310+
" rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, \"rel_id\", relationship_type)\n",
313311
" topology = [\n",
314312
" rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),\n",
315313
" rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),\n",
316314
" ]\n",
317315
" edge_index = torch.tensor(topology, dtype=torch.long)\n",
318-
" edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)\n",
316+
" edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)\n",
319317
" data = Data(edge_index=edge_index, edge_type=edge_type)\n",
320318
" data.num_nodes = len(nodeId_to_id)\n",
321319
" display(data)\n",
@@ -398,7 +396,7 @@
398396
" head_index=data.edge_index[0],\n",
399397
" rel_type=data.edge_type,\n",
400398
" tail_index=data.edge_index[1],\n",
401-
" batch_size=20000,\n",
399+
" batch_size=1000,\n",
402400
" k=10,\n",
403401
" )\n",
404402
"\n",
@@ -411,12 +409,11 @@
411409
" rank, hits = test(val_tensor_data)\n",
412410
" print(f\"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, \" f\"Val Hits@10: {hits:.4f}\")\n",
413411
"\n",
414-
" print(model)\n",
415-
" rank, hits_at_10 = test(test_tensor_data)\n",
416-
" print(f\"Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}\")\n",
417-
"\n",
418412
" torch.save(model, f\"./model_{epoch_count}.pt\")\n",
419413
"\n",
414+
" mean_rank, mrr, hits_at_k = test(test_tensor_data)\n",
415+
" print(f\"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}\")\n",
416+
"\n",
420417
" return model"
421418
]
422419
},

graphdatascience/graph/base_graph_proc_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from ..server_version.compatible_with import compatible_with
1616
from ..server_version.server_version import ServerVersion
1717
from .graph_entity_ops_runner import (
18-
GraphElementPropertyRunner,
1918
GraphLabelRunner,
2019
GraphNodePropertiesRunner,
2120
GraphNodePropertyRunner,
2221
GraphPropertyRunner,
2322
GraphRelationshipPropertiesRunner,
23+
GraphRelationshipPropertyRunner,
2424
GraphRelationshipRunner,
2525
GraphRelationshipsRunner,
2626
)
@@ -390,9 +390,9 @@ def nodeProperties(self) -> GraphNodePropertiesRunner:
390390
return GraphNodePropertiesRunner(self._query_runner, self._namespace, self._server_version)
391391

392392
@property
393-
def relationshipProperty(self) -> GraphElementPropertyRunner:
393+
def relationshipProperty(self) -> GraphRelationshipPropertyRunner:
394394
self._namespace += ".relationshipProperty"
395-
return GraphElementPropertyRunner(self._query_runner, self._namespace, self._server_version)
395+
return GraphRelationshipPropertyRunner(self._query_runner, self._namespace, self._server_version)
396396

397397
@property
398398
def relationshipProperties(self) -> GraphRelationshipPropertiesRunner:

graphdatascience/graph/graph_entity_ops_runner.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,6 @@ def _handle_properties(
7070
)
7171

7272

73-
class GraphElementPropertyRunner(GraphEntityOpsBaseRunner):
74-
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
75-
def stream(self, G: Graph, node_properties: str, node_labels: Strings = ["*"], **config: Any) -> DataFrame:
76-
self._namespace += ".stream"
77-
return self._handle_properties(G, node_properties, node_labels, config)
78-
79-
8073
class GraphNodePropertyRunner(GraphEntityOpsBaseRunner):
8174
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
8275
@filter_id_func_deprecation_warning()
@@ -197,6 +190,16 @@ def drop(self, G: Graph, node_properties: List[str], **config: Any) -> "Series[A
197190
).squeeze()
198191

199192

193+
class GraphRelationshipPropertyRunner(GraphEntityOpsBaseRunner):
194+
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
195+
def stream(
196+
self, G: Graph, relationship_property: str, relationship_types: Strings = ["*"], **config: Any
197+
) -> DataFrame:
198+
self._namespace += ".stream"
199+
relationship_types = [relationship_types] if isinstance(relationship_types, str) else relationship_types
200+
return self._handle_properties(G, relationship_property, relationship_types, config)
201+
202+
200203
class GraphRelationshipPropertiesRunner(GraphEntityOpsBaseRunner):
201204
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
202205
def stream(
@@ -209,6 +212,8 @@ def stream(
209212
) -> DataFrame:
210213
self._namespace += ".stream"
211214

215+
relationship_types = [relationship_types] if isinstance(relationship_types, str) else relationship_types
216+
212217
result = self._handle_properties(G, relationship_properties, relationship_types, config)
213218

214219
# new format was requested, but the query was run via Cypher

graphdatascience/tests/integration/test_graph_ops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,46 @@ def test_graph_relationshipProperties_stream_with_arrow_separate_property_column
695695
assert {e for e in result["relY"]} == {5, 6, 7}
696696

697697

698+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
699+
def test_graph_relationshipProperties_stream_with_arrow_rel_as_str(gds: GraphDataScience) -> None:
700+
G, _ = gds.graph.project(GRAPH_NAME, "*", {"REL": {"properties": ["relX", "relY"]}})
701+
702+
result = gds.graph.relationshipProperties.stream(G, ["relX", "relY"], "REL", concurrency=2)
703+
704+
assert list(result.keys()) == [
705+
"sourceNodeId",
706+
"targetNodeId",
707+
"relationshipType",
708+
"relationshipProperty",
709+
"propertyValue",
710+
]
711+
712+
x_values = result[result.relationshipProperty == "relX"]
713+
assert {e for e in x_values["propertyValue"]} == {4, 5, 6}
714+
y_values = result[result.relationshipProperty == "relY"]
715+
assert {e for e in y_values["propertyValue"]} == {5, 6, 7}
716+
717+
718+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 2, 0))
719+
def test_graph_relationshipProperties_stream_with_arrow_rel_as_str_sep(gds: GraphDataScience) -> None:
720+
G, _ = gds.graph.project(GRAPH_NAME, "*", {"REL": {"properties": ["relX", "relY"]}})
721+
722+
result = gds.graph.relationshipProperties.stream(
723+
G, ["relX", "relY"], "REL", separate_property_columns=True, concurrency=2
724+
)
725+
726+
assert list(result.keys()) == [
727+
"sourceNodeId",
728+
"targetNodeId",
729+
"relationshipType",
730+
"relX",
731+
"relY",
732+
]
733+
734+
assert {e for e in result["relX"]} == {4, 5, 6}
735+
assert {e for e in result["relY"]} == {5, 6, 7}
736+
737+
698738
def test_graph_streamRelationshipProperties_without_arrow(gds_without_arrow: GraphDataScience) -> None:
699739
G, _ = gds_without_arrow.graph.project(GRAPH_NAME, "*", {"REL": {"properties": ["relX", "relY"]}})
700740

graphdatascience/tests/unit/test_graph_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def test_graph_relationshipProperty_stream(runner: CollectingQueryRunner, gds: G
305305
assert runner.last_params() == {
306306
"graph_name": "g",
307307
"properties": "dummyProp",
308-
"entities": "dummyType",
308+
"entities": ["dummyType"],
309309
"config": {"concurrency": 2},
310310
}
311311

@@ -390,7 +390,7 @@ def test_graph_relationshipProperties_stream(runner: CollectingQueryRunner, gds:
390390
assert runner.last_params() == {
391391
"graph_name": "g",
392392
"properties": ["dummyProp"],
393-
"entities": "dummyType",
393+
"entities": ["dummyType"],
394394
"config": {"concurrency": 2},
395395
}
396396

requirements/dev/notebook-ci.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ scipy == 1.10.1
66
torch==2.1.0
77
torch-scatter==2.1.1
88
torch-sparse==0.6.17
9-
torch-geometric==2.3.1
9+
torch-geometric>=2.5.0

0 commit comments

Comments
 (0)