Skip to content

Commit 8e42790

Browse files
committed
Improve knn tests
1 parent 1839c7f commit 8e42790

File tree

4 files changed

+49
-64
lines changed

4 files changed

+49
-64
lines changed

graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_arrow_endpoints.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ def test_knn_stats(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) -> N
5353
"""Test KNN stats operation."""
5454
result = knn_endpoints.stats(G=sample_graph, node_properties=["prop"], top_k=2)
5555

56-
assert result.ran_iterations >= 0
57-
assert result.did_converge in [True, False]
56+
assert result.ran_iterations > 0
57+
assert result.did_converge
5858
assert result.compute_millis >= 0
5959
assert result.pre_processing_millis >= 0
6060
assert result.post_processing_millis >= 0
6161
assert result.nodes_compared > 0
62-
assert result.similarity_pairs >= 0
62+
assert result.similarity_pairs > 0
6363
assert result.node_pairs_considered > 0
6464
assert "p50" in result.similarity_distribution
6565

@@ -86,14 +86,14 @@ def test_knn_mutate(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) ->
8686
top_k=2,
8787
)
8888

89-
assert result.ran_iterations >= 0
90-
assert result.did_converge in [True, False]
89+
assert result.ran_iterations > 0
90+
assert result.did_converge
9191
assert result.pre_processing_millis >= 0
9292
assert result.compute_millis >= 0
9393
assert result.post_processing_millis >= 0
9494
assert result.mutate_millis >= 0
9595
assert result.relationships_written == sample_graph.node_count() * 2
96-
assert result.node_pairs_considered >= 0
96+
assert result.node_pairs_considered > 0
9797

9898

9999
@pytest.mark.db_integration
@@ -108,14 +108,14 @@ def test_knn_write(arrow_client: AuthenticatedArrowClient, query_runner: QueryRu
108108
)
109109

110110
assert isinstance(result, KnnWriteResult)
111-
assert result.ran_iterations >= 0
112-
assert result.did_converge in [True, False]
111+
assert result.ran_iterations > 0
112+
assert result.did_converge
113113
assert result.pre_processing_millis >= 0
114114
assert result.compute_millis >= 0
115115
assert result.post_processing_millis >= 0
116116
assert result.write_millis >= 0
117117
assert result.relationships_written == db_graph.node_count() * 2
118-
assert result.node_pairs_considered >= 0
118+
assert result.node_pairs_considered > 0
119119

120120
# Check that relationships were written to the database
121121
count_result = query_runner.run_cypher("MATCH ()-[r:SIMILAR]->() RETURN COUNT(r) AS count")

graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_filtered_arrow_endpoints.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ def test_knn_filtered_stats(knn_filtered_endpoints: KnnFilteredArrowEndpoints, s
6363
assert result.post_processing_millis >= 0
6464
assert result.nodes_compared >= 0
6565
assert result.similarity_pairs >= 0
66-
assert result.similarity_distribution is not None
67-
assert isinstance(result.did_converge, bool)
66+
assert "p50" in result.similarity_distribution
67+
assert result.did_converge
6868
assert result.ran_iterations >= 0
6969
assert result.node_pairs_considered >= 0
70-
assert result.configuration is not None
70+
assert "concurrency" in result.configuration
7171

7272

7373
def test_knn_filtered_stream(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:
@@ -99,12 +99,12 @@ def test_knn_filtered_mutate(knn_filtered_endpoints: KnnFilteredArrowEndpoints,
9999
assert result.mutate_millis >= 0
100100
assert result.post_processing_millis >= 0
101101
assert result.nodes_compared >= 0
102-
assert result.relationships_written >= 0
103-
assert result.similarity_distribution is not None
104-
assert isinstance(result.did_converge, bool)
105-
assert result.ran_iterations >= 0
102+
assert result.relationships_written > 0
103+
assert "p50" in result.similarity_distribution
104+
assert result.did_converge
105+
assert result.ran_iterations > 0
106106
assert result.node_pairs_considered >= 0
107-
assert result.configuration is not None
107+
assert "concurrency" in result.configuration
108108

109109

110110
def test_knn_filtered_write(
@@ -124,18 +124,17 @@ def test_knn_filtered_write(
124124
target_node_filter="TargetNode",
125125
)
126126

127-
assert isinstance(result, KnnWriteResult)
128127
assert result.pre_processing_millis >= 0
129128
assert result.compute_millis >= 0
130129
assert result.write_millis >= 0
131130
assert result.post_processing_millis >= 0
132131
assert result.nodes_compared >= 0
133-
assert result.relationships_written >= 0
134-
assert isinstance(result.did_converge, bool)
135-
assert result.ran_iterations >= 0
132+
assert result.relationships_written > 0
133+
assert result.did_converge
134+
assert result.ran_iterations > 0
136135
assert result.node_pairs_considered >= 0
137-
assert result.similarity_distribution is not None
138-
assert result.configuration is not None
136+
assert "p50" in result.similarity_distribution
137+
assert "concurrency" in result.configuration
139138

140139

141140
def test_knn_filtered_estimate(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None:

graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/test_knn_cypher_endpoints.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,37 +39,31 @@ def knn_endpoints(query_runner: QueryRunner) -> Generator[KnnCypherEndpoints, No
3939

4040

4141
def test_knn_stats(knn_endpoints: KnnCypherEndpoints, sample_graph: GraphV2) -> None:
42-
"""Test KNN stats operation."""
4342
result = knn_endpoints.stats(G=sample_graph, node_properties=["prop"], top_k=2)
4443

45-
assert result.ran_iterations >= 0
46-
assert result.did_converge in [True, False]
44+
assert result.ran_iterations > 0
45+
assert result.did_converge
4746
assert result.compute_millis > 0
4847
assert result.pre_processing_millis >= 0
4948
assert result.post_processing_millis >= 0
5049
assert result.nodes_compared > 0
5150
assert result.similarity_pairs == 8
5251
assert result.node_pairs_considered > 0
53-
assert "p50" in result.similarity_distribution or "p10" in result.similarity_distribution
52+
assert "p50" in result.similarity_distribution
5453

5554

5655
def test_knn_stream(knn_endpoints: KnnCypherEndpoints, sample_graph: GraphV2) -> None:
57-
"""Test KNN stream operation."""
5856
result_df = knn_endpoints.stream(
5957
G=sample_graph,
6058
node_properties=["prop"],
6159
top_k=2,
6260
)
6361

64-
assert "node1" in result_df.columns
65-
assert "node2" in result_df.columns
66-
assert "similarity" in result_df.columns
67-
assert len(result_df.columns) == 3
62+
assert set(result_df.columns) == {"node1", "node2", "similarity"}
6863
assert len(result_df) == 8
6964

7065

7166
def test_knn_mutate(knn_endpoints: KnnCypherEndpoints, sample_graph: GraphV2) -> None:
72-
"""Test KNN mutate operation."""
7367
result = knn_endpoints.mutate(
7468
G=sample_graph,
7569
mutate_relationship_type="SIMILAR",
@@ -78,14 +72,14 @@ def test_knn_mutate(knn_endpoints: KnnCypherEndpoints, sample_graph: GraphV2) ->
7872
top_k=2,
7973
)
8074

81-
assert result.ran_iterations >= 0
82-
assert result.did_converge in [True, False]
75+
assert result.ran_iterations > 0
76+
assert result.did_converge
8377
assert result.pre_processing_millis >= 0
8478
assert result.compute_millis >= 0
8579
assert result.post_processing_millis >= 0
8680
assert result.mutate_millis >= 0
8781
assert result.relationships_written == 8
88-
assert result.node_pairs_considered >= 0
82+
assert result.node_pairs_considered > 0
8983

9084

9185
def test_knn_estimate(knn_endpoints: KnnCypherEndpoints, sample_graph: GraphV2) -> None:

graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/test_knn_filtered_cypher_endpoints.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def knn_filtered_endpoints(query_runner: QueryRunner) -> Generator[KnnFilteredCy
4141

4242

4343
def test_knn_filtered_stats(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None:
44-
"""Test KNN filtered stats operation."""
4544
result = knn_filtered_endpoints.stats(
4645
G=sample_graph,
4746
node_properties=["prop"],
@@ -50,20 +49,19 @@ def test_knn_filtered_stats(knn_filtered_endpoints: KnnFilteredCypherEndpoints,
5049
target_node_filter="TargetNode",
5150
)
5251

53-
assert result.ran_iterations >= 0
54-
assert result.did_converge in [True, False]
52+
assert result.ran_iterations > 0
53+
assert result.did_converge
5554
assert result.compute_millis > 0
5655
assert result.pre_processing_millis >= 0
5756
assert result.post_processing_millis >= 0
5857
assert result.nodes_compared > 0
59-
assert result.similarity_pairs >= 0
60-
assert result.similarity_distribution is not None
61-
assert result.node_pairs_considered >= 0
62-
assert result.configuration is not None
58+
assert result.similarity_pairs > 0
59+
assert "p50" in result.similarity_distribution
60+
assert result.node_pairs_considered > 0
61+
assert "concurrency" in result.configuration
6362

6463

6564
def test_knn_filtered_stream(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None:
66-
"""Test KNN filtered stream operation."""
6765
result = knn_filtered_endpoints.stream(
6866
G=sample_graph,
6967
node_properties=["prop"],
@@ -72,15 +70,11 @@ def test_knn_filtered_stream(knn_filtered_endpoints: KnnFilteredCypherEndpoints,
7270
target_node_filter="TargetNode",
7371
)
7472

75-
assert len(result) >= 0
76-
if len(result) > 0:
77-
assert "node1" in result.columns
78-
assert "node2" in result.columns
79-
assert "similarity" in result.columns
73+
assert set(result.columns) == {"node1", "node2", "similarity"}
74+
assert len(result) >= 4
8075

8176

8277
def test_knn_filtered_mutate(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None:
83-
"""Test KNN filtered mutate operation."""
8478
result = knn_filtered_endpoints.mutate(
8579
G=sample_graph,
8680
node_properties=["prop"],
@@ -91,21 +85,20 @@ def test_knn_filtered_mutate(knn_filtered_endpoints: KnnFilteredCypherEndpoints,
9185
target_node_filter="TargetNode",
9286
)
9387

94-
assert result.ran_iterations >= 0
95-
assert result.did_converge in [True, False]
88+
assert result.ran_iterations > 0
89+
assert result.did_converge
9690
assert result.compute_millis > 0
9791
assert result.mutate_millis >= 0
9892
assert result.pre_processing_millis >= 0
9993
assert result.post_processing_millis >= 0
10094
assert result.nodes_compared > 0
101-
assert result.relationships_written >= 0
102-
assert result.similarity_distribution is not None
103-
assert result.node_pairs_considered >= 0
104-
assert result.configuration is not None
95+
assert result.relationships_written > 0
96+
assert "p50" in result.similarity_distribution
97+
assert result.node_pairs_considered > 0
98+
assert "concurrency" in result.configuration
10599

106100

107101
def test_knn_filtered_write(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None:
108-
"""Test KNN filtered write operation."""
109102
result = knn_filtered_endpoints.write(
110103
G=sample_graph,
111104
node_properties=["prop"],
@@ -116,21 +109,20 @@ def test_knn_filtered_write(knn_filtered_endpoints: KnnFilteredCypherEndpoints,
116109
target_node_filter="TargetNode",
117110
)
118111

119-
assert result.ran_iterations >= 0
120-
assert result.did_converge in [True, False]
112+
assert result.ran_iterations > 0
113+
assert result.did_converge
121114
assert result.compute_millis > 0
122115
assert result.write_millis >= 0
123116
assert result.pre_processing_millis >= 0
124117
assert result.post_processing_millis >= 0
125118
assert result.nodes_compared > 0
126-
assert result.relationships_written >= 0
127-
assert result.similarity_distribution is not None
128-
assert result.node_pairs_considered >= 0
129-
assert result.configuration is not None
119+
assert result.relationships_written > 0
120+
assert "p50" in result.similarity_distribution
121+
assert result.node_pairs_considered > 0
122+
assert "concurrency" in result.configuration
130123

131124

132125
def test_knn_filtered_estimate(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None:
133-
"""Test KNN filtered estimation operation."""
134126
result = knn_filtered_endpoints.estimate(
135127
G=sample_graph,
136128
node_properties=["prop"],

0 commit comments

Comments
 (0)