Skip to content

Commit ffdbddf

Browse files
committed
Share logic + move result classes
1 parent a99ee6e commit ffdbddf

File tree

2 files changed

+65
-64
lines changed

2 files changed

+65
-64
lines changed

graphdatascience/arrow_client/v2/mutation_client.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,41 @@ class MutationClient:
1111

1212
@staticmethod
1313
def mutate_node_property(client: AuthenticatedArrowClient, job_id: str, mutate_property: str) -> MutateResult:
14-
mutate_config = {"jobId": job_id, "mutateProperty": mutate_property}
15-
start_time = time.time()
16-
mutate_arrow_res = client.do_action_with_retry(MutationClient.MUTATE_ENDPOINT, mutate_config)
17-
mutate_millis = math.ceil((time.time() - start_time) * 1000)
18-
return MutateResult(mutateMillis=mutate_millis, **deserialize_single(mutate_arrow_res))
14+
return MutationClient._mutate(
15+
client=client,
16+
job_id=job_id,
17+
mutate_property=mutate_property,
18+
)
1919

2020
@staticmethod
2121
def mutate_relationship_property(
2222
client: AuthenticatedArrowClient,
2323
job_id: str,
2424
mutate_relationship_type: str,
2525
mutate_property: str,
26+
) -> MutateResult:
27+
return MutationClient._mutate(
28+
client=client,
29+
job_id=job_id,
30+
mutate_property=mutate_property,
31+
mutate_relationship_type=mutate_relationship_type,
32+
)
33+
34+
@staticmethod
35+
def _mutate(
36+
client: AuthenticatedArrowClient,
37+
job_id: str,
38+
mutate_property: str | None = None,
39+
mutate_relationship_type: str | None = None,
2640
) -> MutateResult:
2741
mutate_config = {
2842
"jobId": job_id,
29-
"mutateProperty": mutate_property,
30-
"mutateRelationshipType": mutate_relationship_type,
3143
}
44+
if mutate_relationship_type:
45+
mutate_config["mutateRelationshipType"] = mutate_relationship_type
46+
if mutate_property:
47+
mutate_config["mutateProperty"] = mutate_property
48+
3249
start_time = time.time()
3350
mutate_arrow_res = client.do_action_with_retry(MutationClient.MUTATE_ENDPOINT, mutate_config)
3451
mutate_millis = math.ceil((time.time() - start_time) * 1000)

graphdatascience/procedure_surface/api/similarity/knn_endpoints.py

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10,53 +10,6 @@
1010
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
1111

1212

13-
class KnnMutateResult(BaseResult):
14-
"""Represents the result of running K-Nearest Neighbors in mutate mode."""
15-
16-
pre_processing_millis: int
17-
compute_millis: int
18-
mutate_millis: int
19-
post_processing_millis: int
20-
nodes_compared: int
21-
relationships_written: int
22-
similarity_distribution: dict[str, Any]
23-
did_converge: bool
24-
ran_iterations: int
25-
node_pairs_considered: int
26-
configuration: dict[str, Any]
27-
28-
29-
class KnnStatsResult(BaseResult):
30-
"""Represents the result of running K-Nearest Neighbors in stats mode."""
31-
32-
pre_processing_millis: int
33-
compute_millis: int
34-
post_processing_millis: int
35-
nodes_compared: int
36-
similarity_pairs: int
37-
similarity_distribution: dict[str, Any]
38-
did_converge: bool
39-
ran_iterations: int
40-
node_pairs_considered: int
41-
configuration: dict[str, Any]
42-
43-
44-
class KnnWriteResult(BaseResult):
45-
"""Represents the result of running K-Nearest Neighbors in write mode."""
46-
47-
pre_processing_millis: int
48-
compute_millis: int
49-
write_millis: int
50-
post_processing_millis: int
51-
nodes_compared: int
52-
relationships_written: int
53-
did_converge: bool
54-
ran_iterations: int
55-
node_pairs_considered: int
56-
similarity_distribution: dict[str, Any]
57-
configuration: dict[str, Any]
58-
59-
6013
class KnnEndpoints(ABC):
6114
@abstractmethod
6215
def mutate(
@@ -85,8 +38,6 @@ def mutate(
8538
"""
8639
Runs the K-Nearest Neighbors algorithm and stores the results as new relationships in the graph catalog.
8740
88-
The K-Nearest Neighbors algorithm computes a distance value for all node pairs in the graph and creates new relationships between each node and its k nearest neighbors
89-
9041
Parameters
9142
----------
9243
G : GraphV2
@@ -161,8 +112,6 @@ def stats(
161112
"""
162113
Runs the K-Nearest Neighbors algorithm and returns execution statistics.
163114
164-
The K-Nearest Neighbors algorithm computes a distance value for all node pairs in the graph and creates new relationships between each node and its k nearest neighbors
165-
166115
Parameters
167116
----------
168117
G : GraphV2
@@ -233,8 +182,6 @@ def stream(
233182
"""
234183
Runs the K-Nearest Neighbors algorithm and returns the result as a DataFrame.
235184
236-
The K-Nearest Neighbors algorithm computes a distance value for all node pairs in the graph and creates new relationships between each node and its k nearest neighbors
237-
238185
Parameters
239186
----------
240187
G : GraphV2
@@ -308,8 +255,6 @@ def write(
308255
"""
309256
Runs the K-Nearest Neighbors algorithm and writes the results back to the database.
310257
311-
The K-Nearest Neighbors algorithm computes a distance value for all node pairs in the graph and creates new relationships between each node and its k nearest neighbors
312-
313258
Parameters
314259
----------
315260
G : GraphV2
@@ -386,8 +331,6 @@ def estimate(
386331
"""
387332
Estimates the memory requirements for running the K-Nearest Neighbors algorithm.
388333
389-
The K-Nearest Neighbors algorithm computes a distance value for all node pairs in the graph and creates new relationships between each node and its k nearest neighbors
390-
391334
Parameters
392335
----------
393336
G : GraphV2
@@ -432,3 +375,44 @@ def estimate(
432375
EstimationResult
433376
Object containing the estimated memory requirements.
434377
"""
378+
379+
380+
class KnnMutateResult(BaseResult):
381+
pre_processing_millis: int
382+
compute_millis: int
383+
mutate_millis: int
384+
post_processing_millis: int
385+
nodes_compared: int
386+
relationships_written: int
387+
similarity_distribution: dict[str, Any]
388+
did_converge: bool
389+
ran_iterations: int
390+
node_pairs_considered: int
391+
configuration: dict[str, Any]
392+
393+
394+
class KnnStatsResult(BaseResult):
395+
pre_processing_millis: int
396+
compute_millis: int
397+
post_processing_millis: int
398+
nodes_compared: int
399+
similarity_pairs: int
400+
similarity_distribution: dict[str, Any]
401+
did_converge: bool
402+
ran_iterations: int
403+
node_pairs_considered: int
404+
configuration: dict[str, Any]
405+
406+
407+
class KnnWriteResult(BaseResult):
408+
pre_processing_millis: int
409+
compute_millis: int
410+
write_millis: int
411+
post_processing_millis: int
412+
nodes_compared: int
413+
relationships_written: int
414+
did_converge: bool
415+
ran_iterations: int
416+
node_pairs_considered: int
417+
similarity_distribution: dict[str, Any]
418+
configuration: dict[str, Any]

0 commit comments

Comments
 (0)