Skip to content

Commit add14ac

Browse files
Merge pull request #476 from linkml/schemaview_detect_cycles_node_list_args
schemaview.py: allow `detect_cycles` to receive a list of nodes as args
2 parents 9e0e62b + 2fe479e commit add14ac

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

linkml_runtime/utils/schemaview.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import uuid
99
import warnings
1010
from collections import defaultdict, deque
11+
from collections.abc import Iterable
1112
from copy import copy, deepcopy
1213
from dataclasses import dataclass
1314
from enum import Enum
@@ -44,7 +45,7 @@
4445
from linkml_runtime.utils.pattern import PatternResolver
4546

4647
if TYPE_CHECKING:
47-
from collections.abc import Callable, Iterable, Mapping
48+
from collections.abc import Callable, Mapping
4849
from types import NotImplementedType
4950

5051
from linkml_runtime.utils.metamodelcore import URI, URIorCURIE
@@ -96,8 +97,13 @@ class OrderedBy(Enum):
9697
BLACK = 2
9798

9899

99-
def detect_cycles(f: Callable[[Any], Iterable[Any] | None], x: Any) -> None:
100-
"""Detect cycles in a graph, using function `f` to walk the graph, starting at node `x`.
100+
def detect_cycles(
101+
f: Callable[[Any], Iterable[Any] | None],
102+
node_list: Iterable[Any],
103+
) -> None:
104+
"""Detect cycles in a graph, using function `f` to walk the graph.
105+
106+
Input is supplied as a list of nodes that are used to populate the `todo` stack.
101107
102108
Uses the classic white/grey/black colour coding algorithm to track which nodes have been explored. In this
103109
case, "node" refers to any element in a schema and "neighbours" are elements that can be reached from that
@@ -107,21 +113,28 @@ def detect_cycles(f: Callable[[Any], Iterable[Any] | None], x: Any) -> None:
107113
GREY: node is being processed; processing includes exploring all neighbours reachable via f(node)
108114
BLACK: node and all of its neighbours (and their neighbours, etc.) have been processed
109115
110-
A directed cycle reachable from node `x` raises a ValueError.
116+
A directed cycle reachable from a node or its neighbours raises a ValueError.
111117
112118
:param f: function that returns an iterable of neighbouring nodes (parents or children)
113119
:type f: Callable[[Any], Iterable[Any] | None]
114-
:param x: graph node
115-
:type x: Any
116-
:raises ValueError: if a cycle is discovered through repeated calls to f(x)
120+
:param node_list: list or other iterable of values to process
121+
:type node_list: Iterable[Any]
122+
:raises ValueError: if a cycle is discovered through repeated calls to f(node)
117123
"""
124+
# ensure we have some nodes to start the analysis
125+
if not node_list or not isinstance(node_list, Iterable) or isinstance(node_list, str):
126+
err_msg = "detect_cycles requires a list of values to process"
127+
raise ValueError(err_msg)
128+
118129
# keep track of the processing state of nodes in the graph
119130
processing_state: dict[Any, int] = {}
120131

121132
# Stack entries are (node, processed_flag).
122133
# processed_flag == True means all neighbours (nodes generated by running `f(node)`)
123134
# have been added to the todo stack and the node can be marked BLACK.
124-
todo: list[tuple[Any, bool]] = [(x, False)]
135+
136+
# initialise the todo stack with entries set to False
137+
todo: list[tuple[Any, bool]] = [(node, False) for node in node_list]
125138

126139
while todo:
127140
node, processed_flag = todo.pop()
@@ -173,7 +186,7 @@ def _closure(
173186
:rtype: list[str | ElementName | ClassDefinitionName | EnumDefinitionName | SlotDefinitionName | TypeDefinitionName]
174187
"""
175188
if kwargs and kwargs.get("detect_cycles"):
176-
detect_cycles(f, x)
189+
detect_cycles(f, [x])
177190

178191
rv = [x] if reflexive else []
179192
visited = []

tests/test_utils/test_schemaview.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2981,12 +2981,25 @@ def test_class_name_mappings() -> None:
29812981
assert {snm_def.name: snm for snm, snm_def in view.slot_name_mappings().items()} == slot_names
29822982

29832983

2984+
"""
2985+
Tests of the detect_cycles function, which can identify cyclic relationships between classes, types, and other schema elements.
2986+
"""
2987+
2988+
2989+
@pytest.mark.parametrize("dodgy_input", [None, [], set(), {}, 12345, 123.45, "some string", ()])
2990+
def test_detect_cycles_input_error(dodgy_input: Any) -> None:
2991+
"""Ensure that `detect_cycles` throws an error if input is not supplied in the appropriate form."""
2992+
with pytest.raises(ValueError, match="detect_cycles requires a list of values to process"):
2993+
detect_cycles(lambda x: x, dodgy_input)
2994+
2995+
29842996
@pytest.fixture(scope="module")
29852997
def sv_cycles_schema() -> SchemaView:
29862998
"""A schema containing cycles!"""
29872999
return SchemaView(INPUT_DIR_PATH / "cycles.yaml")
29883000

29893001

