Skip to content

Commit 2d1dd01

Browse files
committed
test: Nested states with general syntax
1 parent 932cdfc commit 2d1dd01

File tree

5 files changed

+40
-36
lines changed

5 files changed

+40
-36
lines changed

statemachine/factory.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ def _check(cls):
6868
if not cls.states:
6969
raise InvalidDefinition(_("There are no states."))
7070

71-
# TODO: Validate no events if has nested states
72-
# if not cls._events:
73-
# raise InvalidDefinition(_("There are no events."))
71+
if not cls._events:
72+
raise InvalidDefinition(_("There are no events."))
7473

7574
cls._check_disconnected_state()
7675

@@ -117,8 +116,9 @@ def _add_unbounded_callback(cls, attr_name, func):
117116

118117
def add_state(cls, id, state):
119118
state._set_id(id)
120-
cls.states.append(state)
121-
cls.states_map[state.value] = state
119+
if not state.parent:
120+
cls.states.append(state)
121+
cls.states_map[state.value] = state
122122

123123
# also register all events associated directly with transitions
124124
for event in state.transitions.unique_events:

statemachine/state.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Optional # noqa: F401, I001
1+
from typing import Any
2+
from typing import TypeAlias
23
from copy import deepcopy
34

45
from .callbacks import Callbacks
@@ -9,19 +10,22 @@
910

1011

1112
class NestedStateFactory(type):
12-
def __new__(cls, classname, bases, attrs, name=None, initial=False, parallel=False):
13+
def __new__( # type: ignore [misc]
14+
cls, classname, bases, attrs, name=None, initial=False, parallel=False
15+
) -> "State":
1316

1417
if not bases:
15-
return super().__new__(cls, classname, bases, attrs)
18+
return super().__new__(cls, classname, bases, attrs) # type: ignore [return-value]
1619

1720
substates = []
1821
for key, value in attrs.items():
19-
if not isinstance(value, State):
20-
continue
21-
value._set_id(key)
22-
substates.append(value)
22+
if isinstance(value, State):
23+
value._set_id(key)
24+
substates.append(value)
25+
if isinstance(value, TransitionList):
26+
value.add_event(key)
2327

24-
return State(name, initial=initial, parallel=parallel, substates=substates)
28+
return State(name=name, initial=initial, parallel=parallel, substates=substates)
2529

2630

2731
class NestedStateBuilder(metaclass=NestedStateFactory):
@@ -94,37 +98,37 @@ class State:
9498
9599
"""
96100

97-
Builder = NestedStateBuilder
101+
Builder: TypeAlias = NestedStateBuilder
98102

99103
def __init__(
100104
self,
101-
name,
102-
value=None,
103-
initial=False,
104-
final=False,
105-
parallel=False,
106-
substates=None,
107-
enter=None,
108-
exit=None,
105+
name: str = "",
106+
value: Any = None,
107+
initial: bool = False,
108+
final: bool = False,
109+
parallel: bool = False,
110+
substates: Any = None,
111+
enter: Any = None,
112+
exit: Any = None,
109113
):
110-
# type: (str, Optional[Any], bool, bool, bool, Optional[Any], Optional[Any], Optional[Any]) -> None # noqa
111114
self.name = name
112115
self.value = value
113116
self.parallel = parallel
114-
self.parent: "State" = None
115117
self.substates = substates or []
116-
self._id = None # type: Optional[str]
117-
self._storage = ""
118118
self._initial = initial
119-
self.transitions = TransitionList()
120119
self._final = final
120+
self._id: str = ""
121+
self._storage: str = ""
122+
self.parent: "State" = None
123+
self.transitions = TransitionList()
121124
self.enter = Callbacks().add(enter)
122125
self.exit = Callbacks().add(exit)
123126
self._init_substates()
124127

125128
def _init_substates(self):
126129
for substate in self.substates:
127130
substate.parent = self
131+
setattr(self, substate.id, substate)
128132

129133
def __eq__(self, other):
130134
return (
@@ -182,6 +186,8 @@ def _set_id(self, id):
182186
self._storage = f"_{id}"
183187
if self.value is None:
184188
self.value = id
189+
if not self.name:
190+
self.name = self._id.replace("_", " ").capitalize()
185191

186192
def _to_(self, *states, **kwargs):
187193
transitions = TransitionList(

statemachine/transition_list.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections import OrderedDict
2-
31
from .utils import ensure_iterable
42

53

@@ -57,10 +55,9 @@ def add_event(self, event):
5755

5856
@property
5957
def unique_events(self):
60-
# Compat Python2.7: Using OrderedDict to get a unique ordered list
61-
tmp_list = OrderedDict()
58+
tmp_ordered_unique_events_as_keys_on_dict = {}
6259
for transition in self.transitions:
6360
for event in transition.events:
64-
tmp_list[event] = True
61+
tmp_ordered_unique_events_as_keys_on_dict[event] = True
6562

66-
return list(tmp_list.keys())
63+
return list(tmp_ordered_unique_events_as_keys_on_dict.keys())

tests/examples/microwave_inheritance_machine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class on(State.Builder, name="On"):
5050
cooking.to(idle, cond="open.is_active")
5151
cooking.to.itself(internal=True, on="increment_timer")
5252

53+
assert isinstance(on, State) # so mypy stop complaining
5354
turn_off = on.to(off)
5455
turn_on = off.to(on)
5556
on.to(off, cond="cook_time_is_over") # eventless transition

tests/test_compound.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ class engine(State.Builder, name="Engine", initial=True):
1111
off = State("Off", initial=True)
1212
on = State("On")
1313

14-
turn_off = on.to(off)
1514
turn_on = off.to(on)
15+
turn_off = on.to(off)
1616

1717
return TestMachine
1818

@@ -26,8 +26,8 @@ def test_capture_constructor_arguments(self, compound_engine_cls):
2626

2727
def test_list_children_states(self, compound_engine_cls):
2828
sm = compound_engine_cls()
29-
assert [s.id for s in sm.engine.children] == ["off", "on"]
29+
assert [s.id for s in sm.engine.substates] == ["off", "on"]
3030

3131
def test_list_events(self, compound_engine_cls):
3232
sm = compound_engine_cls()
33-
assert [e.name for e in sm.events] == ["turn_off", "turn_on"]
33+
assert [e.name for e in sm.events] == ["turn_on", "turn_off"]

0 commit comments

Comments
 (0)