Skip to content

Commit 186b480

Browse files
committed
Move filtered knn behind knn endpoints
1 parent 048497c commit 186b480

File tree

11 files changed

+86
-69
lines changed

11 files changed

+86
-69
lines changed

graphdatascience/procedure_surface/api/similarity/knn_endpoints.py

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,21 @@
55

66
from pandas import DataFrame
77

8-
from graphdatascience.procedure_surface.api.base_result import BaseResult
98
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
109
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
10+
from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints
11+
from graphdatascience.procedure_surface.api.similarity.knn_results import (
12+
KnnMutateResult,
13+
KnnStatsResult,
14+
KnnWriteResult,
15+
)
1116

1217

1318
class KnnEndpoints(ABC):
19+
@abstractmethod
20+
def filtered(self) -> KnnFilteredEndpoints:
21+
pass
22+
1423
@abstractmethod
1524
def mutate(
1625
self,
@@ -374,44 +383,3 @@ def estimate(
374383
EstimationResult
375384
Object containing the estimated memory requirements.
376385
"""
377-
378-
379-
class KnnMutateResult(BaseResult):
380-
pre_processing_millis: int
381-
compute_millis: int
382-
mutate_millis: int
383-
post_processing_millis: int
384-
nodes_compared: int
385-
relationships_written: int
386-
similarity_distribution: dict[str, int | float]
387-
did_converge: bool
388-
ran_iterations: int
389-
node_pairs_considered: int
390-
configuration: dict[str, Any]
391-
392-
393-
class KnnStatsResult(BaseResult):
394-
pre_processing_millis: int
395-
compute_millis: int
396-
post_processing_millis: int
397-
nodes_compared: int
398-
similarity_pairs: int
399-
similarity_distribution: dict[str, int | float]
400-
did_converge: bool
401-
ran_iterations: int
402-
node_pairs_considered: int
403-
configuration: dict[str, Any]
404-
405-
406-
class KnnWriteResult(BaseResult):
407-
pre_processing_millis: int
408-
compute_millis: int
409-
write_millis: int
410-
post_processing_millis: int
411-
nodes_compared: int
412-
relationships_written: int
413-
did_converge: bool
414-
ran_iterations: int
415-
node_pairs_considered: int
416-
similarity_distribution: dict[str, int | float]
417-
configuration: dict[str, Any]

graphdatascience/procedure_surface/api/similarity/knn_filtered_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
99
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
10-
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import (
10+
from graphdatascience.procedure_surface.api.similarity.knn_results import (
1111
KnnMutateResult,
1212
KnnStatsResult,
1313
KnnWriteResult,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Any
2+
3+
from graphdatascience.procedure_surface.api.base_result import BaseResult
4+
5+
6+
class KnnMutateResult(BaseResult):
7+
pre_processing_millis: int
8+
compute_millis: int
9+
mutate_millis: int
10+
post_processing_millis: int
11+
nodes_compared: int
12+
relationships_written: int
13+
similarity_distribution: dict[str, int | float]
14+
did_converge: bool
15+
ran_iterations: int
16+
node_pairs_considered: int
17+
configuration: dict[str, Any]
18+
19+
20+
class KnnStatsResult(BaseResult):
21+
pre_processing_millis: int
22+
compute_millis: int
23+
post_processing_millis: int
24+
nodes_compared: int
25+
similarity_pairs: int
26+
similarity_distribution: dict[str, int | float]
27+
did_converge: bool
28+
ran_iterations: int
29+
node_pairs_considered: int
30+
configuration: dict[str, Any]
31+
32+
33+
class KnnWriteResult(BaseResult):
34+
pre_processing_millis: int
35+
compute_millis: int
36+
write_millis: int
37+
post_processing_millis: int
38+
nodes_compared: int
39+
relationships_written: int
40+
did_converge: bool
41+
ran_iterations: int
42+
node_pairs_considered: int
43+
similarity_distribution: dict[str, int | float]
44+
configuration: dict[str, Any]

graphdatascience/procedure_surface/arrow/relationship_endpoints_helper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from graphdatascience.procedure_surface.arrow.endpoints_helper_base import EndpointsHelperBase
55

66

7-
# TODO find common parts with node_property_endpoints and refactor into a base class
87
class RelationshipEndpointsHelper(EndpointsHelperBase):
98
"""
109
Helper class for Arrow algorithm endpoints that work with relationships.

graphdatascience/procedure_surface/arrow/similarity/knn_arrow_endpoints.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
99
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
1010
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
11-
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import (
12-
KnnEndpoints,
11+
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import KnnEndpoints
12+
from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints
13+
from graphdatascience.procedure_surface.api.similarity.knn_results import (
1314
KnnMutateResult,
1415
KnnStatsResult,
1516
KnnWriteResult,
1617
)
1718
from graphdatascience.procedure_surface.arrow.relationship_endpoints_helper import RelationshipEndpointsHelper
19+
from graphdatascience.procedure_surface.arrow.similarity.knn_filtered_arrow_endpoints import KnnFilteredArrowEndpoints
1820
from graphdatascience.procedure_surface.arrow.stream_result_mapper import rename_similarity_stream_result
1921

2022

@@ -29,6 +31,14 @@ def __init__(
2931
arrow_client, write_back_client=write_back_client, show_progress=show_progress
3032
)
3133

34+
@property
35+
def filtered(self) -> KnnFilteredEndpoints:
36+
return KnnFilteredArrowEndpoints(
37+
self._endpoints_helper._arrow_client,
38+
self._endpoints_helper._write_back_client,
39+
self._endpoints_helper._show_progress,
40+
)
41+
3242
def mutate(
3343
self,
3444
G: GraphV2,

graphdatascience/procedure_surface/arrow/similarity/knn_filtered_arrow_endpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
77
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
88
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
9-
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import (
9+
from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints
10+
from graphdatascience.procedure_surface.api.similarity.knn_results import (
1011
KnnMutateResult,
1112
KnnStatsResult,
1213
KnnWriteResult,
1314
)
14-
from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints
1515
from graphdatascience.procedure_surface.arrow.relationship_endpoints_helper import RelationshipEndpointsHelper
1616
from graphdatascience.procedure_surface.arrow.stream_result_mapper import rename_similarity_stream_result
1717

graphdatascience/procedure_surface/cypher/similarity/knn_cypher_endpoints.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
from graphdatascience.call_parameters import CallParameters
88
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
99
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
10-
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import (
11-
KnnEndpoints,
10+
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import KnnEndpoints
11+
from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints
12+
from graphdatascience.procedure_surface.api.similarity.knn_results import (
1213
KnnMutateResult,
1314
KnnStatsResult,
1415
KnnWriteResult,
1516
)
1617
from graphdatascience.procedure_surface.cypher.estimation_utils import estimate_algorithm
18+
from graphdatascience.procedure_surface.cypher.similarity.knn_filtered_cypher_endpoints import (
19+
KnnFilteredCypherEndpoints,
20+
)
1721
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
1822
from graphdatascience.query_runner.query_runner import QueryRunner
1923

@@ -22,6 +26,10 @@ class KnnCypherEndpoints(KnnEndpoints):
2226
def __init__(self, query_runner: QueryRunner):
2327
self._query_runner = query_runner
2428

29+
@property
30+
def filtered(self) -> KnnFilteredEndpoints:
31+
return KnnFilteredCypherEndpoints(self._query_runner)
32+
2533
def mutate(
2634
self,
2735
G: GraphV2,

graphdatascience/procedure_surface/cypher/similarity/knn_filtered_cypher_endpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from graphdatascience.call_parameters import CallParameters
66
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
77
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
8-
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import (
8+
from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints
9+
from graphdatascience.procedure_surface.api.similarity.knn_results import (
910
KnnMutateResult,
1011
KnnStatsResult,
1112
KnnWriteResult,
1213
)
13-
from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints
1414
from graphdatascience.procedure_surface.cypher.estimation_utils import estimate_algorithm
1515
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
1616
from graphdatascience.query_runner.query_runner import QueryRunner

graphdatascience/session/session_v2_endpoints.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from graphdatascience.procedure_surface.api.community.triangle_count_endpoints import TriangleCountEndpoints
1717
from graphdatascience.procedure_surface.api.node_embedding.graphsage_endpoints import GraphSageEndpoints
1818
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import KnnEndpoints
19-
from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints
2019
from graphdatascience.procedure_surface.arrow.catalog_arrow_endpoints import CatalogArrowEndpoints
2120
from graphdatascience.procedure_surface.arrow.centrality.articlerank_arrow_endpoints import ArticleRankArrowEndpoints
2221
from graphdatascience.procedure_surface.arrow.centrality.articulationpoints_arrow_endpoints import (
@@ -173,10 +172,6 @@ def kmeans(self) -> KMeansEndpoints:
173172
def knn(self) -> KnnEndpoints:
174173
return KnnArrowEndpoints(self._arrow_client, self._write_back_client, show_progress=self._show_progress)
175174

176-
@property
177-
def knn_filtered(self) -> KnnFilteredEndpoints:
178-
return KnnFilteredArrowEndpoints(self._arrow_client, self._write_back_client, show_progress=self._show_progress)
179-
180175
@property
181176
def label_propagation(self) -> LabelPropagationEndpoints:
182177
return LabelPropagationArrowEndpoints(

graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_arrow_endpoints.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
77
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
88
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
9-
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import KnnWriteResult
109
from graphdatascience.procedure_surface.arrow.similarity.knn_arrow_endpoints import KnnArrowEndpoints
1110
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import (
1211
create_graph,
@@ -107,7 +106,6 @@ def test_knn_write(arrow_client: AuthenticatedArrowClient, query_runner: QueryRu
107106
G=db_graph, write_relationship_type="SIMILAR", write_property="similarity", node_properties=["prop"], top_k=2
108107
)
109108

110-
assert isinstance(result, KnnWriteResult)
111109
assert result.ran_iterations > 0
112110
assert result.did_converge
113111
assert result.pre_processing_millis >= 0

0 commit comments

Comments
 (0)