11from typing import Any , Dict , List , Optional , Union
22from serverlessworkflow .sdk .action import Action
3- from serverlessworkflow .sdk .callback_state import CallbackState
43from serverlessworkflow .sdk .function_ref import FunctionRef
5- from serverlessworkflow .sdk .sleep_state import SleepState
4+ from serverlessworkflow .sdk .state_machine_extensions import (
5+ CustomGraphMachine ,
6+ CustomHierarchicalGraphMachine ,
7+ CustomHierarchicalMachine ,
8+ )
69from serverlessworkflow .sdk .transition import Transition
710from serverlessworkflow .sdk .workflow import (
811 State ,
12+ EventState ,
13+ SleepState ,
14+ CallbackState ,
915 DataBasedSwitchState ,
16+ InjectState ,
1017 EventBasedSwitchState ,
1118 ParallelState ,
1219 OperationState ,
@@ -27,7 +34,7 @@ class StateMachineGenerator:
2734 def __init__ (
2835 self ,
2936 state : State ,
30- state_machine : Union [HierarchicalMachine , GraphMachine ],
37+ state_machine : Union [CustomHierarchicalMachine , CustomGraphMachine ],
3138 subflows : List [Workflow ] = [],
3239 is_first_state = False ,
3340 get_actions = False ,
@@ -38,13 +45,20 @@ def __init__(
3845 self .get_actions = get_actions
3946 self .subflows = subflows
4047
41- if self .get_actions and not isinstance (self .state_machine , HierarchicalMachine ):
48+ if (
49+ self .get_actions
50+ and not isinstance (self .state_machine , CustomHierarchicalMachine )
51+ and not isinstance (self .state_machine , CustomHierarchicalGraphMachine )
52+ ):
4253 raise AttributeError (
43- "The provided state machine must be of the HierarchicalMachine type ."
54+ "The provided state machine must be of the CustomHierarchicalMachine or CustomHierarchicalGraphMachine types ."
4455 )
45- if not self .get_actions and isinstance (self .state_machine , HierarchicalMachine ):
56+ if not self .get_actions and (
57+ isinstance (self .state_machine , CustomHierarchicalMachine )
58+ or isinstance (self .state_machine , CustomHierarchicalGraphMachine )
59+ ):
4660 raise AttributeError (
47- "The provided state machine can not be of the HierarchicalMachine type ."
61+ "The provided state machine can not be of the CustomHierarchicalMachine or CustomHierarchicalGraphMachine types ."
4862 )
4963
5064 def generate (self ):
@@ -65,12 +79,7 @@ def transitions(self):
6579
6680 def start_transition (self ):
6781 if self .is_first_state :
68- state_name = self .state .name
69- if state_name not in self .state_machine .states .keys ():
70- self .state_machine .add_states (state_name )
71- self .state_machine ._initial = state_name
72- else :
73- self .state_machine ._initial = state_name
82+ self .state_machine ._initial = self .state .name
7483
7584 def data_conditions_transitions (self ):
7685 if isinstance (self .state , DataBasedSwitchState ):
@@ -153,7 +162,7 @@ def definitions(self):
153162 if state_type == "sleep" :
154163 self .sleep_state_details ()
155164 elif state_type == "event" :
156- pass
165+ self . event_state_details ()
157166 elif state_type == "operation" :
158167 self .operation_state_details ()
159168 elif state_type == "parallel" :
@@ -166,7 +175,7 @@ def definitions(self):
166175 else :
167176 raise Exception (f"Unexpected switch type;\n state value= { self .state } " )
168177 elif state_type == "inject" :
169- pass
178+ self . inject_state_details ()
170179 elif state_type == "foreach" :
171180 self .foreach_state_details ()
172181 elif state_type == "callback" :
@@ -178,10 +187,10 @@ def definitions(self):
178187
179188 def parallel_state_details (self ):
180189 if isinstance (self .state , ParallelState ):
181- if self .state .name not in self . state_machine . states . keys ():
182- self .state_machine .add_states ( self . state . name )
183- if self .is_first_state :
184- self .state_machine ._initial = self . state . name
190+ state_name = self .state .name
191+ if state_name not in self .state_machine .states . keys ():
192+ self .state_machine . add_states ( state_name )
193+ self .state_machine .get_state ( state_name ). tags = [ "parallel_state" ]
185194
186195 state_name = self .state .name
187196 branches = self .state .branches
@@ -192,42 +201,82 @@ def parallel_state_details(self):
192201 if hasattr (branch , "actions" ) and branch .actions :
193202 branch_name = branch .name
194203 self .state_machine .get_state (state_name ).add_substates (
195- NestedState (branch_name )
204+ branch_state := self .state_machine .state_cls (
205+ branch_name
206+ )
196207 )
197208 self .state_machine .get_state (state_name ).initial .append (
198209 branch_name
199210 )
200- branch_state = self .state_machine .get_state (
201- state_name
202- ).states [branch .name ]
211+ branch_state .tags = ["branch" ]
203212 self .generate_actions_info (
204213 machine_state = branch_state ,
205214 state_name = f"{ state_name } .{ branch_name } " ,
206215 actions = branch .actions ,
207216 )
208217
209- def event_based_switch_state_details (self ): ...
218+ def event_based_switch_state_details (self ):
219+ if isinstance (self .state , EventBasedSwitchState ):
220+ state_name = self .state .name
221+ if state_name not in self .state_machine .states .keys ():
222+ self .state_machine .add_states (state_name )
223+ self .state_machine .get_state (state_name ).tags = [
224+ "event_based_switch_state" ,
225+ "switch_state" ,
226+ ]
210227
211- def data_based_switch_state_details (self ): ...
228+ def data_based_switch_state_details (self ):
229+ if isinstance (self .state , DataBasedSwitchState ):
230+ state_name = self .state .name
231+ if state_name not in self .state_machine .states .keys ():
232+ self .state_machine .add_states (state_name )
233+ self .state_machine .get_state (state_name ).tags = [
234+ "data_based_switch_state" ,
235+ "switch_state" ,
236+ ]
212237
213- def operation_state_details (self ):
214- if self .state .name not in self .state_machine .states .keys ():
215- self .state_machine .add_states (self .state .name )
216- if self .is_first_state :
217- self .state_machine ._initial = self .state .name
238+ def inject_state_details (self ):
239+ if isinstance (self .state , InjectState ):
240+ state_name = self .state .name
241+ if state_name not in self .state_machine .states .keys ():
242+ self .state_machine .add_states (state_name )
243+ self .state_machine .get_state (state_name ).tags = ["inject_state" ]
218244
245+ def operation_state_details (self ):
219246 if isinstance (self .state , OperationState ):
247+ state_name = self .state .name
248+ if state_name not in self .state_machine .states .keys ():
249+ self .state_machine .add_states (state_name )
250+ (machine_state := self .state_machine .get_state (state_name )).tags = [
251+ "operation_state"
252+ ]
220253 self .generate_actions_info (
221- machine_state = self . state_machine . get_state ( self . state . name ) ,
254+ machine_state = machine_state ,
222255 state_name = self .state .name ,
223256 actions = self .state .actions ,
224257 action_mode = self .state .actionMode ,
225258 )
226259
227- def sleep_state_details (self ): ...
260+ def sleep_state_details (self ):
261+ if isinstance (self .state , SleepState ):
262+ state_name = self .state .name
263+ if state_name not in self .state_machine .states .keys ():
264+ self .state_machine .add_states (state_name )
265+ self .state_machine .get_state (state_name ).tags = ["sleep_state" ]
266+
267+ def event_state_details (self ):
268+ if isinstance (self .state , EventState ):
269+ state_name = self .state .name
270+ if state_name not in self .state_machine .states .keys ():
271+ self .state_machine .add_states (state_name )
272+ self .state_machine .get_state (state_name ).tags = ["event_state" ]
228273
229274 def foreach_state_details (self ):
230275 if isinstance (self .state , ForEachState ):
276+ state_name = self .state .name
277+ if state_name not in self .state_machine .states .keys ():
278+ self .state_machine .add_states (state_name )
279+ self .state_machine .get_state (state_name ).tags = ["foreach_state" ]
231280 self .generate_actions_info (
232281 machine_state = self .state_machine .get_state (self .state .name ),
233282 state_name = self .state .name ,
@@ -237,6 +286,10 @@ def foreach_state_details(self):
237286
238287 def callback_state_details (self ):
239288 if isinstance (self .state , CallbackState ):
289+ state_name = self .state .name
290+ if state_name not in self .state_machine .states .keys ():
291+ self .state_machine .add_states (state_name )
292+ self .state_machine .get_state (state_name ).tags = ["callback_state" ]
240293 action = self .state .action
241294 if action and action .functionRef :
242295 self .generate_actions_info (
@@ -264,7 +317,7 @@ def get_subflow_state(
264317 or not workflow_version
265318 ):
266319 none_found = False
267- new_machine = HierarchicalMachine (
320+ new_machine = CustomHierarchicalMachine (
268321 model = None , initial = None , auto_transitions = False
269322 )
270323
@@ -282,7 +335,8 @@ def get_subflow_state(
282335 added_states [i ] = self .subflow_state_name (
283336 action = action , subflow = sf
284337 )
285- nested_state = NestedState (added_states [i ])
338+ nested_state = self .state_machine .state_cls (added_states [i ])
339+ nested_state .tags = ["subflow" ]
286340 machine_state .add_substate (nested_state )
287341 self .state_machine_to_nested_state (
288342 state_name = state_name ,
@@ -301,7 +355,7 @@ def generate_actions_info(
301355 self ,
302356 machine_state : NestedState ,
303357 state_name : str ,
304- actions : List [Dict [str , Any ]],
358+ actions : List [Dict [str , Action ]],
305359 action_mode : str = "sequential" ,
306360 ):
307361 parallel_states = []
@@ -322,9 +376,19 @@ def generate_actions_info(
322376 )
323377 )
324378 if name not in machine_state .states .keys ():
325- machine_state .add_substate (NestedState (name ))
379+ machine_state .add_substate (
380+ ns := self .state_machine .state_cls (name )
381+ )
382+ ns .tags = ["function" ]
326383 elif action .subFlowRef :
327384 name = new_subflows_names .get (i )
385+ elif action .eventRef :
386+ name = f"{ action .eventRef .triggerEventRef } /{ action .eventRef .resultEventRef } "
387+ if name not in machine_state .states .keys ():
388+ machine_state .add_substate (
389+ ns := self .state_machine .state_cls (name )
390+ )
391+ ns .tags = ["event" ]
328392 if name :
329393 if action_mode == "sequential" :
330394 if i < len (actions ) - 1 :
@@ -348,9 +412,24 @@ def generate_actions_info(
348412 state_name
349413 ).states .keys ()
350414 ):
351- machine_state .add_substate (NestedState (next_name ))
415+ machine_state .add_substate (
416+ ns := self .state_machine .state_cls (next_name )
417+ )
418+ ns .tags = ["function" ]
352419 elif actions [i + 1 ].subFlowRef :
353420 next_name = new_subflows_names .get (i + 1 )
421+ elif actions [i + 1 ].eventRef :
422+ next_name = f"{ action .eventRef .triggerEventRef } /{ action .eventRef .resultEventRef } "
423+ if (
424+ next_name
425+ not in self .state_machine .get_state (
426+ state_name
427+ ).states .keys ()
428+ ):
429+ machine_state .add_substate (
430+ ns := self .state_machine .state_cls (name )
431+ )
432+ ns .tags = ["event" ]
354433 self .state_machine .add_transition (
355434 trigger = "" ,
356435 source = f"{ state_name } .{ name } " ,
@@ -371,21 +450,22 @@ def subflow_state_name(self, action: Action, subflow: Workflow):
371450 )
372451
373452 def add_all_sub_states (
374- cls ,
375- original_state : Union [NestedState , HierarchicalMachine ],
453+ self ,
454+ original_state : Union [NestedState , CustomHierarchicalMachine ],
376455 new_state : NestedState ,
377456 ):
378457 if len (original_state .states ) == 0 :
379458 return
380459 for substate in original_state .states .values ():
381- new_state .add_substate (ns := NestedState (substate .name ))
382- cls .add_all_sub_states (substate , ns )
460+ new_state .add_substate (ns := self .state_machine .state_cls (substate .name ))
461+ ns .tags = substate .tags
462+ self .add_all_sub_states (substate , ns )
383463 new_state .initial = original_state .initial
384464
385465 def state_machine_to_nested_state (
386466 self ,
387467 state_name : str ,
388- state_machine : HierarchicalMachine ,
468+ state_machine : CustomHierarchicalMachine ,
389469 nested_state : NestedState ,
390470 ) -> NestedState :
391471 self .add_all_sub_states (state_machine , nested_state )
0 commit comments