3002+
# metadata for elements in the `sv_cycles_schema`
29903003
CYCLES = {
29913004
TYPES: {
29923005
# types in cycles, either directly or via ancestors
@@ -3028,8 +3041,8 @@ def sv_cycles_schema() -> SchemaView:
30283041
# key: class name, value: class ancestors
30293042
1: {
30303043
"BaseClass": {"BaseClass"},
3031-
"MixinA": {"MixinA"},
3032-
"MixinB": {"MixinB"},
3044+
"MixinA": {"MixinA"}, # no ID slot
3045+
"MixinB": {"MixinB"}, # no ID slot
30333046
"NonCycleClassA": {"NonCycleClassA", "BaseClass"},
30343047
"NonCycleClassB": {"MixinA", "NonCycleClassB", "NonCycleClassA", "BaseClass"},
30353048
"NonCycleClassC": {"MixinB", "NonCycleClassC", "NonCycleClassA", "BaseClass"},
@@ -3048,10 +3061,10 @@ def test_detect_type_cycles_error(sv_cycles_schema: SchemaView, target: str, cyc
30483061
"""Test detection of cycles in the types segment of the cycles schema."""
30493062
if fn == "detect_cycles":
30503063
with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"):
3051-
detect_cycles(lambda x: sv_cycles_schema.type_parents(x), target)
3064+
detect_cycles(sv_cycles_schema.type_parents, [target])
30523065
elif fn == "graph_closure":
30533066
with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"):
3054-
graph_closure(lambda x: sv_cycles_schema.type_parents(x), target, detect_cycles=True)
3067+
graph_closure(sv_cycles_schema.type_parents, target, detect_cycles=True)
30553068
else:
30563069
with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"):
30573070
sv_cycles_schema.type_ancestors(type_name=target, detect_cycles=True)
@@ -3062,9 +3075,9 @@ def test_detect_type_cycles_error(sv_cycles_schema: SchemaView, target: str, cyc
30623075
def test_detect_type_cycles_no_cycles(sv_cycles_schema: SchemaView, target: str, expected: set[str], fn: str) -> None:
30633076
"""Ensure that types without cycles in their ancestry do not throw an error."""
30643077
if fn == "detect_cycles":
3065-
detect_cycles(lambda x: sv_cycles_schema.type_parents(x), target)
3078+
detect_cycles(sv_cycles_schema.type_parents, [target])
30663079
elif fn == "graph_closure":
3067-
got = graph_closure(lambda x: sv_cycles_schema.type_parents(x), target, detect_cycles=True)
3080+
got = graph_closure(sv_cycles_schema.type_parents, target, detect_cycles=True)
30683081
assert set(got) == expected
30693082
else:
30703083
got = sv_cycles_schema.type_ancestors(target, detect_cycles=True)
@@ -3077,11 +3090,11 @@ def test_detect_class_cycles_error(sv_cycles_schema: SchemaView, target: str, cy
30773090
"""Test detection of class cycles in the cycles schema."""
30783091
if fn == "detect_cycles":
30793092
with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"):
3080-
detect_cycles(lambda x: sv_cycles_schema.class_parents(x), target)
3093+
detect_cycles(sv_cycles_schema.class_parents, [target])
30813094

30823095
elif fn == "graph_closure":
30833096
with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"):
3084-
graph_closure(lambda x: sv_cycles_schema.class_parents(x), target, detect_cycles=True)
3097+
graph_closure(sv_cycles_schema.class_parents, target, detect_cycles=True)
30853098
else:
30863099
with pytest.raises(ValueError, match=f"Cycle detected at node '{cycle_start_node}'"):
30873100
sv_cycles_schema.class_ancestors(target, detect_cycles=True)
@@ -3092,9 +3105,9 @@ def test_detect_class_cycles_error(sv_cycles_schema: SchemaView, target: str, cy
30923105
def test_detect_class_cycles_no_cycles(sv_cycles_schema: SchemaView, target: str, expected: set[str], fn: str) -> None:
30933106
"""Ensure that classes without cycles in their ancestry do not throw an error."""
30943107
if fn == "detect_cycles":
3095-
detect_cycles(lambda x: sv_cycles_schema.class_parents(x), target)
3108+
detect_cycles(sv_cycles_schema.class_parents, [target])
30963109
elif fn == "graph_closure":
3097-
got = graph_closure(lambda x: sv_cycles_schema.class_parents(x), target, detect_cycles=True)
3110+
got = graph_closure(sv_cycles_schema.class_parents, target, detect_cycles=True)
30983111
assert set(got) == expected
30993112
else:
31003113
got = sv_cycles_schema.class_ancestors(target, detect_cycles=True)
@@ -3116,10 +3129,10 @@ def check_recursive_id_slots(class_name: str) -> list[str]:
31163129
# classes with a cycle in the class identifier slot range are cunningly named
31173130
if "IdentifierCycle" in target:
31183131
with pytest.raises(ValueError, match="Cycle detected at node "):
3119-
detect_cycles(lambda x: check_recursive_id_slots(x), target)
3132+
detect_cycles(check_recursive_id_slots, [target])
31203133

31213134
else:
3122-
detect_cycles(lambda x: check_recursive_id_slots(x), target)
3135+
detect_cycles(check_recursive_id_slots, [target])
31233136

31243137

31253138
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)