Skip to content

Commit 7a05bc6

Browse files
committed
Add doc string for filtered knn
1 parent 116aca4 commit 7a05bc6

File tree

3 files changed

+267
-42
lines changed

3 files changed

+267
-42
lines changed

graphdatascience/procedure_surface/api/similarity/knn_filtered_endpoints.py

Lines changed: 255 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616

1717
class KnnFilteredEndpoints(ABC):
18-
"""Base class for Filtered K-Nearest Neighbors endpoints."""
19-
2018
@abstractmethod
2119
def mutate(
2220
self,
@@ -44,7 +42,65 @@ def mutate(
4442
concurrency: Any | None = None,
4543
job_id: Any | None = None,
4644
) -> KnnMutateResult:
47-
"""Run filtered K-Nearest Neighbors in mutate mode."""
45+
"""
46+
Runs the Filtered K-Nearest Neighbors algorithm and stores the results as new relationships in the graph catalog.
47+
48+
The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset.
49+
50+
Parameters
51+
----------
52+
G : GraphV2
53+
The graph to run the algorithm on
54+
mutate_relationship_type : str
55+
The relationship type to use for the new relationships.
56+
mutate_property : str
57+
The relationship property to store the similarity score in.
58+
node_properties : str | list[str] | dict[str, str]
59+
The node properties to use for similarity computation.
60+
source_node_filter : str
61+
A Cypher expression to filter which nodes can be sources in the similarity computation.
62+
target_node_filter : str
63+
A Cypher expression to filter which nodes can be targets in the similarity computation.
64+
seed_target_nodes : bool | None, default=None
65+
Whether to use a seeded approach for target node selection.
66+
top_k : int | None, default=None
67+
The number of nearest neighbors to find for each node.
68+
similarity_cutoff : float | None, default=None
69+
The threshold for similarity scores.
70+
delta_threshold : float | None, default=None
71+
The threshold for convergence assessment.
72+
max_iterations : int | None, default=None
73+
The maximum number of iterations to run.
74+
sample_rate : float | None, default=None
75+
The sampling rate for the algorithm.
76+
perturbation_rate : float | None, default=None
77+
The rate at which to perturb the similarity graph.
78+
random_joins : int | None, default=None
79+
The number of random joins to perform.
80+
random_seed : int | None, default=None
81+
The seed for the random number generator.
82+
initial_sampler : Any | None, default=None
83+
The initial sampling strategy.
84+
relationship_types : list[str] | None, default=None
85+
Filter on relationship types.
86+
node_labels : list[str] | None, default=None
87+
Filter on node labels.
88+
sudo : bool | None, default=None
89+
Run the algorithm with elevated privileges.
90+
log_progress : bool, default=True
91+
Whether to log progress.
92+
username : str | None, default=None
93+
Username for the operation.
94+
concurrency : Any | None, default=None
95+
Concurrency configuration.
96+
job_id : Any | None, default=None
97+
Job ID for the operation.
98+
99+
Returns
100+
-------
101+
KnnMutateResult
102+
Object containing metadata from the execution.
103+
"""
48104
...
49105

50106
@abstractmethod
@@ -72,7 +128,61 @@ def stats(
72128
concurrency: Any | None = None,
73129
job_id: Any | None = None,
74130
) -> KnnStatsResult:
75-
"""Run filtered K-Nearest Neighbors in stats mode."""
131+
"""
132+
Runs the Filtered K-Nearest Neighbors algorithm and returns execution statistics.
133+
134+
The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset.
135+
136+
Parameters
137+
----------
138+
G : GraphV2
139+
The graph to run the algorithm on
140+
node_properties : str | list[str] | dict[str, str]
141+
The node properties to use for similarity computation.
142+
source_node_filter : str
143+
A Cypher expression to filter which nodes can be sources in the similarity computation.
144+
target_node_filter : str
145+
A Cypher expression to filter which nodes can be targets in the similarity computation.
146+
seed_target_nodes : bool | None, default=None
147+
Whether to use a seeded approach for target node selection.
148+
top_k : int | None, default=None
149+
The number of nearest neighbors to find for each node.
150+
similarity_cutoff : float | None, default=None
151+
The threshold for similarity scores.
152+
delta_threshold : float | None, default=None
153+
The threshold for convergence assessment.
154+
max_iterations : int | None, default=None
155+
The maximum number of iterations to run.
156+
sample_rate : float | None, default=None
157+
The sampling rate for the algorithm.
158+
perturbation_rate : float | None, default=None
159+
The rate at which to perturb the similarity graph.
160+
random_joins : int | None, default=None
161+
The number of random joins to perform.
162+
random_seed : int | None, default=None
163+
The seed for the random number generator.
164+
initial_sampler : Any | None, default=None
165+
The initial sampling strategy.
166+
relationship_types : list[str] | None, default=None
167+
Filter on relationship types.
168+
node_labels : list[str] | None, default=None
169+
Filter on node labels.
170+
sudo : bool | None, default=None
171+
Run the algorithm with elevated privileges.
172+
log_progress : bool, default=True
173+
Whether to log progress.
174+
username : str | None, default=None
175+
Username for the operation.
176+
concurrency : Any | None, default=None
177+
Concurrency configuration.
178+
job_id : Any | None, default=None
179+
Job ID for the operation.
180+
181+
Returns
182+
-------
183+
KnnStatsResult
184+
Object containing execution statistics and algorithm-specific results.
185+
"""
76186
...
77187

78188
@abstractmethod
@@ -100,7 +210,61 @@ def stream(
100210
concurrency: Any | None = None,
101211
job_id: Any | None = None,
102212
) -> DataFrame:
103-
"""Run filtered K-Nearest Neighbors in stream mode."""
213+
"""
214+
Runs the Filtered K-Nearest Neighbors algorithm and returns the result as a DataFrame.
215+
216+
The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset.
217+
218+
Parameters
219+
----------
220+
G : GraphV2
221+
The graph to run the algorithm on
222+
node_properties : str | list[str] | dict[str, str]
223+
The node properties to use for similarity computation.
224+
source_node_filter : str
225+
A Cypher expression to filter which nodes can be sources in the similarity computation.
226+
target_node_filter : str
227+
A Cypher expression to filter which nodes can be targets in the similarity computation.
228+
seed_target_nodes : bool | None, default=None
229+
Whether to use a seeded approach for target node selection.
230+
top_k : int | None, default=None
231+
The number of nearest neighbors to find for each node.
232+
similarity_cutoff : float | None, default=None
233+
The threshold for similarity scores.
234+
delta_threshold : float | None, default=None
235+
The threshold for convergence assessment.
236+
max_iterations : int | None, default=None
237+
The maximum number of iterations to run.
238+
sample_rate : float | None, default=None
239+
The sampling rate for the algorithm.
240+
perturbation_rate : float | None, default=None
241+
The rate at which to perturb the similarity graph.
242+
random_joins : int | None, default=None
243+
The number of random joins to perform.
244+
random_seed : int | None, default=None
245+
The seed for the random number generator.
246+
initial_sampler : Any | None, default=None
247+
The initial sampling strategy.
248+
relationship_types : list[str] | None, default=None
249+
Filter on relationship types.
250+
node_labels : list[str] | None, default=None
251+
Filter on node labels.
252+
sudo : bool | None, default=None
253+
Run the algorithm with elevated privileges.
254+
log_progress : bool, default=True
255+
Whether to log progress.
256+
username : str | None, default=None
257+
Username for the operation.
258+
concurrency : Any | None, default=None
259+
Concurrency configuration.
260+
job_id : Any | None, default=None
261+
Job ID for the operation.
262+
263+
Returns
264+
-------
265+
DataFrame
266+
The similarity results as a DataFrame with columns 'node1', 'node2', and 'similarity'.
267+
"""
104268
...
105269

106270
@abstractmethod
@@ -132,7 +296,69 @@ def write(
132296
concurrency: Any | None = None,
133297
job_id: Any | None = None,
134298
) -> KnnWriteResult:
135-
"""Run filtered K-Nearest Neighbors in write mode."""
299+
"""
300+
Runs the Filtered K-Nearest Neighbors algorithm and writes the results back to the database.
301+
302+
The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset.
303+
304+
Parameters
305+
----------
306+
G : GraphV2
307+
The graph to run the algorithm on
308+
write_relationship_type : str
309+
The relationship type to use for the new relationships.
310+
write_property : str
311+
The relationship property to store the similarity score in.
312+
node_properties : str | list[str] | dict[str, str]
313+
The node properties to use for similarity computation.
314+
source_node_filter : str
315+
A Cypher expression to filter which nodes can be sources in the similarity computation.
316+
target_node_filter : str
317+
A Cypher expression to filter which nodes can be targets in the similarity computation.
318+
seed_target_nodes : bool | None, default=None
319+
Whether to use a seeded approach for target node selection.
320+
top_k : int | None, default=None
321+
The number of nearest neighbors to find for each node.
322+
similarity_cutoff : float | None, default=None
323+
The threshold for similarity scores.
324+
delta_threshold : float | None, default=None
325+
The threshold for convergence assessment.
326+
max_iterations : int | None, default=None
327+
The maximum number of iterations to run.
328+
sample_rate : float | None, default=None
329+
The sampling rate for the algorithm.
330+
perturbation_rate : float | None, default=None
331+
The rate at which to perturb the similarity graph.
332+
random_joins : int | None, default=None
333+
The number of random joins to perform.
334+
random_seed : int | None, default=None
335+
The seed for the random number generator.
336+
initial_sampler : Any | None, default=None
337+
The initial sampling strategy.
338+
relationship_types : list[str] | None, default=None
339+
Filter on relationship types.
340+
node_labels : list[str] | None, default=None
341+
Filter on node labels.
342+
write_concurrency : int | None, default=None
343+
Concurrency for writing results.
344+
write_to_result_store : bool | None, default=None
345+
Whether to write results to the result store.
346+
sudo : bool | None, default=None
347+
Run the algorithm with elevated privileges.
348+
log_progress : bool, default=True
349+
Whether to log progress.
350+
username : str | None, default=None
351+
Username for the operation.
352+
concurrency : Any | None, default=None
353+
Concurrency configuration.
354+
job_id : Any | None, default=None
355+
Job ID for the operation.
356+
357+
Returns
358+
-------
359+
KnnWriteResult
360+
Object containing metadata from the execution.
361+
"""
136362
...
137363

138364
@abstractmethod
@@ -158,58 +384,55 @@ def estimate(
158384
username: str | None = None,
159385
concurrency: Any | None = None,
160386
) -> EstimationResult:
161-
"""Estimate filtered K-Nearest Neighbors execution requirements.
387+
"""
388+
Estimates the memory requirements for running the Filtered K-Nearest Neighbors algorithm.
389+
390+
The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset.
162391
163392
Parameters
164393
----------
165394
G : GraphV2 | dict[str, Any]
166395
The graph to run the algorithm on.
167-
node_properties : str | list[str]
396+
node_properties : str | list[str] | dict[str, str]
168397
The node properties to use for similarity computation.
169-
mutate_property : str
170-
The relationship property to store the similarity score in.
171-
mutate_relationship_type : str
172-
The relationship type to use for the new relationships.
173-
source_node_filter : str | None, default=None
398+
source_node_filter : str
174399
A Cypher expression to filter which nodes can be sources in the similarity computation.
175-
target_node_filter : str | None, default=None
400+
target_node_filter : str
176401
A Cypher expression to filter which nodes can be targets in the similarity computation.
177402
seed_target_nodes : bool | None, default=None
178403
Whether to use a seeded approach for target node selection.
404+
top_k : int | None, default=None
405+
The number of nearest neighbors to find for each node.
179406
similarity_cutoff : float | None, default=None
180407
The threshold for similarity scores.
181-
perturbation_rate : float | None, default=None
182-
The rate at which to perturb the similarity graph.
183408
delta_threshold : float | None, default=None
184409
The threshold for convergence assessment.
410+
max_iterations : int | None, default=None
411+
The maximum number of iterations to run.
185412
sample_rate : float | None, default=None
186413
The sampling rate for the algorithm.
414+
perturbation_rate : float | None, default=None
415+
The rate at which to perturb the similarity graph.
187416
random_joins : int | None, default=None
188417
The number of random joins to perform.
189-
initial_sampler : str | None, default=None
190-
The initial sampling strategy.
191-
max_iterations : int | None, default=None
192-
The maximum number of iterations to run.
193-
top_k : int | None, default=None
194-
The number of nearest neighbors to find for each node.
195418
random_seed : int | None, default=None
196419
The seed for the random number generator.
197-
concurrency : int | None, default=None
198-
Concurrency configuration.
199-
job_id : str | None, default=None
200-
Job ID for the operation.
201-
log_progress : bool | None, default=None
202-
Whether to log progress.
420+
initial_sampler : Any | None, default=None
421+
The initial sampling strategy.
422+
relationship_types : list[str] | None, default=None
423+
Filter on relationship types.
424+
node_labels : list[str] | None, default=None
425+
Filter on node labels.
203426
sudo : bool | None, default=None
204427
Run the algorithm with elevated privileges.
205428
username : str | None, default=None
206429
Username for the operation.
207-
**kwargs : Any
208-
Additional parameters.
430+
concurrency : Any | None, default=None
431+
Concurrency configuration.
209432
210433
Returns
211434
-------
212-
KnnMutateResult
213-
Object containing metadata from the execution.
435+
EstimationResult
436+
Object containing the estimated memory requirements.
214437
"""
215438
...

graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_arrow_endpoints.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ def test_knn_stream(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) ->
7373
top_k=2,
7474
)
7575

76-
# TODO the column names dont match the ones in the cypher endpoint
77-
assert set(result_df.columns) == {"sourceNodeId", "targetNodeId", "relationshipType", "similarity"}
76+
assert set(result_df.columns) == {"node1", "node2", "similarity"}
7877
assert len(result_df) == 2
7978

8079

graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_filtered_arrow_endpoints.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,20 @@ def test_stats(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph:
7070
assert result.configuration is not None
7171

7272

73+
@pytest.mark.skip(reason="SEGFAULT for custom metadata. tracked in GDSA-312")
7374
def test_stream_raises_not_implemented(
7475
knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2
7576
) -> None:
76-
with pytest.raises(NotImplementedError, match="Filtered KNN stream endpoint is not available via Arrow"):
77-
knn_filtered_endpoints.stream(
78-
sample_graph,
79-
node_properties="prop",
80-
top_k=2,
81-
source_node_filter="SourceNode",
82-
target_node_filter="TargetNode",
83-
)
77+
result_df = knn_filtered_endpoints.stream(
78+
G=sample_graph,
79+
node_properties=["prop"],
80+
top_k=2,
81+
source_node_filter="SourceNode",
82+
target_node_filter="TargetNode",
83+
)
84+
85+
assert set(result_df.columns) == {"node1", "node2", "similarity"}
86+
assert len(result_df) == 2
8487

8588

8689
def test_mutate(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:

0 commit comments

Comments
 (0)