11import re
22import xml .etree .ElementTree as ET
3+ from typing import Set
34
45from .schema import Action
56from .schema import AssignAction
@@ -32,6 +33,12 @@ def strip_namespaces(tree: ET.Element):
3233 attrib [new_name ] = attrib .pop (name )
3334
3435
36+ def _parse_initial (initial_content : "str | None" ) -> Set [str ]:
37+ if initial_content is None :
38+ return set ()
39+ return set (initial_content .split ())
40+
41+
3542def parse_scxml (scxml_content : str ) -> StateMachineDefinition :
3643 root = ET .fromstring (scxml_content )
3744 strip_namespaces (root )
@@ -40,9 +47,9 @@ def parse_scxml(scxml_content: str) -> StateMachineDefinition:
4047 if scxml is None :
4148 raise ValueError ("No scxml element found in document" )
4249
43- initial_state = scxml .get ("initial" )
50+ initial_state = _parse_initial ( scxml .get ("initial" ) )
4451
45- definition = StateMachineDefinition (initial_state = initial_state )
52+ definition = StateMachineDefinition (initial_states = initial_state )
4653
4754 # Parse datamodel
4855 datamodel = parse_datamodel (scxml )
@@ -52,19 +59,19 @@ def parse_scxml(scxml_content: str) -> StateMachineDefinition:
5259 # Parse states
5360 for state_elem in scxml :
5461 if state_elem .tag == "state" :
55- state = parse_state (state_elem , definition .initial_state )
62+ state = parse_state (state_elem , definition .initial_states )
5663 definition .states [state .id ] = state
5764 elif state_elem .tag == "final" :
58- state = parse_state (state_elem , definition .initial_state , is_final = True )
65+ state = parse_state (state_elem , definition .initial_states , is_final = True )
5966 definition .states [state .id ] = state
6067 elif state_elem .tag == "parallel" :
61- state = parse_state (state_elem , definition .initial_state , is_parallel = True )
68+ state = parse_state (state_elem , definition .initial_states , is_parallel = True )
6269 definition .states [state .id ] = state
6370
6471 # If no initial state was specified, pick the first state
65- if not definition .initial_state and definition .states :
66- definition .initial_state = next (iter (definition .states .keys ()))
67- definition .states [definition .initial_state ].initial = True
72+ if not definition .initial_states and definition .states :
73+ definition .initial_states = next (iter (definition .states .keys ()))
74+ definition .states [definition .initial_states ].initial = True
6875
6976 return definition
7077
@@ -95,15 +102,15 @@ def parse_datamodel(root: ET.Element) -> "DataModel | None":
95102
96103def parse_state (
97104 state_elem : ET .Element ,
98- initial_state : " str | None" ,
105+ initial_states : Set [ str ] ,
99106 is_final : bool = False ,
100107 is_parallel : bool = False ,
101108) -> State :
102109 state_id = state_elem .get ("id" )
103110 if not state_id :
104111 raise ValueError ("State must have an 'id' attribute" )
105112
106- initial = state_id == initial_state
113+ initial = state_id in initial_states
107114 state = State (id = state_id , initial = initial , final = is_final , parallel = is_parallel )
108115
109116 # Parse onentry actions
@@ -122,18 +129,25 @@ def parse_state(
122129 state .transitions .append (transition )
123130
124131 # Parse child states
125- initial_state = state_elem .get ("initial" )
132+ initial_states |= _parse_initial (state_elem .get ("initial" ))
133+ initial_elem = state_elem .find ("initial" )
134+ if initial_elem is not None :
135+ for trans_elem in initial_elem .findall ("transition" ):
136+ transition = parse_transition (trans_elem , initial = True )
137+ state .transitions .append (transition )
138+ initial_states |= _parse_initial (trans_elem .get ("target" ))
139+
126140 for child_state_elem in state_elem .findall ("state" ):
127- child_state = parse_state (child_state_elem , initial_state = initial_state )
141+ child_state = parse_state (child_state_elem , initial_states = initial_states )
128142 state .states [child_state .id ] = child_state
129143 for child_state_elem in state_elem .findall ("parallel" ):
130- state = parse_state (child_state_elem , initial_state = initial_state , is_parallel = True )
144+ state = parse_state (child_state_elem , initial_states = initial_states , is_parallel = True )
131145 state .states [child_state .id ] = child_state
132146
133147 return state
134148
135149
136- def parse_transition (trans_elem : ET .Element ) -> Transition :
150+ def parse_transition (trans_elem : ET .Element , initial : bool = False ) -> Transition :
137151 target = trans_elem .get ("target" )
138152 if not target :
139153 raise ValueError ("Transition must have a 'target' attribute" )
@@ -143,7 +157,9 @@ def parse_transition(trans_elem: ET.Element) -> Transition:
143157
144158 executable_content = parse_executable_content (trans_elem )
145159
146- return Transition (target = target , event = event , cond = cond , on = executable_content )
160+ return Transition (
161+ target = target , initial = initial , event = event , cond = cond , on = executable_content
162+ )
147163
148164
149165def parse_executable_content (element : ET .Element ) -> ExecutableContent :
0 commit comments