Skip to content

Commit 1d6649a

Browse files
committed
add proper fst initialization, add copying of the fst
1 parent b0ee784 commit 1d6649a

File tree

6 files changed

+223
-101
lines changed

6 files changed

+223
-101
lines changed

pyformlang/fst/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
1313
"""
1414

15-
from .fst import FST, State, Symbol, Epsilon
15+
from .fst import FST, TransitionFunction, State, Symbol, Epsilon
1616

1717

1818
__all__ = ["FST",
19+
"TransitionFunction",
1920
"State",
2021
"Symbol",
2122
"Epsilon"]

pyformlang/fst/fst.py

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,38 @@
11
""" Finite State Transducer """
22

3-
from typing import Dict, List, Set, Tuple, Iterator, Iterable, Hashable
4-
from copy import deepcopy
3+
from typing import Dict, List, Set, AbstractSet, \
4+
Tuple, Iterator, Iterable, Hashable
55

66
from networkx import MultiDiGraph
77
from networkx.drawing.nx_pydot import write_dot
88

9+
from .transition_function import TransitionFunction
10+
from .transition_function import TransitionKey, TransitionValues, Transition
911
from .utils import StateRenaming
1012
from ..objects.finite_automaton_objects import State, Symbol, Epsilon
1113
from ..objects.finite_automaton_objects.utils import to_state, to_symbol
1214

13-
TransitionKey = Tuple[State, Symbol]
14-
TransitionValue = Tuple[State, Tuple[Symbol, ...]]
15-
TransitionValues = Set[TransitionValue]
16-
TransitionFunction = Dict[TransitionKey, TransitionValues]
17-
1815
InputTransition = Tuple[Hashable, Hashable, Hashable, Iterable[Hashable]]
19-
Transition = Tuple[TransitionKey, TransitionValue]
2016

2117

2218
class FST(Iterable[Transition]):
2319
""" Representation of a Finite State Transducer"""
2420

25-
def __init__(self) -> None:
26-
self._states: Set[State] = set() # Set of states
27-
self._input_symbols: Set[Symbol] = set() # Set of input symbols
28-
self._output_symbols: Set[Symbol] = set() # Set of output symbols
29-
# Dict from _states x _input_symbols U {epsilon} into a subset of
30-
# _states X _output_symbols*
31-
self._delta: TransitionFunction = {}
32-
self._start_states: Set[State] = set()
33-
self._final_states: Set[State] = set() # _final_states is final states
21+
def __init__(self,
22+
states: AbstractSet[Hashable] = None,
23+
input_symbols: AbstractSet[Hashable] = None,
24+
output_symbols: AbstractSet[Hashable] = None,
25+
transition_function: TransitionFunction = None,
26+
start_states: AbstractSet[Hashable] = None,
27+
final_states: AbstractSet[Hashable] = None) -> None:
28+
self._states = {to_state(x) for x in states or set()}
29+
self._input_symbols = {to_symbol(x) for x in input_symbols or set()}
30+
self._output_symbols = {to_symbol(x) for x in output_symbols or set()}
31+
self._transition_function = transition_function or TransitionFunction()
32+
self._start_states = {to_state(x) for x in start_states or set()}
33+
self._states.update(self._start_states)
34+
self._final_states = {to_state(x) for x in final_states or set()}
35+
self._states.update(self._final_states)
3436

3537
@property
3638
def states(self) -> Set[State]:
@@ -87,16 +89,6 @@ def final_states(self) -> Set[State]:
8789
"""
8890
return self._final_states
8991

