Skip to content

Commit 43c711d

Browse files
committed
Merge branch 'hotfix/fix-concurrency'
2 parents abbd83c + 6e3c69b commit 43c711d

File tree

5 files changed

+134
-16
lines changed

5 files changed

+134
-16
lines changed

statemachine/state.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# coding: utf-8
22
import warnings
3+
from copy import deepcopy
34
from typing import Any
45
from typing import Optional
56
from typing import Text
@@ -84,15 +85,26 @@ def __init__(
8485
self.name = name
8586
self.value = value
8687
self._id = None # type: Optional[Text]
88+
self._storage = ""
8789
self._initial = initial
8890
self.transitions = TransitionList()
8991
self._final = final
9092
self.enter = Callbacks().add(enter)
9193
self.exit = Callbacks().add(exit)
9294

93-
def _setup(self, resolver):
95+
def __eq__(self, other):
96+
return (
97+
isinstance(other, State) and self.name == other.name and self.id == other.id
98+
)
99+
100+
def __hash__(self):
101+
return hash(repr(self))
102+
103+
def _setup(self, machine, resolver):
104+
self.machine = machine
94105
self.enter.setup(resolver)
95106
self.exit.setup(resolver)
107+
machine.__dict__[self._storage] = self
96108

97109
def _add_observer(self, *resolvers):
98110
for r in resolvers:
@@ -120,7 +132,8 @@ def __repr__(self):
120132
)
121133

122134
def __get__(self, machine, owner):
123-
self.machine = machine
135+
if machine and self._storage in machine.__dict__:
136+
return machine.__dict__[self._storage]
124137
return self
125138

126139
def __set__(self, instance, value):
@@ -130,6 +143,9 @@ def __set__(self, instance, value):
130143
)
131144
)
132145

146+
def clone(self):
147+
return deepcopy(self)
148+
133149
@property
134150
def id(self):
135151
return self._id
@@ -144,6 +160,7 @@ def identifier(self):
144160

145161
def _set_id(self, id):
146162
self._id = id
163+
self._storage = "_{}".format(id)
147164
if self.value is None:
148165
self.value = id
149166

statemachine/statemachine.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# coding: utf-8
22
import sys
33
import warnings
4+
from collections import OrderedDict
45
from typing import Any
56
from typing import Dict
67
from typing import List
@@ -30,7 +31,9 @@ def __init__(self, model=None, state_field="state", start_value=None):
3031
self.state_field = state_field
3132
self.start_value = start_value
3233

33-
initial_transition = Transition(None, None, event="__initial__")
34+
initial_transition = Transition(
35+
None, self._get_initial_state(), event="__initial__"
36+
)
3437
self._setup(initial_transition)
3538
self._activate_initial_state(initial_transition)
3639

@@ -42,21 +45,19 @@ def __repr__(self):
4245
self.current_state.id if self.current_state else None,
4346
)
4447

45-
def _activate_initial_state(self, initial_transition):
46-
48+
def _get_initial_state(self):
4749
current_state_value = (
4850
self.start_value if self.start_value else self.initial_state.value
4951
)
50-
if self.current_state_value is None:
51-
52-
try:
53-
initial_state = self.states_map[current_state_value]
54-
except KeyError:
55-
raise InvalidStateValue(current_state_value)
52+
try:
53+
return self.states_map[current_state_value]
54+
except KeyError:
55+
raise InvalidStateValue(current_state_value)
5656

57+
def _activate_initial_state(self, initial_transition):
58+
if self.current_state_value is None:
5759
# send an one-time event `__initial__` to enter the current state.
5860
# current_state = self.current_state
59-
initial_transition.target = initial_state
6061
initial_transition.before.clear()
6162
initial_transition.on.clear()
6263
initial_transition.after.clear()
@@ -96,8 +97,22 @@ def _setup(self, initial_transition):
9697
model = ObjectConfig(self.model, skip_attrs={self.state_field})
9798
default_resolver = resolver_factory(machine, model)
9899

