Skip to content

Commit a99ee6e

Browse files
committed
Support stream mode
with the latest session release we can call the endpoints without a seg fault
1 parent 7a05bc6 commit a99ee6e

File tree

5 files changed

+69
-34
lines changed

5 files changed

+69
-34
lines changed

graphdatascience/procedure_surface/arrow/similarity/knn_arrow_endpoints.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
KnnWriteResult,
1616
)
1717
from graphdatascience.procedure_surface.arrow.relationship_endpoints_helper import RelationshipEndpointsHelper
18+
from graphdatascience.procedure_surface.arrow.stream_result_mapper import rename_similarity_stream_result
1819

1920

2021
class KnnArrowEndpoints(KnnEndpoints):
@@ -145,29 +146,30 @@ def stream(
145146
concurrency: Any | None = None,
146147
job_id: Any | None = None,
147148
) -> DataFrame:
148-
# config = self._endpoints_helper.create_base_config(
149-
# G,
150-
# nodeProperties=node_properties,
151-
# topK=top_k,
152-
# similarityCutoff=similarity_cutoff,
153-
# deltaThreshold=delta_threshold,
154-
# maxIterations=max_iterations,
155-
# sampleRate=sample_rate,
156-
# perturbationRate=perturbation_rate,
157-
# randomJoins=random_joins,
158-
# randomSeed=random_seed,
159-
# initialSampler=initial_sampler,
160-
# relationshipTypes=relationship_types,
161-
# nodeLabels=node_labels,
162-
# sudo=sudo,
163-
# logProgress=log_progress,
164-
# username=username,
165-
# concurrency=concurrency,
166-
# jobId=job_id,
167-
# )
168-
# return self._endpoints_helper.run_job_and_stream("v2/similarity.knn", G, config)
149+
config = self._endpoints_helper.create_base_config(
150+
G,
151+
nodeProperties=node_properties,
152+
topK=top_k,
153+
similarityCutoff=similarity_cutoff,
154+
deltaThreshold=delta_threshold,
155+
maxIterations=max_iterations,
156+
sampleRate=sample_rate,
157+
perturbationRate=perturbation_rate,
158+
randomJoins=random_joins,
159+
randomSeed=random_seed,
160+
initialSampler=initial_sampler,
161+
relationshipTypes=relationship_types,
162+
nodeLabels=node_labels,
163+
sudo=sudo,
164+
logProgress=log_progress,
165+
username=username,
166+
concurrency=concurrency,
167+
jobId=job_id,
168+
)
169+
result = self._endpoints_helper.run_job_and_stream("v2/similarity.knn", G, config)
170+
rename_similarity_stream_result(result)
169171

170-
raise NotImplementedError()
172+
return result
171173

