Skip to content

Commit aa132ee

Browse files
DarthMaxFlorentinD
authored andcommitted
Implement KMeans
1 parent 0c0b17e commit aa132ee

File tree

6 files changed

+1153
-4
lines changed

6 files changed

+1153
-4
lines changed
Lines changed: 394 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,394 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Any, List, Optional, Union
5+
6+
from pandas import DataFrame
7+
8+
from graphdatascience.procedure_surface.api.base_result import BaseResult
9+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
10+
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
11+
12+
13+
class KMeansEndpoints(ABC):
14+
@abstractmethod
15+
def mutate(
16+
self,
17+
G: GraphV2,
18+
node_property: str,
19+
mutate_property: str,
20+
*,
21+
compute_silhouette: Optional[bool] = False,
22+
concurrency: Optional[int] = 4,
23+
delta_threshold: Optional[float] = 0.05,
24+
initial_sampler: Optional[str] = "UNIFORM",
25+
job_id: Optional[str] = None,
26+
k: Optional[int] = 10,
27+
log_progress: bool = True,
28+
max_iterations: Optional[int] = 10,
29+
node_labels: Optional[List[str]] = None,
30+
number_of_restarts: Optional[int] = 1,
31+
random_seed: Optional[int] = None,
32+
relationship_types: Optional[List[str]] = None,
33+
seed_centroids: Optional[List[List[float]]] = None,
34+
sudo: Optional[bool] = False,
35+
username: Optional[str] = None,
36+
) -> KMeansMutateResult:
37+
"""
38+
Executes the K-Means algorithm and writes the results to the in-memory graph as node properties.
39+
40+
Parameters
41+
----------
42+
G : GraphV2
43+
The graph to run the algorithm on
44+
node_property : str
45+
The node property to use for clustering
46+
mutate_property : str
47+
The property name to store the community ID for each node
48+
compute_silhouette : Optional[bool], default=False
49+
Whether to compute silhouette coefficient
50+
concurrency : Optional[int], default=4
51+
The number of concurrent threads
52+
delta_threshold : Optional[float], default=0.05
53+
The convergence threshold for the algorithm
54+
initial_sampler : Optional[str], default="UNIFORM"
55+
The sampling method for initial centroids
56+
job_id : Optional[str], default=None
57+
An identifier for the job
58+
k : Optional[int], default=10
59+
The number of clusters
60+
log_progress : bool, default=True
61+
Whether to log progress
62+
max_iterations : Optional[int], default=10
63+
The maximum number of iterations
64+
node_labels : Optional[List[str]], default=None
65+
The node labels used to select nodes for this algorithm run
66+
number_of_restarts : Optional[int], default=1
67+
The number of times the algorithm should be restarted
68+
random_seed : Optional[int], default=None
69+
Random seed for reproducible results
70+
relationship_types : Optional[List[str]], default=None
71+
The relationship types used to select relationships for this algorithm run
72+
seed_centroids : Optional[List[List[float]]], default=None
73+
Initial centroids for the algorithm
74+
sudo : Optional[bool], default=False
75+
Override memory estimation limits
76+
username : Optional[str], default=None
77+
The username to attribute the procedure run to
78+
79+
Returns
80+
-------
81+
KMeansMutateResult
82+
Algorithm metrics and statistics
83+
"""
84+
pass
85+
86+
@abstractmethod
87+
def stats(
88+
self,
89+
G: GraphV2,
90+
node_property: str,
91+
*,
92+
compute_silhouette: Optional[bool] = False,
93+
concurrency: Optional[int] = 4,
94+
delta_threshold: Optional[float] = 0.05,
95+
initial_sampler: Optional[str] = "UNIFORM",
96+
job_id: Optional[str] = None,
97+
k: Optional[int] = 10,
98+
log_progress: bool = True,
99+
max_iterations: Optional[int] = 10,
100+
node_labels: Optional[List[str]] = None,
101+
number_of_restarts: Optional[int] = 1,
102+
random_seed: Optional[int] = None,
103+
relationship_types: Optional[List[str]] = None,
104+
seed_centroids: Optional[List[List[float]]] = None,
105+
sudo: Optional[bool] = False,
106+
username: Optional[str] = None,
107+
) -> KMeansStatsResult:
108+
"""
109+
Executes the K-Means algorithm and returns statistics.
110+
111+
Parameters
112+
----------
113+
G : GraphV2
114+
The graph to run the algorithm on
115+
node_property : str
116+
The node property to use for clustering
117+
compute_silhouette : Optional[bool], default=False
118+
Whether to compute silhouette coefficient
119+
concurrency : Optional[int], default=4
120+
The number of concurrent threads
121+
delta_threshold : Optional[float], default=0.05
122+
The convergence threshold for the algorithm
123+
initial_sampler : Optional[str], default="UNIFORM"
124+
The sampling method for initial centroids
125+
job_id : Optional[str], default=None
126+
An identifier for the job
127+
k : Optional[int], default=10
128+
The number of clusters
129+
log_progress : bool, default=True
130+
Whether to log progress
131+
max_iterations : Optional[int], default=10
132+
The maximum number of iterations
133+
node_labels : Optional[List[str]], default=None
134+
The node labels used to select nodes for this algorithm run
135+
number_of_restarts : Optional[int], default=1
136+
The number of times the algorithm should be restarted
137+
random_seed : Optional[int], default=None
138+
Random seed for reproducible results
139+
relationship_types : Optional[List[str]], default=None
140+
The relationship types used to select relationships for this algorithm run
141+
seed_centroids : Optional[List[List[float]]], default=None
142+
Initial centroids for the algorithm
143+
sudo : Optional[bool], default=False
144+
Override memory estimation limits
145+
username : Optional[str], default=None
146+
The username to attribute the procedure run to
147+
148+
Returns
149+
-------
150+
KMeansStatsResult
151+
Algorithm metrics and statistics
152+
"""
153+
pass
154+
155+
@abstractmethod
156+
def stream(
157+
self,
158+
G: GraphV2,
159+
node_property: str,
160+
*,
161+
compute_silhouette: Optional[bool] = False,
162+
concurrency: Optional[int] = 4,
163+
delta_threshold: Optional[float] = 0.05,
164+
initial_sampler: Optional[str] = "UNIFORM",
165+
job_id: Optional[str] = None,
166+
k: Optional[int] = 10,
167+
log_progress: bool = True,
168+
max_iterations: Optional[int] = 10,
169+
node_labels: Optional[List[str]] = None,
170+
number_of_restarts: Optional[int] = 1,
171+
random_seed: Optional[int] = None,
172+
relationship_types: Optional[List[str]] = None,
173+
seed_centroids: Optional[List[List[float]]] = None,
174+
sudo: Optional[bool] = False,
175+
username: Optional[str] = None,
176+
) -> DataFrame:
177+
"""
178+
Executes the K-Means algorithm and returns a stream of results.
179+
180+
Parameters
181+
----------
182+
G : GraphV2
183+
The graph to run the algorithm on
184+
node_property : str
185+
The node property to use for clustering
186+
compute_silhouette : Optional[bool], default=False
187+
Whether to compute silhouette coefficient
188+
concurrency : Optional[int], default=4
189+
The number of concurrent threads
190+
delta_threshold : Optional[float], default=0.05
191+
The convergence threshold for the algorithm
192+
initial_sampler : Optional[str], default="UNIFORM"
193+
The sampling method for initial centroids
194+
job_id : Optional[str], default=None
195+
An identifier for the job
196+
k : Optional[int], default=10
197+
The number of clusters
198+
log_progress : bool, default=True
199+
Whether to log progress
200+
max_iterations : Optional[int], default=10
201+
The maximum number of iterations
202+
node_labels : Optional[List[str]], default=None
203+
The node labels used to select nodes for this algorithm run
204+
number_of_restarts : Optional[int], default=1
205+
The number of times the algorithm should be restarted
206+
random_seed : Optional[int], default=None
207+
Random seed for reproducible results
208+
relationship_types : Optional[List[str]], default=None
209+
The relationship types used to select relationships for this algorithm run
210+
seed_centroids : Optional[List[List[float]]], default=None
211+
Initial centroids for the algorithm
212+
sudo : Optional[bool], default=False
213+
Override memory estimation limits
214+
username : Optional[str], default=None
215+
The username to attribute the procedure run to
216+
217+
Returns
218+
-------
219+
DataFrame
220+
DataFrame with the algorithm results containing nodeId, communityId, distanceFromCentroid, and silhouette
221+
"""
222+
pass
223+
224+
@abstractmethod
225+
def write(
226+
self,
227+
G: GraphV2,
228+
node_property: str,
229+
write_property: str,
230+
*,
231+
compute_silhouette: Optional[bool] = False,
232+
concurrency: Optional[int] = 4,
233+
delta_threshold: Optional[float] = 0.05,
234+
initial_sampler: Optional[str] = "UNIFORM",
235+
job_id: Optional[str] = None,
236+
k: Optional[int] = 10,
237+
log_progress: bool = True,
238+
max_iterations: Optional[int] = 10,
239+
node_labels: Optional[List[str]] = None,
240+
number_of_restarts: Optional[int] = 1,
241+
random_seed: Optional[int] = None,
242+
relationship_types: Optional[List[str]] = None,
243+
seed_centroids: Optional[List[List[float]]] = None,
244+
sudo: Optional[bool] = False,
245+
username: Optional[str] = None,
246+
write_concurrency: Optional[int] = None,
247+
write_to_result_store: Optional[bool] = False,
248+
) -> KMeansWriteResult:
249+
"""
250+
Executes the K-Means algorithm and writes the results back to the database.
251+
252+
Parameters
253+
----------
254+
G : GraphV2
255+
The graph to run the algorithm on
256+
node_property : str
257+
The node property to use for clustering
258+
write_property : str
259+
The property name to write the community IDs to
260+
compute_silhouette : Optional[bool], default=False
261+
Whether to compute silhouette coefficient
262+
concurrency : Optional[int], default=4
263+
The number of concurrent threads
264+
delta_threshold : Optional[float], default=0.05
265+
The convergence threshold for the algorithm
266+
initial_sampler : Optional[str], default="UNIFORM"
267+
The sampling method for initial centroids
268+
job_id : Optional[str], default=None
269+
An identifier for the job
270+
k : Optional[int], default=10
271+
The number of clusters
272+
log_progress : bool, default=True
273+
Whether to log progress
274+
max_iterations : Optional[int], default=10
275+
The maximum number of iterations
276+
node_labels : Optional[List[str]], default=None
277+
The node labels used to select nodes for this algorithm run
278+
number_of_restarts : Optional[int], default=1
279+
The number of times the algorithm should be restarted
280+
random_seed : Optional[int], default=None
281+
Random seed for reproducible results
282+
relationship_types : Optional[List[str]], default=None
283+
The relationship types used to select relationships for this algorithm run
284+
seed_centroids : Optional[List[List[float]]], default=None
285+
Initial centroids for the algorithm
286+
sudo : Optional[bool], default=False
287+
Override memory estimation limits
288+
username : Optional[str], default=None
289+
The username to attribute the procedure run to
290+
write_concurrency : Optional[int], default=None
291+
The number of concurrent threads for write operations
292+
write_to_result_store : Optional[bool], default=False
293+
Whether to write to the result store
294+
295+
Returns
296+
-------
297+
KMeansWriteResult
298+
Algorithm metrics and statistics
299+
"""
300+
pass
301+
302+
@abstractmethod
303+
def estimate(
304+
self,
305+
G: Union[GraphV2, dict[str, Any]],
306+
node_property: str,
307+
*,
308+
compute_silhouette: Optional[bool] = False,
309+
concurrency: Optional[int] = 4,
310+
delta_threshold: Optional[float] = 0.05,
311+
initial_sampler: Optional[str] = "UNIFORM",
312+
k: Optional[int] = 10,
313+
max_iterations: Optional[int] = 10,
314+
node_labels: Optional[List[str]] = None,
315+
number_of_restarts: Optional[int] = 1,
316+
random_seed: Optional[int] = None,
317+
relationship_types: Optional[List[str]] = None,
318+
seed_centroids: Optional[List[List[float]]] = None,
319+
) -> EstimationResult:
320+
"""
321+
Estimates the memory requirements for running the K-Means algorithm.
322+
323+
Parameters
324+
----------
325+
G : Union[GraphV2, dict[str, Any]]
326+
The graph or graph configuration to estimate for
327+
node_property : str
328+
The node property to use for clustering
329+
compute_silhouette : Optional[bool], default=False
330+
Whether to compute silhouette coefficient
331+
concurrency : Optional[int], default=4
332+
The number of concurrent threads
333+
delta_threshold : Optional[float], default=0.05
334+
The convergence threshold for the algorithm
335+
initial_sampler : Optional[str], default="UNIFORM"
336+
The sampling method for initial centroids
337+
k : Optional[int], default=10
338+
The number of clusters
339+
max_iterations : Optional[int], default=10
340+
The maximum number of iterations
341+
node_labels : Optional[List[str]], default=None
342+
The node labels used to select nodes for this algorithm run
343+
number_of_restarts : Optional[int], default=1
344+
The number of times the algorithm should be restarted
345+
random_seed : Optional[int], default=None
346+
Random seed for reproducible results
347+
relationship_types : Optional[List[str]], default=None
348+
The relationship types used to select relationships for this algorithm run
349+
seed_centroids : Optional[List[List[float]]], default=None
350+
Initial centroids for the algorithm
351+
352+
Returns
353+
-------
354+
EstimationResult
355+
The memory estimation result
356+
"""
357+
pass
358+
359+
360+
class KMeansMutateResult(BaseResult):
361+
average_distance_to_centroid: float
362+
average_silhouette: float
363+
centroids: List[Any]
364+
community_distribution: dict[str, Any]
365+
compute_millis: int
366+
configuration: dict[str, Any]
367+
mutate_millis: int
368+
node_properties_written: int
369+
post_processing_millis: int
370+
pre_processing_millis: int
371+
372+
373+
class KMeansStatsResult(BaseResult):
374+
average_distance_to_centroid: float
375+
average_silhouette: float
376+
centroids: List[Any]
377+
community_distribution: dict[str, Any]
378+
compute_millis: int
379+
configuration: dict[str, Any]
380+
post_processing_millis: int
381+
pre_processing_millis: int
382+
383+
384+
class KMeansWriteResult(BaseResult):
385+
average_distance_to_centroid: float
386+
average_silhouette: float
387+
centroids: List[Any]
388+
community_distribution: dict[str, Any]
389+
compute_millis: int
390+
configuration: dict[str, Any]
391+
node_properties_written: int
392+
post_processing_millis: int
393+
pre_processing_millis: int
394+
write_millis: int

0 commit comments

Comments
 (0)