99-
initial_transition._setup(default_resolver)
100-
self._visit_states_and_transitions(lambda x: x._setup(default_resolver))
100+
# clone states and transitions to avoid sharing callbacks references between instances
101+
states = []
102+
self.states_map = OrderedDict()
103+
for state in self.states:
104+
new_state = state.clone()
105+
new_state._setup(self, default_resolver)
106+
states.append(new_state)
107+
self.states_map[new_state.value] = new_state
108+
109+
self.states = states
110+
111+
for state in self.states:
112+
for transition in state.transitions:
113+
transition._setup(self, default_resolver)
114+
115+
initial_transition._setup(self, default_resolver)
101116
self.add_observer(machine, model)
102117

103118
def add_observer(self, *observers):

statemachine/transition.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ def __repr__(self):
6161
type(self).__name__, self.source, self.target, self.event
6262
)
6363

64-
def _setup(self, resolver):
64+
def _upd_state_refs(self, machine):
65+
if self.source:
66+
self.source = machine.__dict__[self.source._storage]
67+
self.target = machine.__dict__[self.target._storage]
68+
69+
def _setup(self, machine, resolver):
70+
self._upd_state_refs(machine)
6571
self.validators.setup(resolver)
6672
self.cond.setup(resolver)
6773
self.before.setup(resolver)

tests/examples/order_control_rich_model_machine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class OrderControl(StateMachine):
6161
try:
6262
control = OrderControl()
6363
except AttrNotFound as e:
64-
assert str(e) == "Did not found name 'payment_received' from model or statemachine"
64+
assert str(e) == "Did not found name 'wait_for_payment' from model or statemachine"
6565

6666
# %%
6767
# Now initializing with a proper ``order`` instance.

tests/test_callbacks_isolation.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import pytest
2+
3+
from statemachine import State
4+
from statemachine import StateMachine
5+
6+
7+
@pytest.fixture()
8+
def simple_sm_cls():
9+
class TestStateMachine(StateMachine):
10+
# States
11+
initial = State("Initial", initial=True)
12+
final = State("Final", final=True, enter="do_enter_final")
13+
14+
finish = initial.to(final, cond="can_finish", on="do_finish")
15+
16+
def __init__(self, name):
17+
self.name = name
18+
self.can_finish = False
19+
self.finalized = False
20+
super(TestStateMachine, self).__init__()
21+
22+
def do_finish(self):
23+
return self.name, self.can_finish
24+
25+
def do_enter_final(self):
26+
self.finalized = True
27+
28+
return TestStateMachine
29+
30+
31+
class TestCallbacksIsolation:
32+
def test_should_conditions_be_isolated(self, simple_sm_cls):
33+
sm1 = simple_sm_cls("sm1")
34+
sm2 = simple_sm_cls("sm2")
35+
sm3 = simple_sm_cls("sm3")
36+
37+
sm1.can_finish = True
38+
assert sm1.initial.transitions[0].cond.call() == [True]
39+
assert sm2.initial.transitions[0].cond.call() == [False]
40+
assert sm3.initial.transitions[0].cond.call() == [False]
41+
42+
def test_should_actions_be_isolated(self, simple_sm_cls):
43+
sm1 = simple_sm_cls("sm1")
44+
sm2 = simple_sm_cls("sm2")
45+
46+
sm1.can_finish = True
47+
sm2.can_finish = True
48+
49+
sm1_initial = sm1.initial
50+
sm1_final = sm1.final
51+
52+
assert sm2.finish() == ("sm2", True)
53+
54+
assert not sm2.initial.is_active
55+
assert sm2.final.is_active
56+
assert sm2.finalized is True
57+
58+
assert sm1_initial.is_active
59+
assert not sm1_final.is_active
60+
assert sm1.finalized is False
61+
62+
assert sm1.initial.is_active
63+
assert not sm1.final.is_active
64+
65+
assert sm1.finish() == ("sm1", True)
66+
67+
assert sm1.finalized is True
68+
assert not sm1.initial.is_active
69+
assert sm1.final.is_active
70+
71+
def test_instance_states_and_transitions_are_isolated(self, simple_sm_cls):
72+
sm1 = simple_sm_cls("sm1")
73+
74+
assert sm1.initial == simple_sm_cls.initial
75+
assert sm1.initial is not simple_sm_cls.initial
76+
77+
assert repr(sm1.initial.transitions[0]) == repr(
78+
simple_sm_cls.initial.transitions[0]
79+
)
80+
assert sm1.initial.transitions[0] is not simple_sm_cls.initial.transitions[0]

0 commit comments

Comments
 (0)