Skip to content

Commit 1bb9c7a

Browse files
committed
refactor indexed grammar intersection
1 parent 568f311 commit 1bb9c7a

File tree

4 files changed

+158
-179
lines changed

4 files changed

+158
-179
lines changed

pyformlang/fst/fst.py

Lines changed: 1 addition & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
from networkx import MultiDiGraph
88
from networkx.drawing.nx_pydot import write_dot
99

10-
from pyformlang.indexed_grammar import IndexedGrammar, Rules, \
11-
DuplicationRule, ProductionRule, EndRule, ConsumptionRule
12-
from pyformlang.indexed_grammar.reduced_rule import ReducedRule
13-
1410
from .utils import StateRenaming
1511
from ..objects.finite_automaton_objects import State, Symbol, Epsilon
1612
from ..objects.finite_automaton_objects.utils import to_state, to_symbol
@@ -216,7 +212,7 @@ def translate(self,
216212
The translation of the input word
217213
"""
218214
# (remaining in the input, generated so far, current_state)
219-
input_word = [to_symbol(symbol) for symbol in input_word]
215+
input_word = [to_symbol(x) for x in input_word if x != Epsilon()]
220216
to_process: List[Tuple[List[Symbol], List[Symbol], State]] = []
221217
seen_by_state = {state: [] for state in self.states}
222218
for start_state in self._start_states:
@@ -244,134 +240,6 @@ def translate(self,
244240
generated + list(output_symbols),
245241
next_state))
246242

247-
def intersection(self, indexed_grammar: IndexedGrammar) -> IndexedGrammar:
248-
""" Compute the intersection with an other object
249-
250-
Equivalent to:
251-
>> fst and indexed_grammar
252-
"""
253-
rules = indexed_grammar.rules
254-
new_rules: List[ReducedRule] = [EndRule("T", str(Epsilon()))]
255-
self._extract_consumption_rules_intersection(rules, new_rules)
256-
self._extract_indexed_grammar_rules_intersection(rules, new_rules)
257-
self._extract_terminals_intersection(rules, new_rules)
258-
self._extract_epsilon_transitions_intersection(new_rules)
259-
self._extract_fst_delta_intersection(new_rules)
260-
self._extract_fst_epsilon_intersection(new_rules)
261-
self._extract_fst_duplication_rules_intersection(new_rules)
262-
rules = Rules(new_rules, rules.optim)
263-
return IndexedGrammar(rules).remove_useless_rules()
264-
265-
def _extract_fst_duplication_rules_intersection(
266-
self,
267-
new_rules: List[ReducedRule]) \
268-
-> None:
269-
for state_p in self._final_states:
270-
for start_state in self._start_states:
271-
new_rules.append(DuplicationRule(
272-
"S",
273-
str((start_state, "S", state_p)),
274-
"T"))
275-
276-
def _extract_fst_epsilon_intersection(
277-
self,
278-
new_rules: List[ReducedRule]) \
279-
-> None:
280-
for state_p in self._states:
281-
new_rules.append(EndRule(
282-
str((state_p, Epsilon(), state_p)), str(Epsilon())))
283-
284-
def _extract_fst_delta_intersection(
285-
self,
286-
new_rules:List[ReducedRule]) \
287-
-> None:
288-
for key, pair in self._delta.items():
289-
state_p = key[0]
290-
terminal = key[1]
291-
for transition in pair:
292-
state_q = transition[0]
293-
symbol = transition[1]
294-
new_rules.append(EndRule(str((state_p, terminal, state_q)),
295-
symbol))
296-
297-
def _extract_epsilon_transitions_intersection(
298-
self,
299-
new_rules: List[ReducedRule]) \
300-
-> None:
301-
for state_p in self._states:
302-
for state_q in self._states:
303-
for state_r in self._states:
304-
new_rules.append(DuplicationRule(
305-
str((state_p, Epsilon(), state_q)),
306-
str((state_p, Epsilon(), state_r)),
307-
str((state_r, Epsilon(), state_q))))
308-
309-
def _extract_indexed_grammar_rules_intersection(
310-
self,
311-
rules: Rules,
312-
new_rules: List[ReducedRule]) \
313-
-> None:
314-
for rule in rules.rules:
315-
if isinstance(rule, DuplicationRule):
316-
for state_p in self._states:
317-
for state_q in self._states:
318-
for state_r in self._states:
319-
new_rules.append(DuplicationRule(
320-
str((state_p, rule.left_term, state_q)),
321-
str((state_p, rule.right_terms[0], state_r)),
322-
str((state_r, rule.right_terms[1], state_q))))
323-
elif isinstance(rule, ProductionRule):
324-
for state_p in self._states:
325-
for state_q in self._states:
326-
new_rules.append(ProductionRule(
327-
str((state_p, rule.left_term, state_q)),
328-
str((state_p, rule.right_term, state_q)),
329-
str(rule.production)))
330-
elif isinstance(rule, EndRule):
331-
for state_p in self._states:
332-
for state_q in self._states:
333-
new_rules.append(DuplicationRule(
334-
str((state_p, rule.left_term, state_q)),
335-
str((state_p, rule.right_term, state_q)),
336-
"T"))
337-
338-
def _extract_terminals_intersection(
339-
self,
340-
rules: Rules,
341-
new_rules: List[ReducedRule]) \
342-
-> None:
343-
terminals = rules.terminals
344-
for terminal in terminals:
345-
for state_p in self._states:
346-
for state_q in self._states:
347-
for state_r in self._states:
348-
new_rules.append(DuplicationRule(
349-
str((state_p, terminal, state_q)),
350-
str((state_p, Epsilon(), state_r)),
351-
str((state_r, terminal, state_q))))
352-
new_rules.append(DuplicationRule(
353-
str((state_p, terminal, state_q)),
354-
str((state_p, terminal, state_r)),
355-
str((state_r, Epsilon(), state_q))))
356-
357-
def _extract_consumption_rules_intersection(
358-
self,
359-
rules: Rules,
360-
new_rules: List[ReducedRule]) \
361-
-> None:
362-
consumptions = rules.consumption_rules
363-
for consumption_rule in consumptions:
364-
for consumption in consumptions[consumption_rule]:
365-
for state_r in self._states:
366-
for state_s in self._states:
367-
new_rules.append(ConsumptionRule(
368-
consumption.f_parameter,
369-
str((state_r, consumption.left_term, state_s)),
370-
str((state_r, consumption.right_term, state_s))))
371-
372-
def __and__(self, other: IndexedGrammar) -> IndexedGrammar:
373-
return self.intersection(other)
374-
375243
def union(self, other_fst: "FST") -> "FST":
376244
"""
377245
Makes the union of two fst

pyformlang/fst/tests/test_fst.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
import pytest
66

77
from pyformlang.fst import FST
8-
from pyformlang.indexed_grammar import (
9-
DuplicationRule, ProductionRule, EndRule,
10-
ConsumptionRule, IndexedGrammar, Rules)
118

129

1310
@pytest.fixture
@@ -94,34 +91,6 @@ def test_translate(self):
9491
assert ["b", "c"] in translation
9592
assert ["b"] + ["c"] * 9 in translation
9693

97-
def test_intersection_indexed_grammar(self):
98-
""" Test the intersection with indexed grammar """
99-
l_rules = []
100-
rules = Rules(l_rules)
101-
indexed_grammar = IndexedGrammar(rules)
102-
fst = FST()
103-
intersection = fst & indexed_grammar
104-
assert intersection.is_empty()
105-
106-
l_rules.append(ProductionRule("S", "D", "f"))
107-
l_rules.append(DuplicationRule("D", "A", "B"))
108-
l_rules.append(ConsumptionRule("f", "A", "Afinal"))
109-
l_rules.append(ConsumptionRule("f", "B", "Bfinal"))
110-
l_rules.append(EndRule("Afinal", "a"))
111-
l_rules.append(EndRule("Bfinal", "b"))
112-
113-
rules = Rules(l_rules)
114-
indexed_grammar = IndexedGrammar(rules)
115-
intersection = fst.intersection(indexed_grammar)
116-
assert intersection.is_empty()
117-
118-
fst.add_start_state("q0")
119-
fst.add_final_state("final")
120-
fst.add_transition("q0", "a", "q1", ["a"])
121-
fst.add_transition("q1", "b", "final", ["b"])
122-
intersection = fst.intersection(indexed_grammar)
123-
assert not intersection.is_empty()
124-
12594
def test_union(self, fst0, fst1):
12695
""" Tests the union"""
12796
fst_union = fst0.union(fst1)

pyformlang/indexed_grammar/indexed_grammar.py

Lines changed: 125 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44

55
# pylint: disable=cell-var-from-loop
66

7-
from typing import Dict, List, Set, FrozenSet, Tuple, Hashable, Any
7+
from typing import Dict, List, Set, FrozenSet, Tuple, Hashable
88

99
from pyformlang.cfg import CFGObject, Variable, Terminal
10-
from pyformlang.finite_automaton import FiniteAutomaton
11-
from pyformlang.regular_expression import Regex
10+
from pyformlang.fst import FST
1211

1312
from .rules import Rules
13+
from .reduced_rule import ReducedRule
1414
from .duplication_rule import DuplicationRule
1515
from .production_rule import ProductionRule
16+
from .consumption_rule import ConsumptionRule
1617
from .end_rule import EndRule
1718
from .utils import exists, addrec_bis
1819
from ..objects.cfg_objects.utils import to_variable
@@ -47,7 +48,7 @@ def __init__(self,
4748
# Mark all end symbols
4849
for non_terminal_a in non_terminals:
4950
if exists(self._rules.rules,
50-
lambda x: x.is_end_rule()
51+
lambda x: isinstance(x, EndRule)
5152
and x.left_term == non_terminal_a):
5253
self._marked[non_terminal_a].add(frozenset())
5354

@@ -363,7 +364,7 @@ def remove_useless_rules(self) -> "IndexedGrammar":
363364
rules = Rules(l_rules, self._rules.optim)
364365
return IndexedGrammar(rules)
365366

366-
def intersection(self, other: Any) -> "IndexedGrammar":
367+
def intersection(self, other: FST) -> "IndexedGrammar":
367368
""" Computes the intersection of the current indexed grammar with the \
368369
other object
369370
@@ -387,14 +388,18 @@ def intersection(self, other: Any) -> "IndexedGrammar":
387388
When trying to intersection with something else than a regular
388389
expression or a finite automaton
389390
"""
390-
if isinstance(other, Regex):
391-
other = other.to_epsilon_nfa()
392-
if isinstance(other, FiniteAutomaton):
393-
fst = other.to_fst()
394-
return fst.intersection(self)
395-
raise NotImplementedError
396-
397-
def __and__(self, other: Any) -> "IndexedGrammar":
391+
new_rules: List[ReducedRule] = [EndRule("T", "epsilon")]
392+
self._extract_consumption_rules_intersection(other, new_rules)
393+
self._extract_indexed_grammar_rules_intersection(other, new_rules)
394+
self._extract_terminals_intersection(other, new_rules)
395+
self._extract_epsilon_transitions_intersection(other, new_rules)
396+
self._extract_fst_delta_intersection(other, new_rules)
397+
self._extract_fst_epsilon_intersection(other, new_rules)
398+
self._extract_fst_duplication_rules_intersection(other, new_rules)
399+
rules = Rules(new_rules, self.rules.optim)
400+
return IndexedGrammar(rules).remove_useless_rules()
401+
402+
def __and__(self, other: FST) -> "IndexedGrammar":
398403
""" Computes the intersection of the current indexed grammar with the
399404
other object
400405
@@ -409,3 +414,110 @@ def __and__(self, other: Any) -> "IndexedGrammar":
409414
The indexed grammar which useless rules
410415
"""
411416
return self.intersection(other)
417+
418+
def _extract_fst_duplication_rules_intersection(
419+
self,
420+
other: FST,
421+
new_rules: List[ReducedRule]) \
422+
-> None:
423+
for final_state in other.final_states:
424+
for start_state in other.start_states:
425+
new_rules.append(DuplicationRule(
426+
"S",
427+
(start_state, "S", final_state),
428+
"T"))
429+
430+
def _extract_fst_epsilon_intersection(
431+
self,
432+
other: FST,
433+
new_rules: List[ReducedRule]) \
434+
-> None:
435+
for state in other.states:
436+
new_rules.append(EndRule(
437+
(state, "epsilon", state),
438+
"epsilon"))
439+
440+
def _extract_fst_delta_intersection(
441+
self,
442+
other: FST,
443+
new_rules: List[ReducedRule]) \
444+
-> None:
445+
for (s_from, symb_from), (s_to, symb_to) in other:
446+
new_rules.append(EndRule(
447+
(s_from, symb_from, s_to),
448+
symb_to))
449+
450+
def _extract_epsilon_transitions_intersection(
451+
self,
452+
other: FST,
453+
new_rules: List[ReducedRule]) \
454+
-> None:
455+
for state_p in other.states:
456+
for state_q in other.states:
457+
for state_r in other.states:
458+
new_rules.append(DuplicationRule(
459+
(state_p, "epsilon", state_q),
460+
(state_p, "epsilon", state_r),
461+
(state_r, "epsilon", state_q)))
462+
463+
def _extract_indexed_grammar_rules_intersection(
464+
self,
465+
other: FST,
466+
new_rules: List[ReducedRule]) \
467+
-> None:
468+
for rule in self.rules.rules:
469+
if isinstance(rule, DuplicationRule):
470+
for state_p in other.states:
471+
for state_q in other.states:
472+
for state_r in other.states:
473+
new_rules.append(DuplicationRule(
474+
(state_p, rule.left_term, state_q),
475+
(state_p, rule.right_terms[0], state_r),
476+
(state_r, rule.right_terms[1], state_q)))
477+
elif isinstance(rule, ProductionRule):
478+
for state_p in other.states:
479+
for state_q in other.states:
480+
new_rules.append(ProductionRule(
481+
(state_p, rule.left_term, state_q),
482+
(state_p, rule.right_term, state_q),
483+
rule.production))
484+
elif isinstance(rule, EndRule):
485+
for state_p in other.states:
486+
for state_q in other.states:
487+
new_rules.append(DuplicationRule(
488+
(state_p, rule.left_term, state_q),
489+
(state_p, rule.right_term, state_q),
490+
"T"))
491+
492+
def _extract_terminals_intersection(
493+
self,
494+
other: FST,
495+
new_rules: List[ReducedRule]) \
496+
-> None:
497+
for terminal in self.rules.terminals:
498+
for state_p in other.states:
499+
for state_q in other.states:
500+
for state_r in other.states:
501+
new_rules.append(DuplicationRule(
502+
(state_p, terminal, state_q),
503+
(state_p, "epsilon", state_r),
504+
(state_r, terminal, state_q)))
505+
new_rules.append(DuplicationRule(
506+
(state_p, terminal, state_q),
507+
(state_p, terminal, state_r),
508+
(state_r, "epsilon", state_q)))
509+
510+
def _extract_consumption_rules_intersection(
511+
self,
512+
other: FST,
513+
new_rules: List[ReducedRule]) \
514+
-> None:
515+
consumptions = self.rules.consumption_rules
516+
for terminal in consumptions:
517+
for consumption in consumptions[terminal]:
518+
for state_r in other.states:
519+
for state_s in other.states:
520+
new_rules.append(ConsumptionRule(
521+
consumption.f_parameter,
522+
(state_r, consumption.left_term, state_s),
523+
(state_r, consumption.right_term, state_s)))

0 commit comments

Comments
 (0)