Skip to content

Commit 7aff806

Browse files
committed
Fix typing + signature mismatch
1 parent 96b1ea6 commit 7aff806

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

graphdatascience/procedure_surface/api/centrality/celf_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def write(
191191
random_seed: int | None = None,
192192
relationship_types: list[str] = ALL_TYPES,
193193
node_labels: list[str] = ALL_LABELS,
194-
sudo: bool | None = None,
194+
sudo: bool = False,
195195
log_progress: bool = True,
196196
username: str | None = None,
197197
concurrency: int | None = None,

graphdatascience/tests/integrationV2/procedure_surface/session/test_session_api_spec_coverage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
],
103103
}
104104

105-
ADJUSTED_PARAM_DEFAULT_VALUES = {
105+
ADJUSTED_PARAM_DEFAULT_VALUES: dict[str, dict[str, str | None]] = {
106106
".*": {
107107
"concurrency": None, # default value differs for Aura Graph Analytics compared to plugin (spec is off)
108108
"job_id": None, # default value in spec is `random id`
@@ -248,7 +248,7 @@ def verify_configuration_fields(callable_object: MethodType, endpoint_spec: Endp
248248
)
249249

250250
# validate default values match
251-
default_adjustments: dict[str, str] = {}
251+
default_adjustments: dict[str, str | None] = {}
252252
for endpoint_pattern, adjustments in ADJUSTED_PARAM_DEFAULT_VALUES.items():
253253
if re.match(endpoint_pattern, py_endpoint):
254254
default_adjustments.update(adjustments)

graphdatascience/tests/unit/procedure_surface/cypher/community/test_unit_k1coloring_cypher_endpoints.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
K1ColoringStatsResult,
88
K1ColoringWriteResult,
99
)
10+
from graphdatascience.procedure_surface.api.default_values import ALL_LABELS, ALL_TYPES
1011
from graphdatascience.procedure_surface.cypher.community.k1coloring_cypher_endpoints import K1ColoringCypherEndpoints
1112
from graphdatascience.tests.unit.conftest import DEFAULT_SERVER_VERSION, CollectingQueryRunner
1213
from graphdatascience.tests.unit.procedure_surface.cypher.conftest import estimate_mock_result
@@ -317,7 +318,12 @@ def test_estimate_with_graph_name(graph: GraphV2) -> None:
317318
assert "gds.k1coloring.stats.estimate" in query_runner.queries[0]
318319
params = query_runner.params[0]
319320
assert params["graphNameOrConfiguration"] == "test_graph"
320-
assert params["algoConfig"] == {}
321+
assert params["algoConfig"] == {
322+
"batchSize": 10000,
323+
"maxIterations": 10,
324+
"nodeLabels": ALL_LABELS,
325+
"relationshipTypes": ALL_TYPES,
326+
}
321327

322328

323329
def test_estimate_with_projection_config() -> None:
@@ -331,4 +337,9 @@ def test_estimate_with_projection_config() -> None:
331337
assert "gds.k1coloring.stats.estimate" in query_runner.queries[0]
332338
params = query_runner.params[0]
333339
assert params["graphNameOrConfiguration"] == {"foo": "bar"}
334-
assert params["algoConfig"] == {}
340+
assert params["algoConfig"] == {
341+
"batchSize": 10000,
342+
"maxIterations": 10,
343+
"nodeLabels": ALL_LABELS,
344+
"relationshipTypes": ALL_TYPES,
345+
}

0 commit comments

Comments
 (0)