diff --git a/graphdatascience/procedure_surface/api/catalog/relationships_endpoints.py b/graphdatascience/procedure_surface/api/catalog/relationships_endpoints.py index 08ffc7c84..386b05f9a 100644 --- a/graphdatascience/procedure_surface/api/catalog/relationships_endpoints.py +++ b/graphdatascience/procedure_surface/api/catalog/relationships_endpoints.py @@ -208,6 +208,52 @@ def to_undirected( pass + @abstractmethod + def collapse_path( + self, + G: GraphV2, + path_templates: list[list[str]], + mutate_relationship_type: str, + *, + allow_self_loops: bool = False, + concurrency: int | None = None, + job_id: str | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + ) -> CollapsePathResult: + """ + Collapse each existing path in the graph into a single relationship. + + Parameters + ---------- + + G : GraphV2 + Name of the generated graph. + path_templates : list[list[str]] + A path template is an ordered list of relationship types used for the traversal. The same relationship type can be added multiple times, in order to traverse them as indicated. And, you may specify several path templates to process in one go. + mutate_relationship_type : str + The name of the new relationship type to be created. + allow_self_loops : bool, default=False + Whether nodes in the graph can have relationships where start and end nodes are the same. + concurrency : int | None, default=None: + Number of concurrent threads to use. + job_id : str | None, default=None + Unique identifier for the job associated with the computation. + sudo : bool | None, default=None + Override memory estimation limits + log_progress : bool | None, default=None + Whether to log progress during graph generation. + username : str | None, default=None + Username of the individual requesting the graph generation. + + Returns + ------- + CollapsePathResult: meta data about the generated relationships. + """ + + pass + class RelationshipsWriteResult(BaseResult): graph_name: str @@ -251,6 +297,14 @@ class RelationshipsToUndirectedResult(RelationshipsInverseIndexResult): relationships_written: int +class CollapsePathResult(BaseResult): + preProcessingMillis: int + computeMillis: int + mutateMillis: int + relationshipsWritten: int + configuration: dict[str, Any] + + class Aggregation(str, Enum): NONE = "NONE" SINGLE = "SINGLE" diff --git a/graphdatascience/procedure_surface/api/config_endpoints.py b/graphdatascience/procedure_surface/api/config_endpoints.py new file mode 100644 index 000000000..58c48b42a --- /dev/null +++ b/graphdatascience/procedure_surface/api/config_endpoints.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class ConfigEndpoints(ABC): + @property + @abstractmethod + def defaults(self) -> DefaultsEndpoints: + pass + + @property + @abstractmethod + def limits(self) -> LimitsEndpoints: + pass + + +class DefaultsEndpoints(ABC): + @abstractmethod + def set( + self, + key: str, + value: Any, + username: str | None = None, + ) -> None: + """ + Configure a new default configuration value. + + Parameters: + key : str + The configuration key for which the default value is being set. + value : Any + The value to set as the default for the given key. + username : str | None, default=None + If set, the configuration will be set for the given user. + + Returns: None + """ + pass + + @abstractmethod + def list( + self, + username: str | None = None, + key: str | None = None, + ) -> dict[str, Any]: + """ + List configured default configuration values. + + Parameters: + key : str | None (default=None) + List only the default value for the given key. + username : str | None, default=None + List only default values for the given user. + + Returns: dict[str, Any] + A dictionary containing the default configuration values. + """ + pass + + +class LimitsEndpoints(ABC): + @abstractmethod + def set( + self, + key: str, + value: Any, + username: str | None = None, + ) -> None: + """ + Configure a new limit for a configuration value. + + Parameters: + key : str + The configuration key for which the limit is being set. + value : Any + The value to set as the limit for the given key. + username : str | None, default=None + If set, the limit will be set for the given user. + + Returns: None + """ + pass + + @abstractmethod + def list( + self, + username: str | None = None, + key: str | None = None, + ) -> dict[str, Any]: + """ + List configured configuration limits. + + Parameters: + key : str | None (default=None) + List only the limits for the given key. + username : str | None, default=None + List only liomits for the given user. + + Returns: dict[str, Any] + A dictionary containing the configuration limits. + """ + pass diff --git a/graphdatascience/procedure_surface/api/system_endpoints.py b/graphdatascience/procedure_surface/api/system_endpoints.py new file mode 100644 index 000000000..6a3a94be4 --- /dev/null +++ b/graphdatascience/procedure_surface/api/system_endpoints.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from graphdatascience.procedure_surface.api.base_result import BaseResult + + +class SystemEndpoints(ABC): + @abstractmethod + def list_progress( + self, + job_id: str | None = None, + show_completed: bool = False, + ) -> list[ProgressResult]: + pass + + +class ProgressResult(BaseResult): + username: str + job_id: str + task_name: str + progress: str + progress_bar: str + status: str + time_started: str + elapsed_time: str diff --git a/graphdatascience/procedure_surface/arrow/catalog/relationship_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/catalog/relationship_arrow_endpoints.py index 9d7f95364..21533f5f3 100644 --- a/graphdatascience/procedure_surface/arrow/catalog/relationship_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/catalog/relationship_arrow_endpoints.py @@ -7,6 +7,7 @@ from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 from graphdatascience.procedure_surface.api.catalog.relationships_endpoints import ( Aggregation, + CollapsePathResult, RelationshipsDropResult, RelationshipsEndpoints, RelationshipsInverseIndexResult, @@ -197,3 +198,35 @@ def to_undirected( ) result = JobClient.get_summary(self._arrow_client, job_id) return RelationshipsToUndirectedResult(**result) + + def collapse_path( + self, + G: GraphV2, + path_templates: list[list[str]], + mutate_relationship_type: str, + *, + allow_self_loops: bool = False, + concurrency: int | None = None, + job_id: str | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + ) -> CollapsePathResult: + config = ConfigConverter.convert_to_gds_config( + graph_name=G.name(), + path_templates=path_templates, + mutate_relationship_type=mutate_relationship_type, + allow_self_loops=allow_self_loops, + concurrency=concurrency, + job_id=job_id, + sudo=sudo, + log_progress=log_progress, + username=username, + ) + + show_progress = self._show_progress and log_progress + job_id = JobClient.run_job_and_wait( + self._arrow_client, "v2/graph.relationships.collapsePath", config, show_progress=show_progress + ) + + return CollapsePathResult(**JobClient.get_summary(self._arrow_client, job_id)) diff --git a/graphdatascience/procedure_surface/arrow/config_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/config_arrow_endpoints.py new file mode 100644 index 000000000..9f43dbb93 --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/config_arrow_endpoints.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Any + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize +from graphdatascience.procedure_surface.api.config_endpoints import ( + ConfigEndpoints, + DefaultsEndpoints, + LimitsEndpoints, +) +from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter + + +class ConfigArrowEndpoints(ConfigEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + + @property + def defaults(self) -> DefaultsEndpoints: + return DefaultsArrowEndpoints(self._arrow_client) + + @property + def limits(self) -> LimitsEndpoints: + return LimitsArrowEndpoints(self._arrow_client) + + +class DefaultsArrowEndpoints(DefaultsEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + + def set( + self, + key: str, + value: Any, + username: str | None = None, + ) -> None: + key = ConfigConverter.convert_to_camel_case(key) + deserialize(self._arrow_client.do_action_with_retry("v2/defaults.set", {key: value})) + + def list( + self, + username: str | None = None, + key: str | None = None, + ) -> dict[str, Any]: + config = ConfigConverter.convert_to_gds_config( + key=key, + ) + + rows = deserialize(self._arrow_client.do_action_with_retry("v2/defaults.list", config)) + result = {} + + for row in rows: + result[row["key"]] = row["value"] + + return result + + +class LimitsArrowEndpoints(LimitsEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + + def set( + self, + key: str, + value: Any, + username: str | None = None, + ) -> None: + key = ConfigConverter.convert_to_camel_case(key) + deserialize(self._arrow_client.do_action_with_retry("v2/limits.set", {key: value})) + + def list( + self, + username: str | None = None, + key: str | None = None, + ) -> dict[str, Any]: + config = ConfigConverter.convert_to_gds_config( + key=key, + ) + + rows = deserialize(self._arrow_client.do_action_with_retry("v2/limits.list", config)) + result = {} + + for row in rows: + result[row["key"]] = row["value"] + + return result diff --git a/graphdatascience/procedure_surface/arrow/system_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/system_arrow_endpoints.py new file mode 100644 index 000000000..3bc750c13 --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/system_arrow_endpoints.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize +from graphdatascience.procedure_surface.api.system_endpoints import ProgressResult, SystemEndpoints +from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter + + +class SystemArrowEndpoints(SystemEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + + def list_progress( + self, + job_id: str | None = None, + show_completed: bool = False, + ) -> list[ProgressResult]: + config = ConfigConverter.convert_to_gds_config( + job_id=job_id, + show_completed=show_completed, + ) + + rows = deserialize(self._arrow_client.do_action_with_retry("v2/listProgress", config)) + return [ProgressResult(**row) for row in rows] diff --git a/graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py index f1e61b574..6c4cabed5 100644 --- a/graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py +++ b/graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py @@ -5,6 +5,7 @@ from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 from graphdatascience.procedure_surface.api.catalog.relationships_endpoints import ( Aggregation, + CollapsePathResult, RelationshipsDropResult, RelationshipsEndpoints, RelationshipsInverseIndexResult, @@ -224,3 +225,36 @@ def to_undirected( ).squeeze() return RelationshipsToUndirectedResult(**result.to_dict()) + + def collapse_path( + self, + G: GraphV2, + path_templates: list[list[str]], + mutate_relationship_type: str, + *, + allow_self_loops: bool = False, + concurrency: int | None = None, + job_id: str | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + ) -> CollapsePathResult: + config = ConfigConverter.convert_to_gds_config( + path_templates=path_templates, + mutate_relationship_type=mutate_relationship_type, + allow_self_loops=allow_self_loops, + concurrency=concurrency, + job_id=job_id, + sudo=sudo, + log_progress=log_progress, + username=username, + ) + + params = CallParameters( + graph_name=G.name(), + config=config, + ) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.collapsePath.mutate", params=params).squeeze() + return CollapsePathResult(**result.to_dict()) diff --git a/graphdatascience/procedure_surface/cypher/config_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/config_cypher_endpoints.py new file mode 100644 index 000000000..735ed379c --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/config_cypher_endpoints.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import Any + +from graphdatascience.call_parameters import CallParameters +from graphdatascience.procedure_surface.api.config_endpoints import ( + ConfigEndpoints, + DefaultsEndpoints, + LimitsEndpoints, +) +from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter +from graphdatascience.query_runner.query_runner import QueryRunner + + +class ConfigCypherEndpoints(ConfigEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + @property + def defaults(self) -> DefaultsEndpoints: + return DefaultsCypherEndpoints(self._query_runner) + + @property + def limits(self) -> LimitsEndpoints: + return LimitsCypherEndpoints(self._query_runner) + + +class DefaultsCypherEndpoints(DefaultsEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def set( + self, + key: str, + value: Any, + username: str | None = None, + ) -> None: + key = ConfigConverter.convert_to_camel_case(key) + + params = { + "key": key, + "value": value, + } + + if username: + params["username"] = username + + params = CallParameters(**params) + + self._query_runner.call_procedure(endpoint="gds.config.defaults.set", params=params) + + def list( + self, + username: str | None = None, + key: str | None = None, + ) -> dict[str, Any]: + config = ConfigConverter.convert_to_gds_config( + key=key, + username=username, + ) + + params = CallParameters( + config=config, + ) + + result = self._query_runner.call_procedure(endpoint="gds.config.defaults.list", params=params) + return {row["key"]: row["value"] for _, row in result.iterrows()} + + +class LimitsCypherEndpoints(LimitsEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def set( + self, + key: str, + value: Any, + username: str | None = None, + ) -> None: + key = ConfigConverter.convert_to_camel_case(key) + + params = { + "key": key, + "value": value, + } + + if username: + params["username"] = username + + params = CallParameters(**params) + + self._query_runner.call_procedure(endpoint="gds.config.limits.set", params=params) + + def list( + self, + username: str | None = None, + key: str | None = None, + ) -> dict[str, Any]: + config = ConfigConverter.convert_to_gds_config( + key=key, + username=username, + ) + + params = CallParameters( + config=config, + ) + + result = self._query_runner.call_procedure(endpoint="gds.config.limits.list", params=params) + return {row["key"]: row["value"] for _, row in result.iterrows()} diff --git a/graphdatascience/procedure_surface/cypher/system_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/system_cypher_endpoints.py new file mode 100644 index 000000000..ddeace20f --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/system_cypher_endpoints.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from graphdatascience.call_parameters import CallParameters +from graphdatascience.procedure_surface.api.system_endpoints import ProgressResult, SystemEndpoints +from graphdatascience.query_runner.query_runner import QueryRunner + + +class SystemCypherEndpoints(SystemEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def list_progress( + self, + job_id: str | None = None, + show_completed: bool = False, + ) -> list[ProgressResult]: + params = CallParameters( + job_id=job_id if job_id else "", + show_completed=True, + ) + + result = self._query_runner.call_procedure(endpoint="gds.listProgress", params=params) + return [ProgressResult(**row.to_dict()) for _, row in result.iterrows()] diff --git a/graphdatascience/procedure_surface/utils/config_converter.py b/graphdatascience/procedure_surface/utils/config_converter.py index c8f122b97..d49804532 100644 --- a/graphdatascience/procedure_surface/utils/config_converter.py +++ b/graphdatascience/procedure_surface/utils/config_converter.py @@ -13,7 +13,7 @@ def convert_to_gds_config(**kwargs: Any | None) -> dict[str, Any]: return config @staticmethod - def _convert_to_camel_case(name: str) -> str: + def convert_to_camel_case(name: str) -> str: """Convert a snake_case string to camelCase.""" parts = name.split("_") @@ -29,7 +29,7 @@ def _process_dict_values(input_dict: dict[str, Any]) -> dict[str, Any]: result = {} for key, value in input_dict.items(): if value is not None: - camel_key = ConfigConverter._convert_to_camel_case(key) + camel_key = ConfigConverter.convert_to_camel_case(key) # Recursively process nested dictionaries if isinstance(value, dict): result[camel_key] = ConfigConverter._process_dict_values(value) diff --git a/graphdatascience/session/session_v2_endpoints.py b/graphdatascience/session/session_v2_endpoints.py index 70a55b290..b408d28bf 100644 --- a/graphdatascience/session/session_v2_endpoints.py +++ b/graphdatascience/session/session_v2_endpoints.py @@ -86,6 +86,7 @@ TriangleCountArrowEndpoints, ) from graphdatascience.procedure_surface.arrow.community.wcc_arrow_endpoints import WccArrowEndpoints +from graphdatascience.procedure_surface.arrow.config_arrow_endpoints import ConfigArrowEndpoints from graphdatascience.procedure_surface.arrow.node_embedding.fastrp_arrow_endpoints import FastRPArrowEndpoints from graphdatascience.procedure_surface.arrow.node_embedding.graphsage_predict_arrow_endpoints import ( GraphSagePredictArrowEndpoints, @@ -120,6 +121,7 @@ from graphdatascience.procedure_surface.arrow.similarity.node_similarity_arrow_endpoints import ( NodeSimilarityArrowEndpoints, ) +from graphdatascience.procedure_surface.arrow.system_arrow_endpoints import SystemArrowEndpoints from graphdatascience.query_runner.query_runner import QueryRunner @@ -143,6 +145,14 @@ def set_show_progress(self, show_progress: bool) -> None: def graph(self) -> CatalogArrowEndpoints: return CatalogArrowEndpoints(self._arrow_client, self._db_client, show_progress=self._show_progress) + @property + def config(self) -> ConfigArrowEndpoints: + return ConfigArrowEndpoints(self._arrow_client) + + @property + def system(self) -> SystemArrowEndpoints: + return SystemArrowEndpoints(self._arrow_client) + ## Algorithms @property diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/catalog/test_relationship_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/catalog/test_relationship_arrow_endpoints.py index 9f69f817b..cf5e168ed 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/arrow/catalog/test_relationship_arrow_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/catalog/test_relationship_arrow_endpoints.py @@ -231,3 +231,16 @@ def test_to_undirected_with_property_aggregation( assert result.mutate_millis >= 0 assert result.input_relationships == 3 assert result.relationships_written == 6 + + +def test_collapse_path(relationship_endpoints: RelationshipArrowEndpoints, sample_graph: GraphV2) -> None: + result = relationship_endpoints.collapse_path( + G=sample_graph, + path_templates=[["REL", "REL"]], + mutate_relationship_type="FoF", + ) + + assert result.relationshipsWritten == 3 + assert result.mutateMillis >= 0 + assert result.preProcessingMillis >= 0 + assert result.computeMillis >= 0 diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_config_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_config_arrow_endpoints.py new file mode 100644 index 000000000..9ac987656 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_config_arrow_endpoints.py @@ -0,0 +1,58 @@ +from typing import Generator + +import pytest + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.arrow.config_arrow_endpoints import ConfigArrowEndpoints + + +@pytest.fixture +def config_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[ConfigArrowEndpoints, None, None]: + yield ConfigArrowEndpoints(arrow_client) + + +def test_defaults_set_and_list(config_endpoints: ConfigArrowEndpoints) -> None: + config_endpoints.defaults.set("test.key", 6) + + defaults = config_endpoints.defaults.list() + + assert ("test.key", 6) in defaults.items() + + +def test_defaults_list_by_key(config_endpoints: ConfigArrowEndpoints) -> None: + config_endpoints.defaults.set("test.specific.key", "specific_value") + + specific_defaults = config_endpoints.defaults.list(key="test.specific.key") + + assert specific_defaults == {"test.specific.key": "specific_value"} + + +def test_limits_set_and_list(config_endpoints: ConfigArrowEndpoints) -> None: + config_endpoints.limits.set("test.key", 6) + + limits = config_endpoints.limits.list() + + assert ("test.key", 6) in limits.items() + + +def test_limits_list_by_key(config_endpoints: ConfigArrowEndpoints) -> None: + config_endpoints.limits.set("test.specific.key", 42) + + specific_limits = config_endpoints.limits.list(key="test.specific.key") + + assert specific_limits == {"test.specific.key": 42} + + +def test_config_endpoints_properties(config_endpoints: ConfigArrowEndpoints) -> None: + """Test that the config endpoints have the required properties.""" + assert hasattr(config_endpoints, "defaults") + assert hasattr(config_endpoints, "limits") + + # Verify the properties return the correct endpoint types + defaults = config_endpoints.defaults + limits = config_endpoints.limits + + assert hasattr(defaults, "set") + assert hasattr(defaults, "list") + assert hasattr(limits, "set") + assert hasattr(limits, "list") diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_system_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_system_arrow_endpoints.py new file mode 100644 index 000000000..ec3f1860e --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_system_arrow_endpoints.py @@ -0,0 +1,50 @@ +from typing import Generator + +import pytest + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.arrow.community.wcc_arrow_endpoints import WccArrowEndpoints +from graphdatascience.procedure_surface.arrow.system_arrow_endpoints import SystemArrowEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph + +graph = """ + CREATE + (a1)-[:T]->(a2) + """ + + +@pytest.fixture(scope="class") +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]: + with create_graph(arrow_client, "g", graph) as G: + yield G + + +@pytest.fixture(scope="class") +def job_id(arrow_client: AuthenticatedArrowClient, sample_graph: GraphV2) -> Generator[str, None, None]: + job_id = "test_job_id" + WccArrowEndpoints(arrow_client).stats(sample_graph, job_id=job_id) + yield job_id + + +@pytest.fixture +def system_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[SystemArrowEndpoints, None, None]: + yield SystemArrowEndpoints(arrow_client) + + +def test_list_progress_job_id(system_endpoints: SystemArrowEndpoints, job_id: str) -> None: + results = system_endpoints.list_progress(job_id=job_id, show_completed=True) + + assert len(results) == 1 + + result = results[0] + assert result.username == "neo4j" + assert result.job_id == job_id + assert "WCC" in result.task_name + assert result.progress == "100%" + assert "#" in result.progress_bar + + +def test_list_nothing(system_endpoints: SystemArrowEndpoints) -> None: + results = system_endpoints.list_progress() + assert len(results) == 0 diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/catalog/test_relationship_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/catalog/test_relationship_cypher_endpoints.py index 199936465..956424113 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/cypher/catalog/test_relationship_cypher_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/catalog/test_relationship_cypher_endpoints.py @@ -252,3 +252,16 @@ def test_to_undirected_with_property_aggregation( assert result.mutate_millis >= 0 assert result.input_relationships == 3 assert result.relationships_written == 6 + + +def test_collapse_path(relationship_endpoints: RelationshipCypherEndpoints, sample_graph: GraphV2) -> None: + result = relationship_endpoints.collapse_path( + G=sample_graph, + path_templates=[["REL", "REL"]], + mutate_relationship_type="FoF", + ) + + assert result.relationshipsWritten == 3 + assert result.mutateMillis >= 0 + assert result.preProcessingMillis >= 0 + assert result.computeMillis >= 0 diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_config_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_config_cypher_endpoints.py new file mode 100644 index 000000000..8532a470f --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_config_cypher_endpoints.py @@ -0,0 +1,78 @@ +from typing import Generator + +import pytest + +from graphdatascience import QueryRunner +from graphdatascience.procedure_surface.cypher.config_cypher_endpoints import ConfigCypherEndpoints + + +@pytest.fixture +def config_endpoints(query_runner: QueryRunner) -> Generator[ConfigCypherEndpoints, None, None]: + yield ConfigCypherEndpoints(query_runner) + + +def test_defaults_set_and_list(config_endpoints: ConfigCypherEndpoints) -> None: + config_endpoints.defaults.set("test.key", 6) + + defaults = config_endpoints.defaults.list() + + assert ("test.key", 6) in defaults.items() + + +def test_defaults_set_with_username(config_endpoints: ConfigCypherEndpoints) -> None: + config_endpoints.defaults.set("test.user.key", "user_value", username="testuser") + config_endpoints.defaults.set("test.user.key", "user2_value", username="testuser2") + + user_defaults = config_endpoints.defaults.list(username="testuser") + + assert ("test.user.key", "user_value") in user_defaults.items() + assert ("test.user.key", "user2_value") not in user_defaults.items() + + +def test_defaults_list_by_key(config_endpoints: ConfigCypherEndpoints) -> None: + config_endpoints.defaults.set("test.specific.key", "specific_value") + + specific_defaults = config_endpoints.defaults.list(key="test.specific.key") + + assert specific_defaults == {"test.specific.key": "specific_value"} + + +def test_limits_set_and_list(config_endpoints: ConfigCypherEndpoints) -> None: + config_endpoints.limits.set("test.key", 6) + + limits = config_endpoints.limits.list() + + assert ("test.key", 6) in limits.items() + + +def test_limits_set_with_username(config_endpoints: ConfigCypherEndpoints) -> None: + config_endpoints.limits.set("test.user.key", 1, username="testuser") + config_endpoints.limits.set("test.user.key", 2, username="testuser2") + + user_limits = config_endpoints.limits.list(username="testuser") + + assert ("test.user.key", 1) in user_limits.items() + assert ("test.user.key", 2) not in user_limits.items() + + +def test_limits_list_by_key(config_endpoints: ConfigCypherEndpoints) -> None: + config_endpoints.limits.set("test.specific.key", 42) + + specific_limits = config_endpoints.limits.list(key="test.specific.key") + + assert specific_limits == {"test.specific.key": 42} + + +def test_config_endpoints_properties(config_endpoints: ConfigCypherEndpoints) -> None: + """Test that the config endpoints have the required properties.""" + assert hasattr(config_endpoints, "defaults") + assert hasattr(config_endpoints, "limits") + + # Verify the properties return the correct endpoint types + defaults = config_endpoints.defaults + limits = config_endpoints.limits + + assert hasattr(defaults, "set") + assert hasattr(defaults, "list") + assert hasattr(limits, "set") + assert hasattr(limits, "list") diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_system_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_system_cypher_endpoints.py new file mode 100644 index 000000000..da67a087f --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_system_cypher_endpoints.py @@ -0,0 +1,58 @@ +from typing import Generator + +import pytest + +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.cypher.community.wcc_cypher_endpoints import WccCypherEndpoints +from graphdatascience.procedure_surface.cypher.system_cypher_endpoints import SystemCypherEndpoints +from graphdatascience.query_runner.query_runner import QueryRunner +from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import create_graph + + +@pytest.fixture +def system_endpoints(query_runner: QueryRunner) -> Generator[SystemCypherEndpoints, None, None]: + yield SystemCypherEndpoints(query_runner) + + +graph = """ + CREATE + (a1)-[:T]->(a2) + """ + +projection_query = """ + MATCH (n)-[r]->(m) + WITH gds.graph.project('g', n, m, {}) AS G + RETURN G +""" + + +@pytest.fixture(scope="class") +def sample_graph(query_runner: QueryRunner) -> Generator[GraphV2, None, None]: + with create_graph(query_runner, "g", graph, projection_query) as G: + yield G + + +@pytest.fixture(scope="class") +def job_id(query_runner: QueryRunner, sample_graph: GraphV2) -> Generator[str, None, None]: + job_id = "test_job_id" + WccCypherEndpoints(query_runner).mutate(sample_graph, job_id=job_id, log_progress=True, mutate_property="wcc") + yield job_id + + +@pytest.mark.skip(reason="Enable when we figure out how to retain jobs") +def test_list_progress_job_id(system_endpoints: SystemCypherEndpoints, job_id: str) -> None: + results = system_endpoints.list_progress(show_completed=True) + + assert len(results) == 1 + + result = results[0] + assert result.username == "neo4j" + assert result.job_id == job_id + assert "WCC" in result.task_name + assert result.progress == "100%" + assert "#" in result.progress_bar + + +def test_list_nothing(system_endpoints: SystemCypherEndpoints) -> None: + results = system_endpoints.list_progress() + assert len(results) == 0