90-
def get_number_transitions(self) -> int:
91-
""" Get the number of transitions in the FST
92-
93-
Returns
94-
----------
95-
n_transitions : int
96-
The number of transitions
97-
"""
98-
return sum(len(x) for x in self._delta.values())
99-
10092
def add_transition(self,
10193
s_from: Hashable,
10294
input_symbol: Hashable,
@@ -125,11 +117,10 @@ def add_transition(self,
125117
if input_symbol != Epsilon():
126118
self._input_symbols.add(input_symbol)
127119
self._output_symbols.update(output_symbols)
128-
head = (s_from, input_symbol)
129-
if head in self._delta:
130-
self._delta[head].add((s_to, output_symbols))
131-
else:
132-
self._delta[head] = {(s_to, output_symbols)}
120+
self._transition_function.add_transition(s_from,
121+
input_symbol,
122+
s_to,
123+
output_symbols)
133124

134125
def add_transitions(self, transitions: Iterable[InputTransition]) -> None:
135126
"""
@@ -156,8 +147,20 @@ def remove_transition(self,
156147
input_symbol = to_symbol(input_symbol)
157148
s_to = to_state(s_to)
158149
output_symbols = tuple(to_symbol(x) for x in output_symbols)
159-
head = (s_from, input_symbol)
160-
self._delta.get(head, set()).discard((s_to, output_symbols))
150+
self._transition_function.remove_transition(s_from,
151+
input_symbol,
152+
s_to,
153+
output_symbols)
154+
155+
def get_number_transitions(self) -> int:
156+
""" Get the number of transitions in the FST
157+
158+
Returns
159+
----------
160+
n_transitions : int
161+
The number of transitions
162+
"""
163+
return self._transition_function.get_number_transitions()
161164

162165
def add_start_state(self, start_state: Hashable) -> None:
163166
""" Add a start state
@@ -188,7 +191,7 @@ def __call__(self, s_from: Hashable, input_symbol: Hashable) \
188191
""" Calls the transition function of the FST """
189192
s_from = to_state(s_from)
190193
input_symbol = to_symbol(input_symbol)
191-
return self._delta.get((s_from, input_symbol), set())
194+
return self._transition_function(s_from, input_symbol)
192195

193196
def __contains__(self, transition: InputTransition) -> bool:
194197
""" Whether the given transition is present in the FST """
@@ -201,9 +204,7 @@ def __contains__(self, transition: InputTransition) -> bool:
201204

202205
def __iter__(self) -> Iterator[Transition]:
203206
""" Gets an iterator of transitions of the FST """
204-
for key, values in self._delta.items():
205-
for value in values:
206-
yield key, value
207+
yield from self._transition_function
207208

208209
def translate(self,
209210
input_word: Iterable[Hashable],
@@ -300,14 +301,12 @@ def _add_transitions_to(self,
300301
union_fst: "FST",
301302
state_renaming: StateRenaming,
302303
idx: int) -> None:
303-
for head, transition in self._delta.items():
304-
s_from, input_symbol = head
305-
for s_to, output_symbols in transition:
306-
union_fst.add_transition(
307-
state_renaming.get_renamed_state(s_from, idx),
308-
input_symbol,
309-
state_renaming.get_renamed_state(s_to, idx),
310-
output_symbols)
304+
for (s_from, input_symbol), (s_to, output_symbols) in self:
305+
union_fst.add_transition(
306+
state_renaming.get_renamed_state(s_from, idx),
307+
input_symbol,
308+
state_renaming.get_renamed_state(s_to, idx),
309+
output_symbols)
311310

312311
def _add_extremity_states_to(self,
313312
union_fst: "FST",
@@ -502,6 +501,18 @@ def write_as_dot(self, filename: str) -> None:
502501
"""
503502
write_dot(self.to_networkx(), filename)
504503

505-
def to_dict(self) -> TransitionFunction:
504+
def copy(self) -> "FST":
505+
""" Copies the FST """
506+
return FST(states=self.states,
507+
input_symbols=self.input_symbols,
508+
output_symbols=self.output_symbols,
509+
transition_function=self._transition_function.copy(),
510+
start_states=self.start_states,
511+
final_states=self.final_states)
512+
513+
def __copy__(self) -> "FST":
514+
return self.copy()
515+
516+
def to_dict(self) -> Dict[TransitionKey, TransitionValues]:
506517
"""Gives the transitions as a dictionary"""
507-
return deepcopy(self._delta)
518+
return self._transition_function.to_dict()

pyformlang/fst/tests/test_fst.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from pyformlang.fst import FST
7+
from pyformlang.fst import FST, TransitionFunction, State, Symbol
88

99

1010
@pytest.fixture
@@ -194,9 +194,9 @@ def test_paper(self):
194194
def test_contains(self, fst0: FST):
195195
""" Tests the containment of transition in the FST """
196196
assert ("q0", "a", "q1", ["b"]) in fst0
197-
assert ("a", "b", "c", "d") not in fst0
198-
fst0.add_transition("a", "b", "c", "d")
199-
assert ("a", "b", "c", "d") in fst0
197+
assert ("a", "b", "c", ["d"]) not in fst0
198+
fst0.add_transition("a", "b", "c", {"d"})
199+
assert ("a", "b", "c", ["d"]) in fst0
200200

201201
def test_iter(self, fst0: FST):
202202
""" Tests the iteration of FST transitions """
@@ -210,9 +210,42 @@ def test_iter(self, fst0: FST):
210210

211211
def test_remove_transition(self, fst0: FST):
212212
""" Tests the removal of transition from the FST """
213-
assert ("q0", "a", "q1", "b") in fst0
214-
fst0.remove_transition("q0", "a", "q1", "b")
215-
assert ("q0", "a", "q1", "b") not in fst0
216-
fst0.remove_transition("q0", "a", "q1", "b")
217-
assert ("q0", "a", "q1", "b") not in fst0
213+
assert ("q0", "a", "q1", ["b"]) in fst0
214+
fst0.remove_transition("q0", "a", "q1", ["b"])
215+
assert ("q0", "a", "q1", ["b"]) not in fst0
216+
fst0.remove_transition("q0", "a", "q1", ["b"])
217+
assert ("q0", "a", "q1", ["b"]) not in fst0
218218
assert fst0.get_number_transitions() == 0
219+
220+
def test_initialization(self):
221+
""" Tests the initialization of the FST """
222+
fst = FST(states={0},
223+
input_symbols={"a", "b"},
224+
output_symbols={"c"},
225+
start_states={1},
226+
final_states={2})
227+
assert fst.states == {0, 1, 2}
228+
assert fst.input_symbols == {"a", "b"}
229+
assert fst.output_symbols == {"c"}
230+
assert fst.get_number_transitions() == 0
231+
assert not list(iter(fst))
232+
233+
function = TransitionFunction()
234+
function.add_transition(State(1), Symbol("a"), State(2), (Symbol("b"),))
235+
function.add_transition(State(1), Symbol("a"), State(2), (Symbol("c"),))
236+
fst = FST(transition_function=function)
237+
assert fst.get_number_transitions() == 2
238+
assert (1, "a", 2, ["b"]) in fst
239+
assert (1, "a", 2, ["c"]) in fst
240+
assert fst(1, "a") == {(2, tuple("b")), (2, tuple("c"))}
241+
242+
def test_copy(self, fst0: FST):
243+
""" Tests the copying of the FST """
244+
fst_copy = fst0.copy()
245+
assert fst_copy.states == fst0.states
246+
assert fst_copy.input_symbols == fst0.input_symbols
247+
assert fst_copy.output_symbols == fst0.output_symbols
248+
assert fst_copy.start_states == fst0.start_states
249+
assert fst_copy.final_states == fst0.final_states
250+
assert fst_copy.to_dict() == fst0.to_dict()
251+
assert fst_copy is not fst0
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
""" The transition function of Finite State Transducer """
2+
3+
from typing import Dict, Set, Tuple, Iterator, Iterable
4+
from copy import deepcopy
5+
6+
from ..objects.finite_automaton_objects import State, Symbol
7+
8+
TransitionKey = Tuple[State, Symbol]
9+
TransitionValue = Tuple[State, Tuple[Symbol, ...]]
10+
TransitionValues = Set[TransitionValue]
11+
Transition = Tuple[TransitionKey, TransitionValue]
12+
13+
14+
class TransitionFunction(Iterable[Transition]):
15+
""" The transition function of Finite State Transducer """
16+
17+
def __init__(self) -> None:
18+
self._transitions: Dict[TransitionKey, TransitionValues] = {}
19+
20+
def add_transition(self,
21+
s_from: State,
22+
input_symbol: Symbol,
23+
s_to: State,
24+
output_symbols: Tuple[Symbol, ...]) -> None:
25+
""" Adds given transition to the function """
26+
key = (s_from, input_symbol)
27+
value = (s_to, output_symbols)
28+
self._transitions.setdefault(key, set()).add(value)
29+
30+
def remove_transition(self,
31+
s_from: State,
32+
input_symbol: Symbol,
33+
s_to: State,
34+
output_symbols: Tuple[Symbol, ...]) -> None:
35+
""" Removes given transition from the function """
36+
key = (s_from, input_symbol)
37+
value = (s_to, output_symbols)
38+
self._transitions.get(key, set()).discard(value)
39+
40+
def get_number_transitions(self) -> int:
41+
""" Gets the number of transitions in the function
42+
43+
Returns
44+
----------
45+
n_transitions : int
46+
The number of transitions
47+
"""
48+
return sum(len(x) for x in self._transitions.values())
49+
50+
def __call__(self, s_from: State, input_symbol: Symbol) \
51+
-> TransitionValues:
52+
""" Calls the transition function """
53+
return self._transitions.get((s_from, input_symbol), set())
54+
55+
def __contains__(self, transition: Transition) -> bool:
56+
""" Whether the given transition is present in the function """
57+
key, value = transition
58+
return value in self(*key)
59+
60+
def __iter__(self) -> Iterator[Transition]:
61+
""" Gets an iterator of transitions of the function """
62+
for key, values in self._transitions.items():
63+
for value in values:
64+
yield key, value
65+
66+
def copy(self) -> "TransitionFunction":
67+
""" Copies the transition function """
68+
new_tf = TransitionFunction()
69+
for key, value in self:
70+
new_tf.add_transition(*key, *value)
71+
return new_tf
72+
73+
def __copy__(self) -> "TransitionFunction":
74+
return self.copy()
75+
76+
def to_dict(self) -> Dict[TransitionKey, TransitionValues]:
77+
""" Gives the transition function as a dictionary """
78+
return deepcopy(self._transitions)

0 commit comments

Comments
 (0)