Skip to content

Commit 8210da3

Browse files
committed
feat: New callback
1 parent 35f4bed commit 8210da3

37 files changed

+308
-527
lines changed

docs/actions.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ StateMachine in execution.
1616
There are callbacks that you can specify that are generic and will be called
1717
when something changes and are not bounded to a specific state or event:
1818

19+
- `prepare_event()`
20+
1921
- `before_transition()`
2022

2123
- `on_exit_state()`
@@ -297,6 +299,32 @@ In addition to {ref}`actions`, you can specify {ref}`validators and guards` that
297299
See {ref}`conditions` and {ref}`validators`.
298300
```
299301

302+
### Preparing events
303+
304+
You can use the `prepare_event` method to add custom information
305+
that will be included in `**kwargs` to all other callbacks.
306+
307+
A not so usefull example:
308+
309+
```py
310+
>>> class ExampleStateMachine(StateMachine):
311+
... initial = State(initial=True)
312+
...
313+
... loop = initial.to.itself()
314+
...
315+
... def prepare_event(self):
316+
... return {"foo": "bar"}
317+
...
318+
... def on_loop(self, foo):
319+
... return f"On loop: {foo}"
320+
...
321+
322+
>>> sm = ExampleStateMachine()
323+
324+
>>> sm.loop()
325+
'On loop: bar'
326+
327+
```
300328

301329
## Ordering
302330

@@ -314,6 +342,10 @@ Actions registered on the same group don't have order guaranties and are execute
314342
- Action
315343
- Current state
316344
- Description
345+
* - Preparation
346+
- `prepare_event()`
347+
- `source`
348+
- Add custom event metadata.
317349
* - Validators
318350
- `validators()`
319351
- `source`

docs/releases/3.0.0.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,42 @@ StateMachine 3.0.0 supports Python 3.9, 3.10, 3.11, 3.12, and 3.13.
1414

1515
Now a condition can check if the state machine current set of active states (a.k.a `configuration`) contains a state using the syntax `cond="In('<state-id>')"`.
1616

17+
### Preparing events
18+
19+
You can use the `prepare_event` method to add custom information
20+
that will be included in `**kwargs` to all other callbacks.
21+
22+
A not so usefull example:
23+
24+
```py
25+
>>> class ExampleStateMachine(StateMachine):
26+
... initial = State(initial=True)
27+
...
28+
... loop = initial.to.itself()
29+
...
30+
... def prepare_event(self):
31+
... return {"foo": "bar"}
32+
...
33+
... def on_loop(self, foo):
34+
... return f"On loop: {foo}"
35+
...
36+
37+
>>> sm = ExampleStateMachine()
38+
39+
>>> sm.loop()
40+
'On loop: bar'
41+
42+
```
43+
44+
### Event matching following SCXML spec
45+
46+
Now events matching follows the SCXML spec.
47+
48+
For example, a transition with an `event` attribute of `"error foo"` will match event names `error`, `error.send`, `error.send.failed`, etc. (or `foo`, `foo.bar` etc.)
49+
but would not match events named `errors.my.custom`, `errorhandler.mistake`, `error.send` or `foobar`.
50+
51+
An event designator consisting solely of `*` can be used as a wildcard matching any sequence of tokens, and thus any event.
52+
1753

1854
## Bugfixes in 3.0.0
1955

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ dev = [
5353
"sphinx-copybutton >=0.5.2",
5454
"pdbr>=0.8.9",
5555
"pytest-xdist>=3.6.1",
56+
"pytest-timeout>=2.3.1",
5657
]
5758

5859
[build-system]

statemachine/callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class SpecReference(IntFlag):
4343

4444

4545
class CallbackGroup(IntEnum):
46+
PREPARE = auto()
4647
ENTER = auto()
4748
EXIT = auto()
4849
VALIDATOR = auto()

statemachine/engines/base.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -103,26 +103,13 @@ def start(self):
103103
if self.sm.current_state_value is not None:
104104
return
105105

106-
BoundEvent("__initial__", _sm=self.sm).put(machine=self.sm)
106+
BoundEvent("__initial__", _sm=self.sm).put()
107107

108108
def _initial_transition(self, trigger_data):
109109
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
110110
transition._specs.clear()
111111
return transition
112112

113-
def select_eventless_transitions(self, trigger_data: TriggerData):
114-
"""
115-
Select the eventless transitions that match the trigger data.
116-
"""
117-
return self._select_transitions(trigger_data, lambda t, _e: t.is_eventless)
118-
119-
def _conditions_match(self, transition: Transition, trigger_data: TriggerData):
120-
event_data = EventData(trigger_data=trigger_data, transition=transition)
121-
args, kwargs = event_data.args, event_data.extended_kwargs
122-
123-
self.sm._callbacks.call(transition.validators.key, *args, **kwargs)
124-
return self.sm._callbacks.all(transition.cond.key, *args, **kwargs)
125-
126113
def _filter_conflicting_transitions(
127114
self, transitions: OrderedSet[Transition]
128115
) -> OrderedSet[Transition]:
@@ -234,6 +221,12 @@ def get_effective_target_states(self, transition: Transition) -> OrderedSet[Stat
234221
# TODO: Handle history states
235222
return OrderedSet([transition.target])
236223

224+
def select_eventless_transitions(self, trigger_data: TriggerData):
225+
"""
226+
Select the eventless transitions that match the trigger data.
227+
"""
228+
return self._select_transitions(trigger_data, lambda t, _e: t.is_eventless)
229+
237230
def select_transitions(self, trigger_data: TriggerData) -> OrderedSet[Transition]:
238231
"""
239232
Select the transitions that match the trigger data.
@@ -297,6 +290,27 @@ def microstep(self, transitions: List[Transition], trigger_data: TriggerData):
297290

298291
return result
299292

293+
def _get_args_kwargs(
294+
self, transition: Transition, trigger_data: TriggerData, set_target_as_state: bool = False
295+
):
296+
# TODO: Ideally this method should be called only once per microstep/transition
297+
event_data = EventData(trigger_data=trigger_data, transition=transition)
298+
if set_target_as_state:
299+
event_data.state = transition.target
300+
301+
args, kwargs = event_data.args, event_data.extended_kwargs
302+
303+
result = self.sm._callbacks.call(self.sm.prepare.key, *args, **kwargs)
304+
for new_kwargs in result:
305+
kwargs.update(new_kwargs)
306+
return args, kwargs
307+
308+
def _conditions_match(self, transition: Transition, trigger_data: TriggerData):
309+
args, kwargs = self._get_args_kwargs(transition, trigger_data)
310+
311+
self.sm._callbacks.call(transition.validators.key, *args, **kwargs)
312+
return self.sm._callbacks.all(transition.cond.key, *args, **kwargs)
313+
300314
def _exit_states(self, enabled_transitions: List[Transition], trigger_data: TriggerData):
301315
"""Compute and process the states to exit for the given transitions."""
302316
states_to_exit = self._compute_exit_set(enabled_transitions)
@@ -309,8 +323,7 @@ def _exit_states(self, enabled_transitions: List[Transition], trigger_data: Trig
309323
# states_to_exit = sorted(states_to_exit, key=self.exit_order)
310324

311325
for info in states_to_exit:
312-
event_data = EventData(trigger_data=trigger_data, transition=info.transition)
313-
args, kwargs = event_data.args, event_data.extended_kwargs
326+
args, kwargs = self._get_args_kwargs(info.transition, trigger_data)
314327

315328
# # TODO: Update history
316329
# for history in state.history:
@@ -342,10 +355,9 @@ def _execute_transition_content(
342355
):
343356
result = []
344357
for transition in enabled_transitions:
345-
event_data = EventData(trigger_data=trigger_data, transition=transition)
346-
if set_target_as_state:
347-
event_data.state = transition.target
348-
args, kwargs = event_data.args, event_data.extended_kwargs
358+
args, kwargs = self._get_args_kwargs(
359+
transition, trigger_data, set_target_as_state=set_target_as_state
360+
)
349361

350362
result += self.sm._callbacks.call(get_key(transition), *args, **kwargs)
351363

@@ -379,14 +391,14 @@ def _enter_states(
379391
)
380392

381393
# Sort states to enter in entry order
382-
# for state in sorted(states_to_enter, key=self.entry_order): # TODO: ordegin of states_to_enter # noqa: E501
394+
# for state in sorted(states_to_enter, key=self.entry_order): # TODO: order of states_to_enter # noqa: E501
383395
for info in states_to_enter:
384396
target = info.target
385397
assert target
386398
transition = info.transition
387-
event_data = EventData(trigger_data=trigger_data, transition=transition)
388-
event_data.state = target
389-
args, kwargs = event_data.args, event_data.extended_kwargs
399+
args, kwargs = self._get_args_kwargs(
400+
transition, trigger_data, set_target_as_state=True
401+
)
390402

391403
# Add state to the configuration
392404
# self.sm.configuration |= {target}
@@ -400,7 +412,6 @@ def _enter_states(
400412
# state.is_first_entry = False
401413

402414
# Execute `onentry` handlers
403-
# TODO: if not transition.internal:
404415
self.sm._callbacks.call(target.enter.key, *args, **kwargs)
405416

406417
# Handle default initial states
@@ -420,18 +431,15 @@ def _enter_states(
420431
parent = target.parent
421432
grandparent = parent.parent
422433

423-
self.internal_queue.put(
424-
BoundEvent(f"done.state.{parent.id}", _sm=self.sm).build_trigger(
425-
machine=self.sm
426-
)
427-
)
434+
BoundEvent(
435+
f"done.state.{parent.id}",
436+
_sm=self.sm,
437+
internal=True,
438+
).put()
439+
428440
if grandparent.parallel:
429441
if all(child.final for child in grandparent.states):
430-
self.internal_queue.put(
431-
BoundEvent(f"done.state.{parent.id}", _sm=self.sm).build_trigger(
432-
machine=self.sm
433-
)
434-
)
442+
BoundEvent(f"done.state.{parent.id}", _sm=self.sm, internal=True).put()
435443

436444
def compute_entry_set(
437445
self, transitions, states_to_enter, states_for_default_entry, default_history_content
@@ -521,7 +529,6 @@ def add_descendant_states_to_enter(
521529
assert state
522530

523531
if state.parallel:
524-
# Handle parallel states
525532
for child_state in state.states:
526533
if not any(s.target.is_descendant(child_state) for s in states_to_enter):
527534
info_to_add = StateTransition(
@@ -536,7 +543,6 @@ def add_descendant_states_to_enter(
536543
default_history_content,
537544
)
538545
elif state.is_compound:
539-
# Handle compound states
540546
states_for_default_entry.add(info)
541547
initial_state = next(s for s in state.states if s.initial)
542548
transition = next(

statemachine/engines/sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def processing_loop(self): # noqa: C901
113113
sleep(0.001)
114114
continue
115115

116-
logger.debug("External event: %s", external_event)
116+
logger.debug("External event: %s", external_event.event)
117117
# # TODO: Handle cancel event
118118
# if self.is_cancel_event(external_event):
119119
# self.running = False

statemachine/event.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import TYPE_CHECKING
22
from typing import List
3+
from typing import cast
34
from uuid import uuid4
45

56
from .callbacks import CallbackGroup
@@ -118,13 +119,14 @@ def __get__(self, instance, owner):
118119
return self
119120
return BoundEvent(id=self.id, name=self.name, delay=self.delay, _sm=instance)
120121

121-
def put(self, *args, machine: "StateMachine", send_id: "str | None" = None, **kwargs):
122+
def put(self, *args, send_id: "str | None" = None, **kwargs):
122123
# The `__call__` is declared here to help IDEs knowing that an `Event`
123124
# can be called as a method. But it is not meant to be called without
124125
# an SM instance. Such SM instance is provided by `__get__` method when
125126
# used as a property descriptor.
126-
trigger_data = self.build_trigger(*args, machine=machine, send_id=send_id, **kwargs)
127-
machine._put_nonblocking(trigger_data, internal=self.internal)
127+
assert self._sm is not None
128+
trigger_data = self.build_trigger(*args, machine=self._sm, send_id=send_id, **kwargs)
129+
self._sm._put_nonblocking(trigger_data, internal=self.internal)
128130
return trigger_data
129131

130132
def build_trigger(
@@ -154,9 +156,8 @@ def __call__(self, *args, **kwargs):
154156
# can be called as a method. But it is not meant to be called without
155157
# an SM instance. Such SM instance is provided by `__get__` method when
156158
# used as a property descriptor.
157-
machine = self._sm
158-
self.put(*args, machine=machine, **kwargs)
159-
return machine._processing_loop()
159+
self.put(*args, **kwargs)
160+
return self._sm._processing_loop()
160161

161162
def split( # type: ignore[override]
162163
self, sep: "str | None" = None, maxsplit: int = -1
@@ -167,7 +168,31 @@ def split( # type: ignore[override]
167168
return [Event(event) for event in result]
168169

169170
def match(self, event: str) -> bool:
170-
return self == event or self == "*"
171+
if self == "*":
172+
return True
173+
174+
# Normalize descriptor by removing trailing '.*' or '.'
175+
# to handle cases like 'error', 'error.', 'error.*'
176+
descriptor = cast(str, self)
177+
if descriptor.endswith(".*"):
178+
descriptor = descriptor[:-2]
179+
elif descriptor.endswith("."):
180+
descriptor = descriptor[:-1]
181+
182+
# Check prefix match:
183+
# The descriptor must be a prefix of the event.
184+
# Split both descriptor and event into tokens
185+
descriptor_tokens = descriptor.split(".") if descriptor else []
186+
event_tokens = event.split(".") if event else []
187+
188+
if len(descriptor_tokens) > len(event_tokens):
189+
return False
190+
191+
for d_token, e_token in zip(descriptor_tokens, event_tokens): # noqa: B905
192+
if d_token != e_token:
193+
return False
194+
195+
return True
171196

172197

173198
class BoundEvent(Event):

statemachine/factory.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from typing import Tuple
77

88
from . import registry
9+
from .callbacks import CallbackGroup
10+
from .callbacks import CallbackPriority
11+
from .callbacks import CallbackSpecList
912
from .event import Event
1013
from .exceptions import InvalidDefinition
1114
from .graph import iterate_states
@@ -43,7 +46,10 @@ def __init__(
4346
cls._events: Dict[Event, None] = {} # used Dict to preserve order and avoid duplicates
4447
cls._protected_attrs: set = set()
4548
cls._events_to_update: Dict[Event, Event | None] = {}
46-
49+
cls._specs = CallbackSpecList()
50+
cls.prepare = cls._specs.grouper(CallbackGroup.PREPARE).add(
51+
"prepare_event", priority=CallbackPriority.GENERIC, is_convention=True
52+
)
4753
cls.add_inherited(bases)
4854
cls.add_from_attributes(attrs)
4955
cls._unpack_builders_callbacks()

0 commit comments

Comments
 (0)