Skip to content

Commit a8421ec

Browse files
DarthMaxFlorentinD
authored andcommitted
Implement LabelPropagation
1 parent aa132ee commit a8421ec

File tree

5 files changed

+1092
-0
lines changed

5 files changed

+1092
-0
lines changed
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Any, List, Optional, Union
5+
6+
from pandas import DataFrame
7+
8+
from graphdatascience.procedure_surface.api.base_result import BaseResult
9+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
10+
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
11+
12+
13+
class LabelPropagationEndpoints(ABC):
14+
@abstractmethod
15+
def mutate(
16+
self,
17+
G: GraphV2,
18+
mutate_property: str,
19+
*,
20+
concurrency: Optional[int] = 4,
21+
consecutive_ids: Optional[bool] = False,
22+
job_id: Optional[str] = None,
23+
log_progress: bool = True,
24+
max_iterations: Optional[int] = 10,
25+
node_labels: Optional[List[str]] = None,
26+
node_weight_property: Optional[str] = None,
27+
relationship_types: Optional[List[str]] = None,
28+
relationship_weight_property: Optional[str] = None,
29+
seed_property: Optional[str] = None,
30+
sudo: Optional[bool] = False,
31+
username: Optional[str] = None,
32+
) -> LabelPropagationMutateResult:
33+
"""
34+
Executes the Label Propagation algorithm and writes the results to the in-memory graph as node properties.
35+
36+
Parameters
37+
----------
38+
G : GraphV2
39+
The graph to run the algorithm on
40+
mutate_property : str
41+
The property name to store the community ID for each node
42+
concurrency : Optional[int], default=4
43+
The number of concurrent threads
44+
consecutive_ids : Optional[bool], default=False
45+
Whether to use consecutive community IDs starting from 0
46+
job_id : Optional[str], default=None
47+
An identifier for the job
48+
log_progress : bool, default=True
49+
Whether to log progress
50+
max_iterations : Optional[int], default=10
51+
The maximum number of iterations
52+
node_labels : Optional[List[str]], default=None
53+
The node labels used to select nodes for this algorithm run
54+
node_weight_property : Optional[str], default=None
55+
The property name for node weights
56+
relationship_types : Optional[List[str]], default=None
57+
The relationship types used to select relationships for this algorithm run
58+
relationship_weight_property : Optional[str], default=None
59+
The property name for relationship weights
60+
seed_property : Optional[str], default=None
61+
The property name containing seed values for initial community assignment
62+
sudo : Optional[bool], default=False
63+
Override memory estimation limits
64+
username : Optional[str], default=None
65+
The username to attribute the procedure run to
66+
67+
Returns
68+
-------
69+
LabelPropagationMutateResult
70+
Algorithm metrics and statistics
71+
"""
72+
pass
73+
74+
@abstractmethod
75+
def stats(
76+
self,
77+
G: GraphV2,
78+
*,
79+
concurrency: Optional[int] = 4,
80+
consecutive_ids: Optional[bool] = False,
81+
job_id: Optional[str] = None,
82+
log_progress: bool = True,
83+
max_iterations: Optional[int] = 10,
84+
node_labels: Optional[List[str]] = None,
85+
node_weight_property: Optional[str] = None,
86+
relationship_types: Optional[List[str]] = None,
87+
relationship_weight_property: Optional[str] = None,
88+
seed_property: Optional[str] = None,
89+
sudo: Optional[bool] = False,
90+
username: Optional[str] = None,
91+
) -> LabelPropagationStatsResult:
92+
"""
93+
Executes the Label Propagation algorithm and returns statistics.
94+
95+
Parameters
96+
----------
97+
G : GraphV2
98+
The graph to run the algorithm on
99+
concurrency : Optional[int], default=4
100+
The number of concurrent threads
101+
consecutive_ids : Optional[bool], default=False
102+
Whether to use consecutive community IDs starting from 0
103+
job_id : Optional[str], default=None
104+
An identifier for the job
105+
log_progress : bool, default=True
106+
Whether to log progress
107+
max_iterations : Optional[int], default=10
108+
The maximum number of iterations
109+
node_labels : Optional[List[str]], default=None
110+
The node labels used to select nodes for this algorithm run
111+
node_weight_property : Optional[str], default=None
112+
The property name for node weights
113+
relationship_types : Optional[List[str]], default=None
114+
The relationship types used to select relationships for this algorithm run
115+
relationship_weight_property : Optional[str], default=None
116+
The property name for relationship weights
117+
seed_property : Optional[str], default=None
118+
The property name containing seed values for initial community assignment
119+
sudo : Optional[bool], default=False
120+
Override memory estimation limits
121+
username : Optional[str], default=None
122+
The username to attribute the procedure run to
123+
124+
Returns
125+
-------
126+
LabelPropagationStatsResult
127+
Algorithm metrics and statistics
128+
"""
129+
pass
130+
131+
@abstractmethod
132+
def stream(
133+
self,
134+
G: GraphV2,
135+
*,
136+
concurrency: Optional[int] = 4,
137+
consecutive_ids: Optional[bool] = False,
138+
job_id: Optional[str] = None,
139+
log_progress: bool = True,
140+
max_iterations: Optional[int] = 10,
141+
min_community_size: Optional[int] = None,
142+
node_labels: Optional[List[str]] = None,
143+
node_weight_property: Optional[str] = None,
144+
relationship_types: Optional[List[str]] = None,
145+
relationship_weight_property: Optional[str] = None,
146+
seed_property: Optional[str] = None,
147+
sudo: Optional[bool] = False,
148+
username: Optional[str] = None,
149+
) -> DataFrame:
150+
"""
151+
Executes the Label Propagation algorithm and returns a stream of results.
152+
153+
Parameters
154+
----------
155+
G : GraphV2
156+
The graph to run the algorithm on
157+
concurrency : Optional[int], default=4
158+
The number of concurrent threads
159+
consecutive_ids : Optional[bool], default=False
160+
Whether to use consecutive community IDs starting from 0
161+
job_id : Optional[str], default=None
162+
An identifier for the job
163+
log_progress : bool, default=True
164+
Whether to log progress
165+
max_iterations : Optional[int], default=10
166+
The maximum number of iterations
167+
min_community_size : Optional[int], default=None
168+
Minimum community size to include in results
169+
node_labels : Optional[List[str]], default=None
170+
The node labels used to select nodes for this algorithm run
171+
node_weight_property : Optional[str], default=None
172+
The property name for node weights
173+
relationship_types : Optional[List[str]], default=None
174+
The relationship types used to select relationships for this algorithm run
175+
relationship_weight_property : Optional[str], default=None
176+
The property name for relationship weights
177+
seed_property : Optional[str], default=None
178+
The property name containing seed values for initial community assignment
179+
sudo : Optional[bool], default=False
180+
Override memory estimation limits
181+
username : Optional[str], default=None
182+
The username to attribute the procedure run to
183+
184+
Returns
185+
-------
186+
DataFrame
187+
DataFrame with the algorithm results containing nodeId and communityId
188+
"""
189+
pass
190+
191+
@abstractmethod
192+
def write(
193+
self,
194+
G: GraphV2,
195+
write_property: str,
196+
*,
197+
concurrency: Optional[int] = 4,
198+
consecutive_ids: Optional[bool] = False,
199+
job_id: Optional[str] = None,
200+
log_progress: bool = True,
201+
max_iterations: Optional[int] = 10,
202+
min_community_size: Optional[int] = None,
203+
node_labels: Optional[List[str]] = None,
204+
node_weight_property: Optional[str] = None,
205+
relationship_types: Optional[List[str]] = None,
206+
relationship_weight_property: Optional[str] = None,
207+
seed_property: Optional[str] = None,
208+
sudo: Optional[bool] = False,
209+
username: Optional[str] = None,
210+
write_concurrency: Optional[int] = None,
211+
write_to_result_store: Optional[bool] = False,
212+
) -> LabelPropagationWriteResult:
213+
"""
214+
Executes the Label Propagation algorithm and writes the results back to the database.
215+
216+
Parameters
217+
----------
218+
G : GraphV2
219+
The graph to run the algorithm on
220+
write_property : str
221+
The property name to write the community IDs to
222+
concurrency : Optional[int], default=4
223+
The number of concurrent threads
224+
consecutive_ids : Optional[bool], default=False
225+
Whether to use consecutive community IDs starting from 0
226+
job_id : Optional[str], default=None
227+
An identifier for the job
228+
log_progress : bool, default=True
229+
Whether to log progress
230+
max_iterations : Optional[int], default=10
231+
The maximum number of iterations
232+
min_community_size : Optional[int], default=None
233+
Minimum community size to include in results
234+
node_labels : Optional[List[str]], default=None
235+
The node labels used to select nodes for this algorithm run
236+
node_weight_property : Optional[str], default=None
237+
The property name for node weights
238+
relationship_types : Optional[List[str]], default=None
239+
The relationship types used to select relationships for this algorithm run
240+
relationship_weight_property : Optional[str], default=None
241+
The property name for relationship weights
242+
seed_property : Optional[str], default=None
243+
The property name containing seed values for initial community assignment
244+
sudo : Optional[bool] = False
245+
Override memory estimation limits
246+
username : Optional[str], default=None
247+
The username to attribute the procedure run to
248+
write_concurrency : Optional[int], default=None
249+
The number of concurrent threads for write operations
250+
write_to_result_store : Optional[bool], default=False
251+
Whether to write to the result store
252+
253+
Returns
254+
-------
255+
LabelPropagationWriteResult
256+
Algorithm metrics and statistics
257+
"""
258+
pass
259+
260+
@abstractmethod
261+
def estimate(
262+
self,
263+
G: Union[GraphV2, dict[str, Any]],
264+
*,
265+
concurrency: Optional[int] = 4,
266+
consecutive_ids: Optional[bool] = False,
267+
max_iterations: Optional[int] = 10,
268+
node_labels: Optional[List[str]] = None,
269+
node_weight_property: Optional[str] = None,
270+
relationship_types: Optional[List[str]] = None,
271+
relationship_weight_property: Optional[str] = None,
272+
seed_property: Optional[str] = None,
273+
) -> EstimationResult:
274+
"""
275+
Estimates the memory requirements for running the Label Propagation algorithm.
276+
277+
Parameters
278+
----------
279+
G : Union[GraphV2, dict[str, Any]]
280+
The graph or graph configuration to estimate for
281+
concurrency : Optional[int], default=4
282+
The number of concurrent threads
283+
consecutive_ids : Optional[bool], default=False
284+
Whether to use consecutive community IDs starting from 0
285+
max_iterations : Optional[int], default=10
286+
The maximum number of iterations
287+
node_labels : Optional[List[str]], default=None
288+
The node labels used to select nodes for this algorithm run
289+
node_weight_property : Optional[str], default=None
290+
The property name for node weights
291+
relationship_types : Optional[List[str]], default=None
292+
The relationship types used to select relationships for this algorithm run
293+
relationship_weight_property : Optional[str], default=None
294+
The property name for relationship weights
295+
seed_property : Optional[str], default=None
296+
The property name containing seed values for initial community assignment
297+
298+
Returns
299+
-------
300+
EstimationResult
301+
The memory estimation result
302+
"""
303+
pass
304+
305+
306+
class LabelPropagationMutateResult(BaseResult):
307+
community_count: int
308+
community_distribution: dict[str, Any]
309+
compute_millis: int
310+
configuration: dict[str, Any]
311+
did_converge: bool
312+
mutate_millis: int
313+
node_properties_written: int
314+
post_processing_millis: int
315+
pre_processing_millis: int
316+
ran_iterations: int
317+
318+
319+
class LabelPropagationStatsResult(BaseResult):
320+
community_count: int
321+
community_distribution: dict[str, Any]
322+
compute_millis: int
323+
configuration: dict[str, Any]
324+
did_converge: bool
325+
post_processing_millis: int
326+
pre_processing_millis: int
327+
ran_iterations: int
328+
329+
330+
class LabelPropagationWriteResult(BaseResult):
331+
community_count: int
332+
community_distribution: dict[str, Any]
333+
compute_millis: int
334+
configuration: dict[str, Any]
335+
did_converge: bool
336+
node_properties_written: int
337+
post_processing_millis: int
338+
pre_processing_millis: int
339+
ran_iterations: int
340+
write_millis: int

0 commit comments

Comments
 (0)