172174
def write(
173175
self,

graphdatascience/procedure_surface/arrow/similarity/knn_filtered_arrow_endpoints.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints
1515
from graphdatascience.procedure_surface.arrow.relationship_endpoints_helper import RelationshipEndpointsHelper
16+
from graphdatascience.procedure_surface.arrow.stream_result_mapper import rename_similarity_stream_result
1617

1718

1819
class KnnFilteredArrowEndpoints(KnnFilteredEndpoints):
@@ -154,7 +155,34 @@ def stream(
154155
concurrency: Any | None = None,
155156
job_id: Any | None = None,
156157
) -> DataFrame:
157-
raise NotImplementedError("Filtered KNN stream endpoint is not available via Arrow")
158+
config = self._endpoints_helper.create_base_config(
159+
G,
160+
nodeProperties=node_properties,
161+
sourceNodeFilter=source_node_filter,
162+
targetNodeFilter=target_node_filter,
163+
seedTargetNodes=seed_target_nodes,
164+
nodeLabels=node_labels,
165+
relationshipTypes=relationship_types,
166+
similarityCutoff=similarity_cutoff,
167+
perturbationRate=perturbation_rate,
168+
deltaThreshold=delta_threshold,
169+
sampleRate=sample_rate,
170+
randomJoins=random_joins,
171+
initialSampler=initial_sampler,
172+
maxIterations=max_iterations,
173+
topK=top_k,
174+
randomSeed=random_seed,
175+
concurrency=concurrency,
176+
jobId=job_id,
177+
logProgress=log_progress,
178+
sudo=sudo,
179+
username=username,
180+
)
181+
182+
result = self._endpoints_helper.run_job_and_stream("v2/similarity.knn.filtered", G, config)
183+
rename_similarity_stream_result(result)
184+
185+
return result
158186

159187
def write(
160188
self,
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pandas import DataFrame
2+
3+
4+
def rename_similarity_stream_result(result: DataFrame) -> None:
5+
result.rename(columns={"sourceNodeId": "node1", "targetNodeId": "node2"}, inplace=True)
6+
if "relationshipType" in result.columns:
7+
result.drop(columns=["relationshipType"], inplace=True)

graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_arrow_endpoints.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def test_knn_stats(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) -> N
6464
assert "p50" in result.similarity_distribution
6565

6666

67-
@pytest.mark.skip(reason="SEGFAULT for custom metadata. tracked in GDSA-312")
6867
def test_knn_stream(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) -> None:
6968
"""Test KNN stream operation."""
7069
result_df = knn_endpoints.stream(
@@ -74,7 +73,7 @@ def test_knn_stream(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) ->
7473
)
7574

7675
assert set(result_df.columns) == {"node1", "node2", "similarity"}
77-
assert len(result_df) == 2
76+
assert len(result_df) == 8
7877

7978

8079
def test_knn_mutate(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) -> None:

graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_filtered_arrow_endpoints.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def knn_filtered_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[
4949
yield KnnFilteredArrowEndpoints(arrow_client)
5050

5151

52-
def test_stats(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:
52+
def test_knn_filtered_stats(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:
5353
result = knn_filtered_endpoints.stats(
5454
sample_graph,
5555
node_properties="prop",
@@ -70,10 +70,7 @@ def test_stats(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph:
7070
assert result.configuration is not None
7171

7272

73-
@pytest.mark.skip(reason="SEGFAULT for custom metadata. tracked in GDSA-312")
74-
def test_stream_raises_not_implemented(
75-
knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2
76-
) -> None:
73+
def test_knn_filtered_stream(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:
7774
result_df = knn_filtered_endpoints.stream(
7875
G=sample_graph,
7976
node_properties=["prop"],
@@ -83,10 +80,10 @@ def test_stream_raises_not_implemented(
8380
)
8481

8582
assert set(result_df.columns) == {"node1", "node2", "similarity"}
86-
assert len(result_df) == 2
83+
assert len(result_df) == 4
8784

8885

89-
def test_mutate(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:
86+
def test_knn_filtered_mutate(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:
9087
result = knn_filtered_endpoints.mutate(
9188
sample_graph,
9289
node_properties="prop",
@@ -110,7 +107,9 @@ def test_mutate(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph:
110107
assert result.configuration is not None
111108

112109

113-
def test_knn_write(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: GraphV2) -> None:
110+
def test_knn_filtered_write(
111+
arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: GraphV2
112+
) -> None:
114113
endpoints = KnnFilteredArrowEndpoints(
115114
arrow_client, write_back_client=RemoteWriteBackClient(arrow_client, query_runner), show_progress=False
116115
)
@@ -139,7 +138,7 @@ def test_knn_write(arrow_client: AuthenticatedArrowClient, query_runner: QueryRu
139138
assert result.configuration is not None
140139

141140

142-
def test_estimate(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:
141+
def test_knn_filtered_estimate(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:
143142
result = knn_filtered_endpoints.estimate(
144143
sample_graph,
145144
node_properties="prop",

0 commit comments

Comments
 (0)