11""" Finite State Transducer """
22
3- from typing import Dict , List , Set , Tuple , Iterator , Iterable , Hashable
4- from copy import deepcopy
3+ from typing import Dict , List , Set , AbstractSet , \
4+ Tuple , Iterator , Iterable , Hashable
55
66from networkx import MultiDiGraph
77from networkx .drawing .nx_pydot import write_dot
88
9+ from .transition_function import TransitionFunction
10+ from .transition_function import TransitionKey , TransitionValues , Transition
911from .utils import StateRenaming
1012from ..objects .finite_automaton_objects import State , Symbol , Epsilon
1113from ..objects .finite_automaton_objects .utils import to_state , to_symbol
1214
13- TransitionKey = Tuple [State , Symbol ]
14- TransitionValue = Tuple [State , Tuple [Symbol , ...]]
15- TransitionValues = Set [TransitionValue ]
16- TransitionFunction = Dict [TransitionKey , TransitionValues ]
17-
1815InputTransition = Tuple [Hashable , Hashable , Hashable , Iterable [Hashable ]]
19- Transition = Tuple [TransitionKey , TransitionValue ]
2016
2117
2218class FST (Iterable [Transition ]):
2319 """ Representation of a Finite State Transducer"""
2420
25- def __init__ (self ) -> None :
26- self ._states : Set [State ] = set () # Set of states
27- self ._input_symbols : Set [Symbol ] = set () # Set of input symbols
28- self ._output_symbols : Set [Symbol ] = set () # Set of output symbols
29- # Dict from _states x _input_symbols U {epsilon} into a subset of
30- # _states X _output_symbols*
31- self ._delta : TransitionFunction = {}
32- self ._start_states : Set [State ] = set ()
33- self ._final_states : Set [State ] = set () # _final_states is final states
21+ def __init__ (self ,
22+ states : AbstractSet [Hashable ] = None ,
23+ input_symbols : AbstractSet [Hashable ] = None ,
24+ output_symbols : AbstractSet [Hashable ] = None ,
25+ transition_function : TransitionFunction = None ,
26+ start_states : AbstractSet [Hashable ] = None ,
27+ final_states : AbstractSet [Hashable ] = None ) -> None :
28+ self ._states = {to_state (x ) for x in states or set ()}
29+ self ._input_symbols = {to_symbol (x ) for x in input_symbols or set ()}
30+ self ._output_symbols = {to_symbol (x ) for x in output_symbols or set ()}
31+ self ._transition_function = transition_function or TransitionFunction ()
32+ self ._start_states = {to_state (x ) for x in start_states or set ()}
33+ self ._states .update (self ._start_states )
34+ self ._final_states = {to_state (x ) for x in final_states or set ()}
35+ self ._states .update (self ._final_states )
3436
3537 @property
3638 def states (self ) -> Set [State ]:
@@ -87,16 +89,6 @@ def final_states(self) -> Set[State]:
8789 """
8890 return self ._final_states
8991
90- def get_number_transitions (self ) -> int :
91- """ Get the number of transitions in the FST
92-
93- Returns
94- ----------
95- n_transitions : int
96- The number of transitions
97- """
98- return sum (len (x ) for x in self ._delta .values ())
99-
10092 def add_transition (self ,
10193 s_from : Hashable ,
10294 input_symbol : Hashable ,
@@ -125,11 +117,10 @@ def add_transition(self,
125117 if input_symbol != Epsilon ():
126118 self ._input_symbols .add (input_symbol )
127119 self ._output_symbols .update (output_symbols )
128- head = (s_from , input_symbol )
129- if head in self ._delta :
130- self ._delta [head ].add ((s_to , output_symbols ))
131- else :
132- self ._delta [head ] = {(s_to , output_symbols )}
120+ self ._transition_function .add_transition (s_from ,
121+ input_symbol ,
122+ s_to ,
123+ output_symbols )
133124
134125 def add_transitions (self , transitions : Iterable [InputTransition ]) -> None :
135126 """
@@ -156,8 +147,20 @@ def remove_transition(self,
156147 input_symbol = to_symbol (input_symbol )
157148 s_to = to_state (s_to )
158149 output_symbols = tuple (to_symbol (x ) for x in output_symbols )
159- head = (s_from , input_symbol )
160- self ._delta .get (head , set ()).discard ((s_to , output_symbols ))
150+ self ._transition_function .remove_transition (s_from ,
151+ input_symbol ,
152+ s_to ,
153+ output_symbols )
154+
155+ def get_number_transitions (self ) -> int :
156+ """ Get the number of transitions in the FST
157+
158+ Returns
159+ ----------
160+ n_transitions : int
161+ The number of transitions
162+ """
163+ return self ._transition_function .get_number_transitions ()
161164
162165 def add_start_state (self , start_state : Hashable ) -> None :
163166 """ Add a start state
@@ -188,7 +191,7 @@ def __call__(self, s_from: Hashable, input_symbol: Hashable) \
188191 """ Calls the transition function of the FST """
189192 s_from = to_state (s_from )
190193 input_symbol = to_symbol (input_symbol )
191- return self ._delta . get (( s_from , input_symbol ), set () )
194+ return self ._transition_function ( s_from , input_symbol )
192195
193196 def __contains__ (self , transition : InputTransition ) -> bool :
194197 """ Whether the given transition is present in the FST """
@@ -201,9 +204,7 @@ def __contains__(self, transition: InputTransition) -> bool:
201204
202205 def __iter__ (self ) -> Iterator [Transition ]:
203206 """ Gets an iterator of transitions of the FST """
204- for key , values in self ._delta .items ():
205- for value in values :
206- yield key , value
207+ yield from self ._transition_function
207208
208209 def translate (self ,
209210 input_word : Iterable [Hashable ],
@@ -300,14 +301,12 @@ def _add_transitions_to(self,
300301 union_fst : "FST" ,
301302 state_renaming : StateRenaming ,
302303 idx : int ) -> None :
303- for head , transition in self ._delta .items ():
304- s_from , input_symbol = head
305- for s_to , output_symbols in transition :
306- union_fst .add_transition (
307- state_renaming .get_renamed_state (s_from , idx ),
308- input_symbol ,
309- state_renaming .get_renamed_state (s_to , idx ),
310- output_symbols )
304+ for (s_from , input_symbol ), (s_to , output_symbols ) in self :
305+ union_fst .add_transition (
306+ state_renaming .get_renamed_state (s_from , idx ),
307+ input_symbol ,
308+ state_renaming .get_renamed_state (s_to , idx ),
309+ output_symbols )
311310
312311 def _add_extremity_states_to (self ,
313312 union_fst : "FST" ,
@@ -502,6 +501,18 @@ def write_as_dot(self, filename: str) -> None:
502501 """
503502 write_dot (self .to_networkx (), filename )
504503
505- def to_dict (self ) -> TransitionFunction :
504+ def copy (self ) -> "FST" :
505+ """ Copies the FST """
506+ return FST (states = self .states ,
507+ input_symbols = self .input_symbols ,
508+ output_symbols = self .output_symbols ,
509+ transition_function = self ._transition_function .copy (),
510+ start_states = self .start_states ,
511+ final_states = self .final_states )
512+
513+ def __copy__ (self ) -> "FST" :
514+ return self .copy ()
515+
516+ def to_dict (self ) -> Dict [TransitionKey , TransitionValues ]:
506517 """Gives the transitions as a dictionary"""
507- return deepcopy ( self ._delta )
518+ return self ._transition_function . to_dict ( )
0 commit comments