From 0a0ecc3277e0e1327c8377c1a1450e4d6befd095 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Fri, 25 Oct 2024 17:12:51 +0300 Subject: [PATCH 01/12] add fst annotations --- .../finite_automaton/finite_automaton.py | 12 +- pyformlang/fst/__init__.py | 7 +- pyformlang/fst/fst.py | 361 +++++++++--------- pyformlang/fst/utils.py | 69 ++++ 4 files changed, 271 insertions(+), 178 deletions(-) create mode 100644 pyformlang/fst/utils.py diff --git a/pyformlang/finite_automaton/finite_automaton.py b/pyformlang/finite_automaton/finite_automaton.py index 654a039..3017abb 100644 --- a/pyformlang/finite_automaton/finite_automaton.py +++ b/pyformlang/finite_automaton/finite_automaton.py @@ -436,14 +436,14 @@ def to_fst(self) -> FST: """ fst = FST() for start_state in self._start_states: - fst.add_start_state(start_state.value) + fst.add_start_state(start_state) for final_state in self._final_states: - fst.add_final_state(final_state.value) + fst.add_final_state(final_state) for s_from, symb_by, s_to in self._transition_function: - fst.add_transition(s_from.value, - symb_by.value, - s_to.value, - [symb_by.value]) + fst.add_transition(s_from, + symb_by, + s_to, + [symb_by]) return fst def is_acyclic(self) -> bool: diff --git a/pyformlang/fst/__init__.py b/pyformlang/fst/__init__.py index afd33d1..c881874 100644 --- a/pyformlang/fst/__init__.py +++ b/pyformlang/fst/__init__.py @@ -12,7 +12,10 @@ """ -from .fst import FST +from .fst import FST, State, Symbol, Epsilon -__all__ = ["FST"] +__all__ = ["FST", + "State", + "Symbol", + "Epsilon"] diff --git a/pyformlang/fst/fst.py b/pyformlang/fst/fst.py index dcecafa..73f5cab 100644 --- a/pyformlang/fst/fst.py +++ b/pyformlang/fst/fst.py @@ -1,29 +1,44 @@ """ Finite State Transducer """ -import json -from typing import Any, Iterable -import networkx as nx +from typing import Dict, List, Set, Tuple, Iterator, Iterable, Hashable +from copy import deepcopy +from json import dumps, loads + +from networkx import MultiDiGraph from networkx.drawing.nx_pydot import write_dot -from pyformlang.indexed_grammar import DuplicationRule, ProductionRule, \ - EndRule, ConsumptionRule, IndexedGrammar, Rules +from pyformlang.indexed_grammar import IndexedGrammar, Rules, \ + DuplicationRule, ProductionRule, EndRule, ConsumptionRule +from pyformlang.indexed_grammar.reduced_rule import ReducedRule + +from .utils import StateRenaming +from ..objects.finite_automaton_objects import State, Symbol, Epsilon +from ..objects.finite_automaton_objects.utils import to_state, to_symbol + +TransitionKey = Tuple[State, Symbol] +TransitionValue = Tuple[State, Tuple[Symbol, ...]] +TransitionValues = Set[TransitionValue] +TransitionFunction = Dict[TransitionKey, TransitionValues] + +InputTransition = Tuple[Hashable, Hashable, Hashable, Iterable[Hashable]] +Transition = Tuple[TransitionKey, TransitionValue] -class FST: +class FST(Iterable[Transition]): """ Representation of a Finite State Transducer""" - def __init__(self): - self._states = set() # Set of states - self._input_symbols = set() # Set of input symbols - self._output_symbols = set() # Set of output symbols + def __init__(self) -> None: + self._states: Set[State] = set() # Set of states + self._input_symbols: Set[Symbol] = set() # Set of input symbols + self._output_symbols: Set[Symbol] = set() # Set of output symbols # Dict from _states x _input_symbols U {epsilon} into a subset of # _states X _output_symbols* - self._delta = {} - self._start_states = set() - self._final_states = set() # _final_states is final states + self._delta: TransitionFunction = {} + self._start_states: Set[State] = set() + self._final_states: Set[State] = set() # _final_states is final states @property - def states(self): + def states(self) -> Set[State]: """ Get the states of the FST Returns @@ -34,7 +49,7 @@ def states(self): return self._states @property - def input_symbols(self): + def input_symbols(self) -> Set[Symbol]: """ Get the input symbols of the FST Returns @@ -45,7 +60,7 @@ def input_symbols(self): return self._input_symbols @property - def output_symbols(self): + def output_symbols(self) -> Set[Symbol]: """ Get the output symbols of the FST Returns @@ -56,7 +71,7 @@ def output_symbols(self): return self._output_symbols @property - def start_states(self): + def start_states(self) -> Set[State]: """ Get the start states of the FST Returns @@ -67,7 +82,7 @@ def start_states(self): return self._start_states @property - def final_states(self): + def final_states(self) -> Set[State]: """ Get the final states of the FST Returns @@ -77,11 +92,6 @@ def final_states(self): """ return self._final_states - @property - def transitions(self): - """Gives the transitions as a dictionary""" - return self._delta - def get_number_transitions(self) -> int: """ Get the number of transitions in the FST @@ -92,10 +102,11 @@ def get_number_transitions(self) -> int: """ return sum(len(x) for x in self._delta.values()) - def add_transition(self, s_from: Any, - input_symbol: Any, - s_to: Any, - output_symbols: Iterable[Any]): + def add_transition(self, + s_from: Hashable, + input_symbol: Hashable, + s_to: Hashable, + output_symbols: Iterable[Hashable]) -> None: """ Add a transition to the FST Parameters @@ -109,20 +120,23 @@ def add_transition(self, s_from: Any, output_symbols : iterable of Any The symbols to output """ + s_from = to_state(s_from) + input_symbol = to_symbol(input_symbol) + s_to = to_state(s_to) + output_symbols = tuple(to_symbol(x) for x in output_symbols + if x != Epsilon()) self._states.add(s_from) self._states.add(s_to) - if input_symbol != "epsilon": + if input_symbol != Epsilon(): self._input_symbols.add(input_symbol) - for output_symbol in output_symbols: - if output_symbol != "epsilon": - self._output_symbols.add(output_symbol) + self._output_symbols.update(output_symbols) head = (s_from, input_symbol) if head in self._delta: - self._delta[head].append((s_to, output_symbols)) + self._delta[head].add((s_to, output_symbols)) else: - self._delta[head] = [(s_to, output_symbols)] + self._delta[head] = {(s_to, output_symbols)} - def add_transitions(self, transitions_list): + def add_transitions(self, transitions: Iterable[InputTransition]) -> None: """ Adds several transitions to the FST @@ -131,15 +145,13 @@ def add_transitions(self, transitions_list): transitions_list : list of tuples The tuples have the form (s_from, in_symbol, s_to, out_symbols) """ - for s_from, input_symbol, s_to, output_symbols in transitions_list: - self.add_transition( - s_from, - input_symbol, - s_to, - output_symbols - ) - - def add_start_state(self, start_state: Any): + for s_from, input_symbol, s_to, output_symbols in transitions: + self.add_transition(s_from, + input_symbol, + s_to, + output_symbols) + + def add_start_state(self, start_state: Hashable) -> None: """ Add a start state Parameters @@ -147,10 +159,11 @@ def add_start_state(self, start_state: Any): start_state : any The start state """ + start_state = to_state(start_state) self._states.add(start_state) self._start_states.add(start_state) - def add_final_state(self, final_state: Any): + def add_final_state(self, final_state: Hashable) -> None: """ Add a final state Parameters @@ -158,11 +171,35 @@ def add_final_state(self, final_state: Any): final_state : any The final state to add """ + final_state = to_state(final_state) self._final_states.add(final_state) self._states.add(final_state) - def translate(self, input_word: Iterable[Any], max_length: int = -1) -> \ - Iterable[Any]: + def __call__(self, s_from: Hashable, input_symbol: Hashable) \ + -> TransitionValues: + """ Calls the transition function of the FST """ + s_from = to_state(s_from) + input_symbol = to_symbol(input_symbol) + return self._delta.get((s_from, input_symbol), set()) + + def __contains__(self, transition: InputTransition) -> bool: + """ Whether the given transition is present in the FST """ + s_from, input_symbol, s_to, output_symbols = transition + s_from = to_state(s_from) + input_symbol = to_symbol(input_symbol) + s_to = to_state(s_to) + output_symbols = tuple(to_symbol(x) for x in output_symbols) + return (s_to, output_symbols) in self(s_from, input_symbol) + + def __iter__(self) -> Iterator[Transition]: + """ Gets an iterator of transitions of the FST """ + for key, values in self._delta.items(): + for value in values: + yield key, value + + def translate(self, + input_word: Iterable[Hashable], + max_length: int = -1) -> Iterable[List[Symbol]]: """ Translate a string into another using the FST Parameters @@ -179,7 +216,8 @@ def translate(self, input_word: Iterable[Any], max_length: int = -1) -> \ The translation of the input word """ # (remaining in the input, generated so far, current_state) - to_process = [] + input_word = [to_symbol(symbol) for symbol in input_word] + to_process: List[Tuple[List[Symbol], List[Symbol], State]] = [] seen_by_state = {state: [] for state in self.states} for start_state in self._start_states: to_process.append((input_word, [], start_state)) @@ -192,28 +230,28 @@ def translate(self, input_word: Iterable[Any], max_length: int = -1) -> \ yield generated # We try to read an input if len(remaining) != 0: - for next_state, output_string in self._delta.get( - (current_state, remaining[0]), []): + for next_state, output_symbols in self(current_state, + remaining[0]): to_process.append( (remaining[1:], - generated + output_string, + generated + list(output_symbols), next_state)) # We try to read an epsilon transition if max_length == -1 or len(generated) < max_length: - for next_state, output_string in self._delta.get( - (current_state, "epsilon"), []): + for next_state, output_symbols in self(current_state, + Epsilon()): to_process.append((remaining, - generated + output_string, + generated + list(output_symbols), next_state)) - def intersection(self, indexed_grammar): + def intersection(self, indexed_grammar: IndexedGrammar) -> IndexedGrammar: """ Compute the intersection with an other object Equivalent to: >> fst and indexed_grammar """ rules = indexed_grammar.rules - new_rules = [EndRule("T", "epsilon")] + new_rules: List[ReducedRule] = [EndRule("T", str(Epsilon()))] self._extract_consumption_rules_intersection(rules, new_rules) self._extract_indexed_grammar_rules_intersection(rules, new_rules) self._extract_terminals_intersection(rules, new_rules) @@ -224,7 +262,10 @@ def intersection(self, indexed_grammar): rules = Rules(new_rules, rules.optim) return IndexedGrammar(rules).remove_useless_rules() - def _extract_fst_duplication_rules_intersection(self, new_rules): + def _extract_fst_duplication_rules_intersection( + self, + new_rules: List[ReducedRule]) \ + -> None: for state_p in self._final_states: for start_state in self._start_states: new_rules.append(DuplicationRule( @@ -232,13 +273,18 @@ def _extract_fst_duplication_rules_intersection(self, new_rules): str((start_state, "S", state_p)), "T")) - def _extract_fst_epsilon_intersection(self, new_rules): + def _extract_fst_epsilon_intersection( + self, + new_rules: List[ReducedRule]) \ + -> None: for state_p in self._states: new_rules.append(EndRule( - str((state_p, "epsilon", state_p)), - "epsilon")) + str((state_p, Epsilon(), state_p)), str(Epsilon()))) - def _extract_fst_delta_intersection(self, new_rules): + def _extract_fst_delta_intersection( + self, + new_rules:List[ReducedRule]) \ + -> None: for key, pair in self._delta.items(): state_p = key[0] terminal = key[1] @@ -248,16 +294,23 @@ def _extract_fst_delta_intersection(self, new_rules): new_rules.append(EndRule(str((state_p, terminal, state_q)), symbol)) - def _extract_epsilon_transitions_intersection(self, new_rules): + def _extract_epsilon_transitions_intersection( + self, + new_rules: List[ReducedRule]) \ + -> None: for state_p in self._states: for state_q in self._states: for state_r in self._states: new_rules.append(DuplicationRule( - str((state_p, "epsilon", state_q)), - str((state_p, "epsilon", state_r)), - str((state_r, "epsilon", state_q)))) - - def _extract_indexed_grammar_rules_intersection(self, rules, new_rules): + str((state_p, Epsilon(), state_q)), + str((state_p, Epsilon(), state_r)), + str((state_r, Epsilon(), state_q)))) + + def _extract_indexed_grammar_rules_intersection( + self, + rules: Rules, + new_rules: List[ReducedRule]) \ + -> None: for rule in rules.rules: if rule.is_duplication(): for state_p in self._states: @@ -282,7 +335,11 @@ def _extract_indexed_grammar_rules_intersection(self, rules, new_rules): str((state_p, rule.right_term, state_q)), "T")) - def _extract_terminals_intersection(self, rules, new_rules): + def _extract_terminals_intersection( + self, + rules: Rules, + new_rules: List[ReducedRule]) \ + -> None: terminals = rules.terminals for terminal in terminals: for state_p in self._states: @@ -290,14 +347,18 @@ def _extract_terminals_intersection(self, rules, new_rules): for state_r in self._states: new_rules.append(DuplicationRule( str((state_p, terminal, state_q)), - str((state_p, "epsilon", state_r)), + str((state_p, Epsilon(), state_r)), str((state_r, terminal, state_q)))) new_rules.append(DuplicationRule( str((state_p, terminal, state_q)), str((state_p, terminal, state_r)), - str((state_r, "epsilon", state_q)))) + str((state_r, Epsilon(), state_q)))) - def _extract_consumption_rules_intersection(self, rules, new_rules): + def _extract_consumption_rules_intersection( + self, + rules: Rules, + new_rules: List[ReducedRule]) \ + -> None: consumptions = rules.consumption_rules for consumption_rule in consumptions: for consumption in consumptions[consumption_rule]: @@ -308,10 +369,10 @@ def _extract_consumption_rules_intersection(self, rules, new_rules): str((state_r, consumption.left_term, state_s)), str((state_r, consumption.right, state_s)))) - def __and__(self, other): + def __and__(self, other: IndexedGrammar) -> IndexedGrammar: return self.intersection(other) - def union(self, other_fst): + def union(self, other_fst: "FST") -> "FST": """ Makes the union of two fst Parameters @@ -332,7 +393,7 @@ def union(self, other_fst): other_fst._copy_into(union_fst, state_renaming, 1) return union_fst - def __or__(self, other_fst): + def __or__(self, other_fst: "FST") -> "FST": """ Makes the union of two fst Parameters @@ -348,33 +409,50 @@ def __or__(self, other_fst): """ return self.union(other_fst) - def _copy_into(self, union_fst, state_renaming, idx): + def _copy_into(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: self._add_extremity_states_to(union_fst, state_renaming, idx) self._add_transitions_to(union_fst, state_renaming, idx) - def _add_transitions_to(self, union_fst, state_renaming, idx): - for head, transition in self.transitions.items(): + def _add_transitions_to(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: + for head, transition in self._delta.items(): s_from, input_symbol = head - for s_to, output_symbols in transition: + for s_to, output_symbol in transition: union_fst.add_transition( - state_renaming.get_name(s_from, idx), + state_renaming.get_renamed_state(s_from, idx), input_symbol, - state_renaming.get_name(s_to, idx), - output_symbols) + state_renaming.get_renamed_state(s_to, idx), + output_symbol) - def _add_extremity_states_to(self, union_fst, state_renaming, idx): + def _add_extremity_states_to(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: self._add_start_states_to(union_fst, state_renaming, idx) self._add_final_states_to(union_fst, state_renaming, idx) - def _add_final_states_to(self, union_fst, state_renaming, idx): + def _add_final_states_to(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: for state in self.final_states: - union_fst.add_final_state(state_renaming.get_name(state, idx)) + union_fst.add_final_state( + state_renaming.get_renamed_state(state, idx)) - def _add_start_states_to(self, union_fst, state_renaming, idx): + def _add_start_states_to(self, + union_fst: "FST", + state_renaming: StateRenaming, + idx: int) -> None: for state in self.start_states: - union_fst.add_start_state(state_renaming.get_name(state, idx)) + union_fst.add_start_state( + state_renaming.get_renamed_state(state, idx)) - def concatenate(self, other_fst): + def concatenate(self, other_fst: "FST") -> "FST": """ Makes the concatenation of two fst Parameters @@ -398,14 +476,14 @@ def concatenate(self, other_fst): for final_state in self.final_states: for start_state in other_fst.start_states: fst_concatenate.add_transition( - state_renaming.get_name(final_state, 0), - "epsilon", - state_renaming.get_name(start_state, 1), + state_renaming.get_renamed_state(final_state, 0), + Epsilon(), + state_renaming.get_renamed_state(start_state, 1), [] ) return fst_concatenate - def __add__(self, other): + def __add__(self, other: "FST") -> "FST": """ Makes the concatenation of two fst Parameters @@ -421,13 +499,13 @@ def __add__(self, other): """ return self.concatenate(other) - def _get_state_renaming(self, other_fst): - state_renaming = FSTStateRemaining() - state_renaming.add_states(list(self.states), 0) + def _get_state_renaming(self, other_fst: "FST") -> StateRenaming: + state_renaming = StateRenaming() + state_renaming.add_states(self.states, 0) state_renaming.add_states(other_fst.states, 1) return state_renaming - def kleene_star(self): + def kleene_star(self) -> "FST": """ Computes the kleene star of the FST @@ -437,29 +515,29 @@ def kleene_star(self): A FST representing the kleene star of the FST """ fst_star = FST() - state_renaming = FSTStateRemaining() - state_renaming.add_states(list(self.states), 0) + state_renaming = StateRenaming() + state_renaming.add_states(self.states, 0) self._add_extremity_states_to(fst_star, state_renaming, 0) self._add_transitions_to(fst_star, state_renaming, 0) for final_state in self.final_states: for start_state in self.start_states: fst_star.add_transition( - state_renaming.get_name(final_state, 0), - "epsilon", - state_renaming.get_name(start_state, 0), + state_renaming.get_renamed_state(final_state, 0), + Epsilon(), + state_renaming.get_renamed_state(start_state, 0), [] ) for final_state in self.start_states: for start_state in self.final_states: fst_star.add_transition( - state_renaming.get_name(final_state, 0), - "epsilon", - state_renaming.get_name(start_state, 0), + state_renaming.get_renamed_state(final_state, 0), + Epsilon(), + state_renaming.get_renamed_state(start_state, 0), [] ) return fst_star - def to_networkx(self) -> nx.MultiDiGraph: + def to_networkx(self) -> MultiDiGraph: """ Transform the current fst into a networkx graph @@ -469,7 +547,7 @@ def to_networkx(self) -> nx.MultiDiGraph: A networkx MultiDiGraph representing the fst """ - graph = nx.MultiDiGraph() + graph = MultiDiGraph() for state in self._states: graph.add_node(state, is_start=state in self.start_states, @@ -489,12 +567,12 @@ def to_networkx(self) -> nx.MultiDiGraph: graph.add_edge( s_from, s_to, - label=(json.dumps(input_symbol) + " -> " + - json.dumps(output_symbols))) + label=(dumps(input_symbol) + " -> " + + dumps(output_symbols))) return graph @classmethod - def from_networkx(cls, graph): + def from_networkx(cls, graph: MultiDiGraph) -> "FST": """ Import a networkx graph into an finite state transducer. \ The imported graph requires to have the good format, i.e. to come \ @@ -521,8 +599,8 @@ def from_networkx(cls, graph): if "label" in transition: in_symbol, out_symbols = transition["label"].split( " -> ") - in_symbol = json.loads(in_symbol) - out_symbols = json.loads(out_symbols) + in_symbol = loads(in_symbol) + out_symbols = loads(out_symbols) fst.add_transition(s_from, in_symbol, s_to, @@ -534,7 +612,7 @@ def from_networkx(cls, graph): fst.add_final_state(node) return fst - def write_as_dot(self, filename): + def write_as_dot(self, filename: str) -> None: """ Write the FST in dot format into a file @@ -546,63 +624,6 @@ def write_as_dot(self, filename): """ write_dot(self.to_networkx(), filename) - -class FSTStateRemaining: - """Class for remaining the states in FST""" - - def __init__(self): - self._state_renaming = {} - self._seen_states = set() - - def add_state(self, state, idx): - """ - Add a state - Parameters - ---------- - state : str - The state to add - idx : int - The index of the FST - """ - if state in self._seen_states: - counter = 0 - new_state = state + str(counter) - while new_state in self._seen_states: - counter += 1 - new_state = state + str(counter) - self._state_renaming[(state, idx)] = new_state - self._seen_states.add(new_state) - else: - self._state_renaming[(state, idx)] = state - self._seen_states.add(state) - - def add_states(self, states, idx): - """ - Add states - Parameters - ---------- - states : list of str - The states to add - idx : int - The index of the FST - """ - for state in states: - self.add_state(state, idx) - - def get_name(self, state, idx): - """ - Get the renaming. - - Parameters - ---------- - state : str - The state to rename - idx : int - The index of the FST - - Returns - ------- - new_name : str - The new name of the state - """ - return self._state_renaming[(state, idx)] + def to_dict(self) -> TransitionFunction: + """Gives the transitions as a dictionary""" + return deepcopy(self._delta) diff --git a/pyformlang/fst/utils.py b/pyformlang/fst/utils.py new file mode 100644 index 0000000..6315edf --- /dev/null +++ b/pyformlang/fst/utils.py @@ -0,0 +1,69 @@ +""" Class for renaming the states in FST """ + +from typing import Dict, Set, Iterable, Tuple + +from ..objects.finite_automaton_objects import State +from ..objects.finite_automaton_objects.utils import to_state + + +class StateRenaming: + """ Class for renaming the states in FST """ + + def __init__(self) -> None: + self._state_renaming: Dict[Tuple[str, int], str] = {} + self._seen_states: Set[str] = set() + + def add_state(self, state: State, idx: int) -> None: + """ + Add a state + Parameters + ---------- + state : State + The state to add + idx : int + The index of the FST + """ + current_name = str(state) + if current_name in self._seen_states: + counter = 0 + new_name = current_name + str(counter) + while new_name in self._seen_states: + counter += 1 + new_name = current_name + str(counter) + self._state_renaming[(current_name, idx)] = new_name + self._seen_states.add(new_name) + else: + self._state_renaming[(current_name, idx)] = current_name + self._seen_states.add(current_name) + + def add_states(self, states: Iterable[State], idx: int) -> None: + """ + Add states + Parameters + ---------- + states : Iterable of States + The states to add + idx : int + The index of the FST + """ + for state in states: + self.add_state(state, idx) + + def get_renamed_state(self, state: State, idx: int) -> State: + """ + Get the renaming. + + Parameters + ---------- + state : State + The state to rename + idx : int + The index of the FST + + Returns + ------- + new_name : State + Renamed state + """ + renaming = self._state_renaming[(str(state), idx)] + return to_state(renaming) From a5ea2f9c795492862b9eaf81f651d9826e98ae07 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Wed, 30 Oct 2024 23:19:17 +0300 Subject: [PATCH 02/12] add indexed_grammar type annotations, rewrite rules in a stricter way --- pyformlang/fst/fst.py | 8 +- .../indexed_grammar/consumption_rule.py | 82 +++---- .../indexed_grammar/duplication_rule.py | 72 +++---- pyformlang/indexed_grammar/end_rule.py | 61 +++--- pyformlang/indexed_grammar/indexed_grammar.py | 201 +++++++++++------- pyformlang/indexed_grammar/production_rule.py | 73 ++++--- pyformlang/indexed_grammar/reduced_rule.py | 89 ++++---- pyformlang/indexed_grammar/rule_ordering.py | 44 ++-- pyformlang/indexed_grammar/rules.py | 75 ++++--- 9 files changed, 381 insertions(+), 324 deletions(-) diff --git a/pyformlang/fst/fst.py b/pyformlang/fst/fst.py index 73f5cab..b8dcc1d 100644 --- a/pyformlang/fst/fst.py +++ b/pyformlang/fst/fst.py @@ -312,7 +312,7 @@ def _extract_indexed_grammar_rules_intersection( new_rules: List[ReducedRule]) \ -> None: for rule in rules.rules: - if rule.is_duplication(): + if isinstance(rule, DuplicationRule): for state_p in self._states: for state_q in self._states: for state_r in self._states: @@ -320,14 +320,14 @@ def _extract_indexed_grammar_rules_intersection( str((state_p, rule.left_term, state_q)), str((state_p, rule.right_terms[0], state_r)), str((state_r, rule.right_terms[1], state_q)))) - elif rule.is_production(): + elif isinstance(rule, ProductionRule): for state_p in self._states: for state_q in self._states: new_rules.append(ProductionRule( str((state_p, rule.left_term, state_q)), str((state_p, rule.right_term, state_q)), str(rule.production))) - elif rule.is_end_rule(): + elif isinstance(rule, EndRule): for state_p in self._states: for state_q in self._states: new_rules.append(DuplicationRule( @@ -367,7 +367,7 @@ def _extract_consumption_rules_intersection( new_rules.append(ConsumptionRule( consumption.f_parameter, str((state_r, consumption.left_term, state_s)), - str((state_r, consumption.right, state_s)))) + str((state_r, consumption.right_term, state_s)))) def __and__(self, other: IndexedGrammar) -> IndexedGrammar: return self.intersection(other) diff --git a/pyformlang/indexed_grammar/consumption_rule.py b/pyformlang/indexed_grammar/consumption_rule.py index 39b2e16..eecc043 100644 --- a/pyformlang/indexed_grammar/consumption_rule.py +++ b/pyformlang/indexed_grammar/consumption_rule.py @@ -3,7 +3,11 @@ the stack """ -from typing import Any, Iterable, AbstractSet +from typing import List, Set, Any + +from pyformlang.cfg import Variable, Terminal +from pyformlang.cfg.utils import to_variable, to_terminal +from pyformlang.cfg.cfg_object import CFGObject from .reduced_rule import ReducedRule @@ -24,30 +28,19 @@ class ConsumptionRule(ReducedRule): """ @property - def right_term(self): - raise NotImplementedError - - @property - def right_terms(self): + def production(self) -> Terminal: raise NotImplementedError - def __init__(self, f_param: Any, left: Any, right: Any): - self._f = f_param - self._right = right - self._left_term = left - - def is_consumption(self) -> bool: - """Whether the rule is a consumption rule or not - - Returns - ---------- - is_consumption : bool - Whether the rule is a consumption rule or not - """ - return True + def __init__(self, + f_param: Any, + left_term: Any, + right_term: Any) -> None: + self._f = to_terminal(f_param) + self._left_term = to_variable(left_term) + self._right_term = to_variable(right_term) @property - def f_parameter(self) -> Any: + def f_parameter(self) -> Terminal: """Gets the symbol which is consumed Returns @@ -58,38 +51,45 @@ def f_parameter(self) -> Any: return self._f @property - def production(self): - raise NotImplementedError + def left_term(self) -> Variable: + """Gets the symbol on the left of the rule + + left : any + The left symbol of the rule + """ + return self._left_term @property - def right(self) -> Any: - """Gets the symbole on the right of the rule + def right_term(self) -> Variable: + """Gets the symbol on the right of the rule right : any The right symbol """ - return self._right + return self._right_term @property - def left_term(self) -> Any: - """Gets the symbol on the left of the rule + def right_terms(self) -> List[CFGObject]: + """Gives the non-terminals on the right of the rule - left : any - The left symbol of the rule + Returns + --------- + right_terms : iterable of any + The right terms of the rule """ - return self._left_term + return [self._right_term] @property - def non_terminals(self) -> Iterable[Any]: + def non_terminals(self) -> Set[Variable]: """Gets the non-terminals used in the rule non_terminals : iterable of any The non_terminals used in the rule """ - return [self._left_term, self._right] + return {self._left_term, self._right_term} @property - def terminals(self) -> AbstractSet[Any]: + def terminals(self) -> Set[Terminal]: """Gets the terminals used in the rule terminals : set of any @@ -97,10 +97,12 @@ def terminals(self) -> AbstractSet[Any]: """ return {self._f} - def __repr__(self): - return self._left_term + " [ " + self._f + " ] -> " + self._right + def __repr__(self) -> str: + return f"{self._left_term} [ {self._f} ] -> {self._right_term}" - def __eq__(self, other): - return other.is_consumption() and other.left_term == \ - self.left_term and other.right == self.right and \ - other.f_parameter() == self.f_parameter + def __eq__(self, other: Any) -> bool: + if not isinstance(other, ConsumptionRule): + return False + return other.left_term == self.left_term \ + and other.right_term == self.right_term \ + and other.f_parameter == self.f_parameter diff --git a/pyformlang/indexed_grammar/duplication_rule.py b/pyformlang/indexed_grammar/duplication_rule.py index cc719b5..5b81a32 100644 --- a/pyformlang/indexed_grammar/duplication_rule.py +++ b/pyformlang/indexed_grammar/duplication_rule.py @@ -2,7 +2,11 @@ A representation of a duplication rule, i.e. a rule that duplicates the stack """ -from typing import Any, Iterable, AbstractSet, Tuple +from typing import List, Set, Any + +from pyformlang.cfg import Variable, Terminal +from pyformlang.cfg.utils import to_variable +from pyformlang.cfg.cfg_object import CFGObject from .reduced_rule import ReducedRule @@ -22,33 +26,38 @@ class DuplicationRule(ReducedRule): """ @property - def production(self): + def f_parameter(self) -> Terminal: raise NotImplementedError @property - def right_term(self): + def production(self) -> Terminal: raise NotImplementedError @property - def f_parameter(self): + def right_term(self) -> CFGObject: raise NotImplementedError - def __init__(self, left_term, right_term0, right_term1): - self._left_term = left_term - self._right_terms = (right_term0, right_term1) + def __init__(self, + left_term: Any, + right_term0: Any, + right_term1: Any) -> None: + self._left_term = to_variable(left_term) + self._right_terms = (to_variable(right_term0), + to_variable(right_term1)) - def is_duplication(self) -> bool: - """Whether the rule is a duplication rule or not + @property + def left_term(self) -> Variable: + """Gives the non-terminal on the left of the rule Returns - ---------- - is_duplication : bool - Whether the rule is a duplication rule or not + --------- + left_term : any + The left term of the rule """ - return True + return self._left_term @property - def right_terms(self) -> Tuple[Any, Any]: + def right_terms(self) -> List[CFGObject]: """Gives the non-terminals on the right of the rule Returns @@ -56,21 +65,10 @@ def right_terms(self) -> Tuple[Any, Any]: right_terms : iterable of any The right terms of the rule """ - return self._right_terms + return list(self._right_terms) @property - def left_term(self) -> Any: - """Gives the non-terminal on the left of the rule - - Returns - --------- - left_term : any - The left term of the rule - """ - return self._left_term - - @property - def non_terminals(self) -> Iterable[Any]: + def non_terminals(self) -> Set[Variable]: """Gives the set of non-terminals used in this rule Returns @@ -78,10 +76,10 @@ def non_terminals(self) -> Iterable[Any]: non_terminals : iterable of any The non terminals used in this rule """ - return [self._left_term, self._right_terms[0], self._right_terms[1]] + return {self._left_term, *self._right_terms} @property - def terminals(self) -> AbstractSet[Any]: + def terminals(self) -> Set[Terminal]: """Gets the terminals used in the rule Returns @@ -91,11 +89,13 @@ def terminals(self) -> AbstractSet[Any]: """ return set() - def __repr__(self): + def __repr__(self) -> str: """Gives a string representation of the rule, ignoring the sigmas""" - return self._left_term + " -> " + self._right_terms[0] + \ - " " + self._right_terms[1] - - def __eq__(self, other): - return other.is_duplication() and other.left_term == \ - self._left_term and other.right_terms == self.right_terms + return f"{self._left_term} -> \ + {self._right_terms[0]} {self._right_terms[1]}" + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DuplicationRule): + return False + return other.left_term == self._left_term \ + and other.right_terms == self.right_terms diff --git a/pyformlang/indexed_grammar/end_rule.py b/pyformlang/indexed_grammar/end_rule.py index 7979b84..78ed7a6 100644 --- a/pyformlang/indexed_grammar/end_rule.py +++ b/pyformlang/indexed_grammar/end_rule.py @@ -2,7 +2,11 @@ Represents a end rule, i.e. a rule which give only a terminal """ -from typing import Any, Iterable, AbstractSet +from typing import List, Set, Any + +from pyformlang.cfg import Variable, Terminal +from pyformlang.cfg.utils import to_variable, to_terminal +from pyformlang.cfg.cfg_object import CFGObject from .reduced_rule import ReducedRule @@ -20,29 +24,30 @@ class EndRule(ReducedRule): """ @property - def production(self): + def f_parameter(self) -> Terminal: raise NotImplementedError @property - def right_terms(self): + def production(self) -> Terminal: raise NotImplementedError - def __init__(self, left, right): - self._left_term = left - self._right_term = right + def __init__(self, left_term: Any, right_term: Any) -> None: + self._left_term = to_variable(left_term) + self._right_term = to_terminal(right_term) - def is_end_rule(self) -> bool: - """Whether the rule is an end rule or not + @property + def left_term(self) -> Variable: + """Gets the non-terminal on the left of the rule Returns - ---------- - is_end : bool - Whether the rule is an end rule or not + --------- + left_term : any + The left non-terminal of the rule """ - return True + return self._left_term @property - def right_term(self) -> Any: + def right_term(self) -> Terminal: """Gets the terminal on the right of the rule Returns @@ -53,18 +58,18 @@ def right_term(self) -> Any: return self._right_term @property - def left_term(self) -> Any: - """Gets the non-terminal on the left of the rule + def right_terms(self) -> List[CFGObject]: + """Gives the terminals on the right of the rule Returns --------- - left_term : any - The left non-terminal of the rule + right_terms : iterable of any + The right terms of the rule """ - return self._left_term + return [self._right_term] @property - def non_terminals(self) -> Iterable[Any]: + def non_terminals(self) -> Set[Variable]: """Gets the non-terminals used Returns @@ -72,10 +77,10 @@ def non_terminals(self) -> Iterable[Any]: non_terminals : iterable of any The non terminals used in this rule """ - return [self._left_term] + return {self._left_term} @property - def terminals(self) -> AbstractSet[Any]: + def terminals(self) -> Set[Terminal]: """Gets the terminals used Returns @@ -85,14 +90,12 @@ def terminals(self) -> AbstractSet[Any]: """ return {self._right_term} - def __repr__(self): + def __repr__(self) -> str: """Gets the string representation of the rule""" - return self._left_term + " -> " + self._right_term + return f"{self._left_term} -> {self._right_term}" - def __eq__(self, other): - return other.is_end_rule() and other.left_term == self.left_term\ + def __eq__(self, other: Any) -> bool: + if not isinstance(other, EndRule): + return False + return other.left_term == self.left_term \ and other.right_term == self.right_term - - @property - def f_parameter(self): - raise NotImplementedError diff --git a/pyformlang/indexed_grammar/indexed_grammar.py b/pyformlang/indexed_grammar/indexed_grammar.py index 5e66917..762557a 100644 --- a/pyformlang/indexed_grammar/indexed_grammar.py +++ b/pyformlang/indexed_grammar/indexed_grammar.py @@ -2,12 +2,20 @@ Representation of an indexed grammar """ -from typing import Any, Iterable, AbstractSet +# pylint: disable=cell-var-from-loop -import pyformlang +from typing import Callable, Dict, List, \ + Set, FrozenSet, Tuple, Iterable, Any + +from pyformlang.cfg import Variable, Terminal +from pyformlang.cfg.utils import to_variable +from pyformlang.cfg.cfg_object import CFGObject +from pyformlang.finite_automaton import FiniteAutomaton +from pyformlang.regular_expression import Regex from .duplication_rule import DuplicationRule from .production_rule import ProductionRule +from .end_rule import EndRule from .rules import Rules @@ -24,40 +32,62 @@ class IndexedGrammar: def __init__(self, rules: Rules, - start_variable: Any = "S"): - self.rules = rules - self.start_variable = start_variable + start_variable: Any = "S") -> None: + self._rules = rules + self._start_variable = to_variable(start_variable) # Precompute all non-terminals - self.non_terminals = rules.non_terminals - self.non_terminals.append(self.start_variable) - self.non_terminals = set(self.non_terminals) + non_terminals = self.non_terminals # We cache the marked items in case of future update of the query - self.marked = {} + self._marked: Dict[CFGObject, Set[FrozenSet[Variable]]] = {} # Initialize the marked symbols # Mark the identity - for non_terminal_a in self.non_terminals: - self.marked[non_terminal_a] = set() + for non_terminal_a in non_terminals: + self._marked[non_terminal_a] = set() temp = frozenset({non_terminal_a}) - self.marked[non_terminal_a].add(temp) + self._marked[non_terminal_a].add(temp) # Mark all end symbols - for non_terminal_a in self.non_terminals: - if exists(self.rules.rules, + for non_terminal_a in non_terminals: + if exists(self._rules.rules, lambda x: x.is_end_rule() and x.left_term == non_terminal_a): - self.marked[non_terminal_a].add(frozenset()) + self._marked[non_terminal_a].add(frozenset()) + + @property + def rules(self) -> Rules: + """ Get the rules of the grammar """ + return self._rules + + @property + def start_variable(self) -> Variable: + """ Get the start variable of the grammar """ + return self._start_variable + + @property + def non_terminals(self) -> Set[Variable]: + """Get all the non-terminals in the grammar + + Returns + ---------- + terminals : iterable of any + The non-terminals used in the grammar + """ + non_terminals = self._rules.non_terminals + non_terminals.add(self._start_variable) + return non_terminals @property - def terminals(self) -> Iterable[Any]: + def terminals(self) -> Set[Terminal]: """Get all the terminals in the grammar Returns ---------- terminals : iterable of any - The terminals used in the rules + The terminals used in the grammar """ - return self.rules.terminals + return self._rules.terminals - def _duplication_processing(self, rule: DuplicationRule): + def _duplication_processing(self, rule: DuplicationRule) \ + -> Tuple[bool, bool]: """Processes a duplication rule Parameters @@ -68,9 +98,9 @@ def _duplication_processing(self, rule: DuplicationRule): was_modified = False need_stop = False right_term_marked0 = [] - for marked_term0 in self.marked[rule.right_terms[0]]: + for marked_term0 in self._marked[rule.right_terms[0]]: right_term_marked1 = [] - for marked_term1 in self.marked[rule.right_terms[1]]: + for marked_term1 in self._marked[rule.right_terms[1]]: if marked_term0 <= marked_term1: temp = marked_term1 elif marked_term1 <= marked_term0: @@ -78,26 +108,27 @@ def _duplication_processing(self, rule: DuplicationRule): else: temp = marked_term0.union(marked_term1) # Check if it was marked before - if temp not in self.marked[rule.left_term]: + if temp not in self._marked[rule.left_term]: was_modified = True if rule.left_term == rule.right_terms[0]: right_term_marked0.append(temp) elif rule.left_term == rule.right_terms[1]: right_term_marked1.append(temp) else: - self.marked[rule.left_term].add(temp) + self._marked[rule.left_term].add(temp) # Stop condition, no need to continue - if rule.left_term == self.start_variable and len( + if rule.left_term == self._start_variable and len( temp) == 0: need_stop = True for temp in right_term_marked1: - self.marked[rule.right_terms[1]].add(temp) + self._marked[rule.right_terms[1]].add(temp) for temp in right_term_marked0: - self.marked[rule.right_terms[0]].add(temp) + self._marked[rule.right_terms[0]].add(temp) return was_modified, need_stop - def _production_process(self, rule: ProductionRule): + def _production_process(self, rule: ProductionRule) \ + -> Tuple[bool, bool]: """Processes a production rule Parameters @@ -108,19 +139,19 @@ def _production_process(self, rule: ProductionRule): was_modified = False # f_rules contains the consumption rules associated with # the current production symbol - f_rules = self.rules.consumption_rules.setdefault( + f_rules = self._rules.consumption_rules.setdefault( rule.production, []) # l_rules contains the left symbol plus what is marked on # the right side l_temp = [(x.left_term, - self.marked[x.right]) for x in f_rules] + self._marked[x.right_term]) for x in f_rules] marked_symbols = [x.left_term for x in f_rules] # Process all combinations of consumption rule was_modified |= addrec_bis(l_temp, - self.marked[rule.left_term], - self.marked[rule.right_term]) + self._marked[rule.left_term], + self._marked[rule.right_term]) # End condition - if frozenset() in self.marked[self.start_variable]: + if frozenset() in self._marked[self._start_variable]: return was_modified, True # Is it useful? if rule.right_term in marked_symbols: @@ -129,17 +160,17 @@ def _production_process(self, rule: ProductionRule): for sub_term in [sub_term for sub_term in term[1] if sub_term not in - self.marked[rule.left_term]]: + self._marked[rule.left_term]]: was_modified = True - self.marked[rule.left_term].add(sub_term) - if (rule.left_term == self.start_variable and + self._marked[rule.left_term].add(sub_term) + if (rule.left_term == self._start_variable and len(sub_term) == 0): return was_modified, True # Edge case - if frozenset() in self.marked[rule.right_term]: - if frozenset() not in self.marked[rule.left_term]: + if frozenset() in self._marked[rule.right_term]: + if frozenset() not in self._marked[rule.left_term]: was_modified = True - self.marked[rule.left_term].add(frozenset()) + self._marked[rule.left_term].add(frozenset()) return was_modified, False def is_empty(self) -> bool: @@ -154,28 +185,28 @@ def is_empty(self) -> bool: was_modified = True while was_modified: was_modified = False - for rule in self.rules.rules: + for rule in self._rules.rules: # If we have a duplication rule, we mark all combinations of # the sets marked on the right side for the symbol on the left # side - if rule.is_duplication(): + if isinstance(rule, DuplicationRule): dup_res = self._duplication_processing(rule) was_modified |= dup_res[0] if dup_res[1]: return False - elif rule.is_production(): + elif isinstance(rule, ProductionRule): prod_res = self._production_process(rule) if prod_res[1]: return False was_modified |= prod_res[0] - if frozenset() in self.marked[self.start_variable]: + if frozenset() in self._marked[self._start_variable]: return False return True - def __bool__(self): + def __bool__(self) -> bool: return not self.is_empty() - def get_reachable_non_terminals(self) -> AbstractSet[Any]: + def get_reachable_non_terminals(self) -> Set[Variable]: """ Get the reachable symbols Returns @@ -184,10 +215,10 @@ def get_reachable_non_terminals(self) -> AbstractSet[Any]: The reachable symbols from the start state """ # Preprocess - reachable_from = {} - consumption_rules = self.rules.consumption_rules - for rule in self.rules.rules: - if rule.is_duplication(): + reachable_from: Dict[Variable, Set[CFGObject]] = {} + consumption_rules = self._rules.consumption_rules + for rule in self._rules.rules: + if isinstance(rule, DuplicationRule): left = rule.left_term right0 = rule.right_terms[0] right1 = rule.right_terms[1] @@ -195,7 +226,7 @@ def get_reachable_non_terminals(self) -> AbstractSet[Any]: reachable_from[left] = set() reachable_from[left].add(right0) reachable_from[left].add(right1) - if rule.is_production(): + if isinstance(rule, ProductionRule): left = rule.left_term right = rule.right_term if left not in reachable_from: @@ -204,22 +235,23 @@ def get_reachable_non_terminals(self) -> AbstractSet[Any]: for key in consumption_rules: for rule in consumption_rules[key]: left = rule.left_term - right = rule.right + right = rule.right_term if left not in reachable_from: reachable_from[left] = set() reachable_from[left].add(right) # Processing - to_process = [self.start_variable] - reachables = {self.start_variable} + to_process = [self._start_variable] + reachables = {self._start_variable} while to_process: current = to_process.pop() - for symbol in reachable_from.get(current, []): + for symbol in reachable_from.get(current, set()): if symbol not in reachables: - reachables.add(symbol) - to_process.append(symbol) + variable = to_variable(symbol) + reachables.add(variable) + to_process.append(variable) return reachables - def get_generating_non_terminals(self) -> AbstractSet[Any]: + def get_generating_non_terminals(self) -> Set[Variable]: """ Get the generating symbols Returns @@ -250,20 +282,28 @@ def get_generating_non_terminals(self) -> AbstractSet[Any]: to_process.append(duplication[0]) return generating - def _preprocess_consumption_rules_generating(self, generating_from): - for key in self.rules.consumption_rules: - for rule in self.rules.consumption_rules[key]: + def _preprocess_consumption_rules_generating( + self, + generating_from: Dict[CFGObject, Set[Variable]]) \ + -> None: + for key in self._rules.consumption_rules: + for rule in self._rules.consumption_rules[key]: left = rule.left_term - right = rule.right + right = rule.right_term if right in generating_from: generating_from[right].add(left) else: generating_from[right] = {left} - def _preprocess_rules_generating(self, duplication_pointer, generating, - generating_from, to_process): - for rule in self.rules.rules: - if rule.is_duplication(): + def _preprocess_rules_generating( + self, + duplication_pointer: Dict[CFGObject, List], + generating: Set[Variable], + generating_from: Dict[CFGObject, Set[Variable]], + to_process: List[Variable]) \ + -> None: + for rule in self._rules.rules: + if isinstance(rule, DuplicationRule): left = rule.left_term right0 = rule.right_terms[0] right1 = rule.right_terms[1] @@ -276,14 +316,14 @@ def _preprocess_rules_generating(self, duplication_pointer, generating, duplication_pointer[right1].append(temp) else: duplication_pointer[right1] = [temp] - if rule.is_production(): + if isinstance(rule, ProductionRule): left = rule.left_term right = rule.right_term if right in generating_from: generating_from[right].add(left) else: generating_from[right] = {left} - if rule.is_end_rule(): + if isinstance(rule, EndRule): left = rule.left_term if left not in generating: generating.add(left) @@ -303,33 +343,33 @@ def remove_useless_rules(self) -> "IndexedGrammar": l_rules = [] generating = self.get_generating_non_terminals() reachables = self.get_reachable_non_terminals() - consumption_rules = self.rules.consumption_rules - for rule in self.rules.rules: - if rule.is_duplication(): + consumption_rules = self._rules.consumption_rules + for rule in self._rules.rules: + if isinstance(rule, DuplicationRule): left = rule.left_term right0 = rule.right_terms[0] right1 = rule.right_terms[1] if all(x in generating and x in reachables for x in [left, right0, right1]): l_rules.append(rule) - if rule.is_production(): + if isinstance(rule, ProductionRule): left = rule.left_term right = rule.right_term if all(x in generating and x in reachables for x in [left, right]): l_rules.append(rule) - if rule.is_end_rule(): + if isinstance(rule, EndRule): left = rule.left_term if left in generating and left in reachables: l_rules.append(rule) for key in consumption_rules: for rule in consumption_rules[key]: left = rule.left_term - right = rule.right + right = rule.right_term if all(x in generating and x in reachables for x in [left, right]): l_rules.append(rule) - rules = Rules(l_rules, self.rules.optim) + rules = Rules(l_rules, self._rules.optim) return IndexedGrammar(rules) def intersection(self, other: Any) -> "IndexedGrammar": @@ -356,14 +396,14 @@ def intersection(self, other: Any) -> "IndexedGrammar": When trying to intersection with something else than a regular expression or a finite automaton """ - if isinstance(other, pyformlang.regular_expression.Regex): + if isinstance(other, Regex): other = other.to_epsilon_nfa() - if isinstance(other, pyformlang.finite_automaton.FiniteAutomaton): + if isinstance(other, FiniteAutomaton): fst = other.to_fst() return fst.intersection(self) raise NotImplementedError - def __and__(self, other): + def __and__(self, other: Any) -> "IndexedGrammar": """ Computes the intersection of the current indexed grammar with the other object @@ -380,7 +420,8 @@ def __and__(self, other): return self.intersection(other) -def exists(list_elements, check_function): +def exists(list_elements: List[Any], + check_function: Callable[[Any], bool]) -> bool: """exists Check whether at least an element x of l is True for f(x) :param list_elements: A list of elements to test @@ -393,7 +434,9 @@ def exists(list_elements, check_function): return False -def addrec_bis(l_sets, marked_left, marked_right): +def addrec_bis(l_sets: Iterable[Any], + marked_left: Set[Any], + marked_right: Set[Any]) -> bool: """addrec_bis Optimized version of addrec :param l_sets: a list containing tuples (C, M) where: @@ -415,7 +458,7 @@ def addrec_bis(l_sets, marked_left, marked_right): return was_modified -def addrec_ter(l_sets, marked_left): +def addrec_ter(l_sets: List[Any], marked_left: Set[Any]) -> bool: """addrec Explores all possible combination of consumption rules to mark a production rule. diff --git a/pyformlang/indexed_grammar/production_rule.py b/pyformlang/indexed_grammar/production_rule.py index bccb68f..4321ea7 100644 --- a/pyformlang/indexed_grammar/production_rule.py +++ b/pyformlang/indexed_grammar/production_rule.py @@ -2,7 +2,11 @@ Represents a production rule, i.e. a rule that pushed on the stack """ -from typing import Any, Iterable, AbstractSet +from typing import List, Set, Any + +from pyformlang.cfg import Variable, Terminal +from pyformlang.cfg.utils import to_variable, to_terminal +from pyformlang.cfg.cfg_object import CFGObject from .reduced_rule import ReducedRule @@ -22,30 +26,19 @@ class ProductionRule(ReducedRule): """ @property - def right_terms(self): - raise NotImplementedError - - @property - def f_parameter(self): + def f_parameter(self) -> Terminal: raise NotImplementedError - def __init__(self, left, right, prod): - self._production = prod - self._left_term = left - self._right_term = right - - def is_production(self) -> bool: - """Whether the rule is a production rule or not - - Returns - ---------- - is_production : bool - Whether the rule is a production rule or not - """ - return True + def __init__(self, + left_term: Any, + right_term: Any, + production: Any) -> None: + self._left_term = to_variable(left_term) + self._right_term = to_variable(right_term) + self._production = to_terminal(production) @property - def production(self) -> Any: + def production(self) -> Terminal: """Gets the terminal used in the production Returns @@ -56,7 +49,7 @@ def production(self) -> Any: return self._production @property - def left_term(self) -> Any: + def left_term(self) -> Variable: """Gets the non-terminal on the left side of the rule Returns @@ -67,7 +60,7 @@ def left_term(self) -> Any: return self._left_term @property - def right_term(self) -> Any: + def right_term(self) -> Variable: """Gets the non-terminal on the right side of the rule Returns @@ -78,7 +71,18 @@ def right_term(self) -> Any: return self._right_term @property - def non_terminals(self) -> Iterable[Any]: + def right_terms(self) -> List[CFGObject]: + """Gives the non-terminals on the right of the rule + + Returns + --------- + right_terms : iterable of any + The right terms of the rule + """ + return [self._right_term] + + @property + def non_terminals(self) -> Set[Variable]: """Gets the non-terminals used in the rule Returns @@ -86,10 +90,10 @@ def non_terminals(self) -> Iterable[Any]: non_terminals : any The non terminals used in this rules """ - return [self._left_term, self._right_term] + return {self._left_term, self._right_term} @property - def terminals(self) -> AbstractSet[Any]: + def terminals(self) -> Set[Terminal]: """Gets the terminals used in the rule Returns @@ -99,12 +103,13 @@ def terminals(self) -> AbstractSet[Any]: """ return {self._production} - def __repr__(self): + def __repr__(self) -> str: """Gets the string representation of the rule""" - return self._left_term + " -> " + \ - self._right_term + "[ " + self._production + " ]" - - def __eq__(self, other): - return other.is_production() and other.left_term == \ - self.left_term and other.right_term == self.right_term \ - and other.production == self.production + return f"{self._left_term} -> {self._right_term} [ {self._production} ]" + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, ProductionRule): + return False + return other.left_term == self.left_term \ + and other.right_term == self.right_term \ + and other.production == self.production diff --git a/pyformlang/indexed_grammar/reduced_rule.py b/pyformlang/indexed_grammar/reduced_rule.py index 5d463ad..20ae7dd 100644 --- a/pyformlang/indexed_grammar/reduced_rule.py +++ b/pyformlang/indexed_grammar/reduced_rule.py @@ -1,8 +1,13 @@ """ Representation of a reduced rule """ + +from typing import List, Set from abc import abstractmethod +from pyformlang.cfg import Variable, Terminal +from pyformlang.cfg.cfg_object import CFGObject + class ReducedRule: """Representation of all possible reduced forms. @@ -13,102 +18,82 @@ class ReducedRule: * Duplication """ - def is_consumption(self) -> bool: - """Whether the rule is a consumption rule or not - - Returns - ---------- - is_consumption : bool - Whether the rule is a consumption rule or not - """ - return False - - def is_duplication(self) -> bool: - """Whether the rule is a duplication rule or not - - Returns - ---------- - is_duplication : bool - Whether the rule is a duplication rule or not - """ - return False - - def is_production(self) -> bool: - """Whether the rule is a production rule or not + @property + @abstractmethod + def f_parameter(self) -> Terminal: + """The f parameter Returns ---------- - is_production : bool - Whether the rule is a production rule or not + f : cfg.Terminal + The f parameter """ - return False + raise NotImplementedError - def is_end_rule(self) -> bool: - """Whether the rule is an end rule or not + @property + @abstractmethod + def production(self) -> Terminal: + """The production Returns ---------- - is_end : bool - Whether the rule is an end rule or not + right_terms : any + The production """ - return False + raise NotImplementedError @property @abstractmethod - def f_parameter(self): - """The f parameter + def left_term(self) -> Variable: + """The left term Returns ---------- - f : any - The f parameter + left_term : cfg.Variable + The left term of the rule """ raise NotImplementedError @property @abstractmethod - def left_term(self): - """The left term + def right_term(self) -> CFGObject: + """The unique right term Returns ---------- - left_term : any - The left term of the rule + right_term : cfg.cfg_object.CFGObject + The unique right term of the rule """ raise NotImplementedError @property @abstractmethod - def right_terms(self): + def right_terms(self) -> List[CFGObject]: """The right terms Returns ---------- - right_terms : iterable of any + right_terms : list of cfg.cfg_object.CFGObject The right terms of the rule """ raise NotImplementedError @property @abstractmethod - def right_term(self): - """The unique right term + def non_terminals(self) -> Set[Variable]: + """Gets the non-terminals used in the rule - Returns - ---------- - right_term : iterable of any - The unique right term of the rule + terminals : set of cfg.Variable + The non-terminals used in the rule """ raise NotImplementedError @property @abstractmethod - def production(self): - """The production + def terminals(self) -> Set[Terminal]: + """Gets the terminals used in the rule - Returns - ---------- - right_terms : any - The production + terminals : set of cfg.Terminal + The terminals used in the rule """ raise NotImplementedError diff --git a/pyformlang/indexed_grammar/rule_ordering.py b/pyformlang/indexed_grammar/rule_ordering.py index af9236d..677c9f5 100644 --- a/pyformlang/indexed_grammar/rule_ordering.py +++ b/pyformlang/indexed_grammar/rule_ordering.py @@ -2,15 +2,18 @@ Representation of a way to order rules """ -from typing import Iterable, Dict, Any +from typing import List, Dict from queue import Queue -import random +from random import shuffle +from networkx import DiGraph, core_number, minimum_spanning_tree -import networkx as nx +from pyformlang.cfg import Terminal from .reduced_rule import ReducedRule from .consumption_rule import ConsumptionRule +from .duplication_rule import DuplicationRule +from .production_rule import ProductionRule class RuleOrdering: @@ -25,12 +28,13 @@ class RuleOrdering: The consumption rules of the indexed grammar """ - def __init__(self, rules: Iterable[ReducedRule], - conso_rules: Dict[Any, ConsumptionRule]): + def __init__(self, + rules: List[ReducedRule], + conso_rules: Dict[Terminal, List[ConsumptionRule]]) -> None: self.rules = rules self.conso_rules = conso_rules - def reverse(self) -> Iterable[ReducedRule]: + def reverse(self) -> List[ReducedRule]: """The reverser ordering, simply reverse the order. Returns @@ -41,26 +45,26 @@ def reverse(self) -> Iterable[ReducedRule]: """ return self.rules[::1] - def _get_graph(self): + def _get_graph(self) -> DiGraph: """ Get the graph of the non-terminals in the rules. If there there is a link between A and B (oriented), it means that modifying A may modify B""" - di_graph = nx.DiGraph() + di_graph = DiGraph() for rule in self.rules: - if rule.is_duplication(): + if isinstance(rule, DuplicationRule): if rule.right_terms[0] != rule.left_term: di_graph.add_edge(rule.right_terms[0], rule.left_term) if rule.right_terms[1] != rule.left_term: di_graph.add_edge(rule.right_terms[1], rule.left_term) - if rule.is_production(): + if isinstance(rule, ProductionRule): f_rules = self.conso_rules.setdefault( rule.production, []) for f_rule in f_rules: - if f_rule.right != rule.left_term: - di_graph.add_edge(f_rule.right, rule.left_term) + if f_rule.right_term != rule.left_term: + di_graph.add_edge(f_rule.right_term, rule.left_term) return di_graph - def order_by_core(self, reverse: bool = False) -> Iterable[ReducedRule]: + def order_by_core(self, reverse: bool = False) -> List[ReducedRule]: """Order the rules using the core numbers Parameters @@ -77,7 +81,7 @@ def order_by_core(self, reverse: bool = False) -> Iterable[ReducedRule]: # Graph construction di_graph = self._get_graph() # Get core number, careful the degree is in + out - core_numbers = nx.core_number(di_graph) + core_numbers = dict(core_number(di_graph)) new_order = sorted(self.rules, key=lambda x: core_numbers.setdefault( x.left_term, 0)) @@ -86,7 +90,7 @@ def order_by_core(self, reverse: bool = False) -> Iterable[ReducedRule]: return new_order def order_by_arborescence(self, reverse: bool = True) \ - -> Iterable[ReducedRule]: + -> List[ReducedRule]: """Order the rules using the arborescence method. Parameters @@ -101,7 +105,7 @@ def order_by_arborescence(self, reverse: bool = True) \ The rules ordered using core number """ di_graph = self._get_graph() - arborescence = nx.minimum_spanning_tree(di_graph.to_undirected()) + arborescence = minimum_spanning_tree(di_graph.to_undirected()) to_process = Queue() processed = set() res = {} @@ -126,7 +130,7 @@ def order_by_arborescence(self, reverse: bool = True) \ return new_order @staticmethod - def _get_len_out(di_graph, rule): + def _get_len_out(di_graph: DiGraph, rule: ReducedRule) -> int: """Get the number of out edges of a rule (more exactly, the non \ terminal at its left. @@ -141,7 +145,7 @@ def _get_len_out(di_graph, rule): return len(di_graph[rule.left_term]) return 0 - def order_by_edges(self, reverse=False): + def order_by_edges(self, reverse: bool = False) -> List[ReducedRule]: """Order using the number of edges. Parameters @@ -162,7 +166,7 @@ def order_by_edges(self, reverse=False): new_order.reverse() return new_order - def order_random(self): + def order_random(self) -> List[ReducedRule]: """The random ordering Returns @@ -171,5 +175,5 @@ def order_random(self): :class:`~pyformlang.indexed_grammar.ReducedRule` The rules ordered at random """ - random.shuffle(self.rules) + shuffle(self.rules) return self.rules diff --git a/pyformlang/indexed_grammar/rules.py b/pyformlang/indexed_grammar/rules.py index 017845a..9e431d7 100644 --- a/pyformlang/indexed_grammar/rules.py +++ b/pyformlang/indexed_grammar/rules.py @@ -2,7 +2,10 @@ Representations of rules in a indexed grammar """ -from typing import Iterable, Dict, Any, List +from typing import Dict, List, Set, Tuple, Iterable, Any + +from pyformlang.cfg import Variable, Terminal +from pyformlang.cfg.utils import to_variable, to_terminal from .production_rule import ProductionRule from .consumption_rule import ConsumptionRule @@ -30,13 +33,13 @@ class Rules: 8 -> random order """ - def __init__(self, rules: Iterable[ReducedRule], optim: int = 7): - self._rules = [] - self._consumption_rules = {} + def __init__(self, rules: Iterable[ReducedRule], optim: int = 7) -> None: + self._rules: List[ReducedRule] = [] + self._consumption_rules: Dict[Terminal, List[ConsumptionRule]] = {} self._optim = optim for rule in rules: # We separate consumption rule from other - if rule.is_consumption(): + if isinstance(rule, ConsumptionRule): temp = self._consumption_rules.setdefault(rule.f_parameter, []) if rule not in temp: temp.append(rule) @@ -63,7 +66,7 @@ def __init__(self, rules: Iterable[ReducedRule], optim: int = 7): self._rules = rule_ordering.order_random() @property - def optim(self): + def optim(self) -> int: """Gets the optimization number Returns @@ -74,7 +77,7 @@ def optim(self): return self._optim @property - def rules(self) -> Iterable[ReducedRule]: + def rules(self) -> List[ReducedRule]: """Gets the non consumption rules Returns @@ -86,7 +89,7 @@ def rules(self) -> Iterable[ReducedRule]: return self._rules @property - def length(self) -> (int, int): + def length(self) -> Tuple[int, int]: """Get the total number of rules Returns @@ -98,7 +101,7 @@ def length(self) -> (int, int): return len(self._rules), len(self._consumption_rules.values()) @property - def consumption_rules(self) -> Dict[Any, Iterable[ConsumptionRule]]: + def consumption_rules(self) -> Dict[Terminal, List[ConsumptionRule]]: """Gets the consumption rules Returns @@ -111,40 +114,43 @@ def consumption_rules(self) -> Dict[Any, Iterable[ConsumptionRule]]: return self._consumption_rules @property - def terminals(self) -> Iterable[Any]: - """Gets all the terminals used by all the rules + def non_terminals(self) -> Set[Variable]: + """Gets all the non-terminals used by all the rules Returns ---------- - terminals : iterable of any - The terminals used in the rules + non_terminals : iterable of any + The non terminals used in the rule """ terminals = set() for temp_rule in self._consumption_rules.values(): for rule in temp_rule: - terminals = terminals.union(rule.terminals) + terminals = terminals.union(rule.non_terminals) for rule in self._rules: - terminals = terminals.union(rule.terminals) + terminals = terminals.union(rule.non_terminals) return terminals @property - def non_terminals(self) -> List[Any]: - """Gets all the non-terminals used by all the rules + def terminals(self) -> Set[Terminal]: + """Gets all the terminals used by all the rules Returns ---------- - non_terminals : iterable of any - The non terminals used in the rule + terminals : iterable of any + The terminals used in the rules """ terminals = set() - for temp_rule in self._consumption_rules.values(): - for rule in temp_rule: - terminals = terminals.union(set(rule.non_terminals)) + for rules in self._consumption_rules.values(): + for rule in rules: + terminals = terminals.union(rule.terminals) for rule in self._rules: - terminals = terminals.union(set(rule.non_terminals)) - return list(terminals) + terminals = terminals.union(rule.terminals) + return terminals - def remove_production(self, left: Any, right: Any, prod: Any): + def remove_production(self, + left: Any, + right: Any, + prod: Any) -> None: """Remove the production rule: left[sigma] -> right[prod sigma] @@ -157,13 +163,19 @@ def remove_production(self, left: Any, right: Any, prod: Any): prod : any The production used in the rule """ - self._rules = list(filter(lambda x: not (x.is_production() and - x.left_term == left and - x.right_term == right and - x.production == prod), + left = to_variable(left) + right = to_variable(right) + prod = to_terminal(prod) + self._rules = list(filter(lambda x: not (isinstance(x, ProductionRule) + and x.left_term == left + and x.right_term == right + and x.production == prod), self._rules)) - def add_production(self, left: Any, right: Any, prod: Any): + def add_production(self, + left: Any, + right: Any, + prod: Any) -> None: """Add the production rule: left[sigma] -> right[prod sigma] @@ -176,4 +188,7 @@ def add_production(self, left: Any, right: Any, prod: Any): prod : any The production used in the rule """ + left = to_variable(left) + right = to_variable(right) + prod = to_terminal(prod) self._rules.append(ProductionRule(left, right, prod)) From 13d87ab20c985f4b940b5831bb5e06e799bd6966 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Thu, 5 Dec 2024 12:59:39 +0300 Subject: [PATCH 03/12] correct indexed grammar annotations and imports --- pyformlang/fst/fst.py | 4 +- pyformlang/indexed_grammar/__init__.py | 9 +- .../indexed_grammar/consumption_rule.py | 27 ++-- .../indexed_grammar/duplication_rule.py | 41 +++-- pyformlang/indexed_grammar/end_rule.py | 23 ++- pyformlang/indexed_grammar/indexed_grammar.py | 148 +++--------------- pyformlang/indexed_grammar/production_rule.py | 29 ++-- pyformlang/indexed_grammar/reduced_rule.py | 13 +- pyformlang/indexed_grammar/rules.py | 50 +++--- pyformlang/indexed_grammar/utils.py | 102 ++++++++++++ pyformlang/rsa/recursive_automaton.py | 2 +- 11 files changed, 227 insertions(+), 221 deletions(-) create mode 100644 pyformlang/indexed_grammar/utils.py diff --git a/pyformlang/fst/fst.py b/pyformlang/fst/fst.py index b8dcc1d..ebda44b 100644 --- a/pyformlang/fst/fst.py +++ b/pyformlang/fst/fst.py @@ -422,12 +422,12 @@ def _add_transitions_to(self, idx: int) -> None: for head, transition in self._delta.items(): s_from, input_symbol = head - for s_to, output_symbol in transition: + for s_to, output_symbols in transition: union_fst.add_transition( state_renaming.get_renamed_state(s_from, idx), input_symbol, state_renaming.get_renamed_state(s_to, idx), - output_symbol) + output_symbols) def _add_extremity_states_to(self, union_fst: "FST", diff --git a/pyformlang/indexed_grammar/__init__.py b/pyformlang/indexed_grammar/__init__.py index 14da624..00e6f18 100644 --- a/pyformlang/indexed_grammar/__init__.py +++ b/pyformlang/indexed_grammar/__init__.py @@ -23,16 +23,23 @@ """ from .rules import Rules +from .reduced_rule import ReducedRule from .consumption_rule import ConsumptionRule from .end_rule import EndRule from .production_rule import ProductionRule from .duplication_rule import DuplicationRule from .indexed_grammar import IndexedGrammar +from ..objects.cfg_objects import CFGObject, Variable, Terminal, Epsilon __all__ = ["Rules", + "ReducedRule", "ConsumptionRule", "EndRule", "ProductionRule", "DuplicationRule", - "IndexedGrammar"] + "IndexedGrammar", + "CFGObject", + "Variable", + "Terminal", + "Epsilon"] diff --git a/pyformlang/indexed_grammar/consumption_rule.py b/pyformlang/indexed_grammar/consumption_rule.py index eecc043..b59609f 100644 --- a/pyformlang/indexed_grammar/consumption_rule.py +++ b/pyformlang/indexed_grammar/consumption_rule.py @@ -3,13 +3,12 @@ the stack """ -from typing import List, Set, Any +from typing import List, Set, Hashable, Any -from pyformlang.cfg import Variable, Terminal -from pyformlang.cfg.utils import to_variable, to_terminal -from pyformlang.cfg.cfg_object import CFGObject +from pyformlang.cfg import CFGObject, Variable, Terminal from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable, to_terminal class ConsumptionRule(ReducedRule): @@ -27,14 +26,10 @@ class ConsumptionRule(ReducedRule): The non terminal on the right (here B) """ - @property - def production(self) -> Terminal: - raise NotImplementedError - def __init__(self, - f_param: Any, - left_term: Any, - right_term: Any) -> None: + f_param: Hashable, + left_term: Hashable, + right_term: Hashable) -> None: self._f = to_terminal(f_param) self._left_term = to_variable(left_term) self._right_term = to_variable(right_term) @@ -50,6 +45,10 @@ def f_parameter(self) -> Terminal: """ return self._f + @property + def production(self) -> Terminal: + raise NotImplementedError + @property def left_term(self) -> Variable: """Gets the symbol on the left of the rule @@ -97,12 +96,12 @@ def terminals(self) -> Set[Terminal]: """ return {self._f} - def __repr__(self) -> str: - return f"{self._left_term} [ {self._f} ] -> {self._right_term}" - def __eq__(self, other: Any) -> bool: if not isinstance(other, ConsumptionRule): return False return other.left_term == self.left_term \ and other.right_term == self.right_term \ and other.f_parameter == self.f_parameter + + def __repr__(self) -> str: + return f"{self._left_term} [ {self._f} ] -> {self._right_term}" diff --git a/pyformlang/indexed_grammar/duplication_rule.py b/pyformlang/indexed_grammar/duplication_rule.py index 5b81a32..81ab340 100644 --- a/pyformlang/indexed_grammar/duplication_rule.py +++ b/pyformlang/indexed_grammar/duplication_rule.py @@ -2,13 +2,12 @@ A representation of a duplication rule, i.e. a rule that duplicates the stack """ -from typing import List, Set, Any +from typing import List, Set, Hashable, Any -from pyformlang.cfg import Variable, Terminal -from pyformlang.cfg.utils import to_variable -from pyformlang.cfg.cfg_object import CFGObject +from pyformlang.cfg import CFGObject, Variable, Terminal from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable class DuplicationRule(ReducedRule): @@ -25,6 +24,14 @@ class DuplicationRule(ReducedRule): The second non-terminal on the right of the rule (C here) """ + def __init__(self, + left_term: Hashable, + right_term0: Hashable, + right_term1: Hashable) -> None: + self._left_term = to_variable(left_term) + self._right_terms = (to_variable(right_term0), + to_variable(right_term1)) + @property def f_parameter(self) -> Terminal: raise NotImplementedError @@ -33,18 +40,6 @@ def f_parameter(self) -> Terminal: def production(self) -> Terminal: raise NotImplementedError - @property - def right_term(self) -> CFGObject: - raise NotImplementedError - - def __init__(self, - left_term: Any, - right_term0: Any, - right_term1: Any) -> None: - self._left_term = to_variable(left_term) - self._right_terms = (to_variable(right_term0), - to_variable(right_term1)) - @property def left_term(self) -> Variable: """Gives the non-terminal on the left of the rule @@ -56,6 +51,10 @@ def left_term(self) -> Variable: """ return self._left_term + @property + def right_term(self) -> CFGObject: + raise NotImplementedError + @property def right_terms(self) -> List[CFGObject]: """Gives the non-terminals on the right of the rule @@ -89,13 +88,13 @@ def terminals(self) -> Set[Terminal]: """ return set() - def __repr__(self) -> str: - """Gives a string representation of the rule, ignoring the sigmas""" - return f"{self._left_term} -> \ - {self._right_terms[0]} {self._right_terms[1]}" - def __eq__(self, other: Any) -> bool: if not isinstance(other, DuplicationRule): return False return other.left_term == self._left_term \ and other.right_terms == self.right_terms + + def __repr__(self) -> str: + """Gives a string representation of the rule, ignoring the sigmas""" + return f"{self._left_term} -> \ + {self._right_terms[0]} {self._right_terms[1]}" diff --git a/pyformlang/indexed_grammar/end_rule.py b/pyformlang/indexed_grammar/end_rule.py index 78ed7a6..1433ca9 100644 --- a/pyformlang/indexed_grammar/end_rule.py +++ b/pyformlang/indexed_grammar/end_rule.py @@ -2,13 +2,12 @@ Represents a end rule, i.e. a rule which give only a terminal """ -from typing import List, Set, Any +from typing import List, Set, Hashable, Any -from pyformlang.cfg import Variable, Terminal -from pyformlang.cfg.utils import to_variable, to_terminal -from pyformlang.cfg.cfg_object import CFGObject +from pyformlang.cfg import CFGObject, Variable, Terminal from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable, to_terminal class EndRule(ReducedRule): @@ -23,6 +22,10 @@ class EndRule(ReducedRule): The terminal on the right, "a" here """ + def __init__(self, left_term: Hashable, right_term: Hashable) -> None: + self._left_term = to_variable(left_term) + self._right_term = to_terminal(right_term) + @property def f_parameter(self) -> Terminal: raise NotImplementedError @@ -31,10 +34,6 @@ def f_parameter(self) -> Terminal: def production(self) -> Terminal: raise NotImplementedError - def __init__(self, left_term: Any, right_term: Any) -> None: - self._left_term = to_variable(left_term) - self._right_term = to_terminal(right_term) - @property def left_term(self) -> Variable: """Gets the non-terminal on the left of the rule @@ -90,12 +89,12 @@ def terminals(self) -> Set[Terminal]: """ return {self._right_term} - def __repr__(self) -> str: - """Gets the string representation of the rule""" - return f"{self._left_term} -> {self._right_term}" - def __eq__(self, other: Any) -> bool: if not isinstance(other, EndRule): return False return other.left_term == self.left_term \ and other.right_term == self.right_term + + def __repr__(self) -> str: + """Gets the string representation of the rule""" + return f"{self._left_term} -> {self._right_term}" diff --git a/pyformlang/indexed_grammar/indexed_grammar.py b/pyformlang/indexed_grammar/indexed_grammar.py index 762557a..ab5004e 100644 --- a/pyformlang/indexed_grammar/indexed_grammar.py +++ b/pyformlang/indexed_grammar/indexed_grammar.py @@ -4,19 +4,18 @@ # pylint: disable=cell-var-from-loop -from typing import Callable, Dict, List, \ - Set, FrozenSet, Tuple, Iterable, Any +from typing import Dict, List, Set, FrozenSet, Tuple, Hashable, Any -from pyformlang.cfg import Variable, Terminal -from pyformlang.cfg.utils import to_variable -from pyformlang.cfg.cfg_object import CFGObject +from pyformlang.cfg import CFGObject, Variable, Terminal from pyformlang.finite_automaton import FiniteAutomaton from pyformlang.regular_expression import Regex +from .rules import Rules from .duplication_rule import DuplicationRule from .production_rule import ProductionRule from .end_rule import EndRule -from .rules import Rules +from .utils import exists, addrec_bis +from ..objects.cfg_objects.utils import to_variable class IndexedGrammar: @@ -32,7 +31,7 @@ class IndexedGrammar: def __init__(self, rules: Rules, - start_variable: Any = "S") -> None: + start_variable: Hashable = "S") -> None: self._rules = rules self._start_variable = to_variable(start_variable) # Precompute all non-terminals @@ -71,9 +70,7 @@ def non_terminals(self) -> Set[Variable]: terminals : iterable of any The non-terminals used in the grammar """ - non_terminals = self._rules.non_terminals - non_terminals.add(self._start_variable) - return non_terminals + return {self.start_variable} | self._rules.non_terminals @property def terminals(self) -> Set[Terminal]: @@ -260,8 +257,8 @@ def get_generating_non_terminals(self) -> Set[Variable]: The generating symbols from the start state """ # Preprocess - generating_from = {} - duplication_pointer = {} + generating_from: Dict[Variable, Set[Variable]] = {} + duplication_pointer: Dict[CFGObject, List[Tuple[Variable, int]]] = {} generating = set() to_process = [] self._preprocess_rules_generating(duplication_pointer, generating, @@ -274,17 +271,17 @@ def get_generating_non_terminals(self) -> Set[Variable]: if symbol not in generating: generating.add(symbol) to_process.append(symbol) - for duplication in duplication_pointer.get(current, []): - duplication[1] -= 1 - if duplication[1] == 0: - if duplication[0] not in generating: - generating.add(duplication[0]) - to_process.append(duplication[0]) + for symbol, pointer in duplication_pointer.get(current, []): + pointer -= 1 + if pointer == 0: + if symbol not in generating: + generating.add(symbol) + to_process.append(symbol) return generating def _preprocess_consumption_rules_generating( self, - generating_from: Dict[CFGObject, Set[Variable]]) \ + generating_from: Dict[Variable, Set[Variable]]) \ -> None: for key in self._rules.consumption_rules: for rule in self._rules.consumption_rules[key]: @@ -297,9 +294,9 @@ def _preprocess_consumption_rules_generating( def _preprocess_rules_generating( self, - duplication_pointer: Dict[CFGObject, List], + duplication_pointer: Dict[CFGObject, List[Tuple[Variable, int]]], generating: Set[Variable], - generating_from: Dict[CFGObject, Set[Variable]], + generating_from: Dict[Variable, Set[Variable]], to_process: List[Variable]) \ -> None: for rule in self._rules.rules: @@ -307,15 +304,9 @@ def _preprocess_rules_generating( left = rule.left_term right0 = rule.right_terms[0] right1 = rule.right_terms[1] - temp = [left, 2] - if right0 in duplication_pointer: - duplication_pointer[right0].append(temp) - else: - duplication_pointer[right0] = [temp] - if right1 in duplication_pointer: - duplication_pointer[right1].append(temp) - else: - duplication_pointer[right1] = [temp] + temp = (left, 2) + duplication_pointer.setdefault(right0, []).append(temp) + duplication_pointer.setdefault(right1, []).append(temp) if isinstance(rule, ProductionRule): left = rule.left_term right = rule.right_term @@ -418,100 +409,3 @@ def __and__(self, other: Any) -> "IndexedGrammar": The indexed grammar which useless rules """ return self.intersection(other) - - -def exists(list_elements: List[Any], - check_function: Callable[[Any], bool]) -> bool: - """exists - Check whether at least an element x of l is True for f(x) - :param list_elements: A list of elements to test - :param check_function: The checking function (takes one parameter and \ - return a boolean) - """ - for element in list_elements: - if check_function(element): - return True - return False - - -def addrec_bis(l_sets: Iterable[Any], - marked_left: Set[Any], - marked_right: Set[Any]) -> bool: - """addrec_bis - Optimized version of addrec - :param l_sets: a list containing tuples (C, M) where: - * C is a non-terminal on the left of a consumption rule - * M is the set of the marked set for the right non-terminal in the - production rule - :param marked_left: Sets which are marked for the non-terminal on the - left of the production rule - :param marked_right: Sets which are marked for the non-terminal on the - right of the production rule - """ - was_modified = False - for marked in list(marked_right): - l_temp = [x for x in l_sets if x[0] in marked] - s_temp = [x[0] for x in l_temp] - # At least one symbol to consider - if frozenset(s_temp) == marked and len(marked) > 0: - was_modified |= addrec_ter(l_temp, marked_left) - return was_modified - - -def addrec_ter(l_sets: List[Any], marked_left: Set[Any]) -> bool: - """addrec - Explores all possible combination of consumption rules to mark a - production rule. - :param l_sets: a list containing tuples (C, M) where: - * C is a non-terminal on the left of a consumption rule - * M is the set of the marked set for the right non-terminal in the - production rule - :param marked_left: Sets which are marked for the non-terminal on the - left of the production rule - :return Whether an element was actually marked - """ - # End condition, nothing left to process - temp_in = [x[0] for x in l_sets] - exists_after = [ - exists(l_sets[index + 1:], lambda x: x[0] == l_sets[index][0]) - for index in range(len(l_sets))] - exists_before = [l_sets[index][0] in temp_in[:index] - for index in range(len(l_sets))] - marked_sets = [l_sets[index][1] for index in range(len(l_sets))] - marked_sets = [sorted(x, key=lambda x: -len(x)) for x in marked_sets] - # Try to optimize by having an order of the sets - sorted_zip = sorted(zip(exists_after, exists_before, marked_sets), - key=lambda x: -len(x[2])) - exists_after, exists_before, marked_sets = \ - zip(*sorted_zip) - res = False - # contains tuples of index, temp_set - to_process = [(0, frozenset())] - done = set() - while to_process: - index, new_temp = to_process.pop() - if index >= len(l_sets): - # Check if at least one non-terminal was considered, then if the - # set of non-terminals considered is marked of the right - # non-terminal in the production rule, then if a new set is - # marked or not - if new_temp not in marked_left: - marked_left.add(new_temp) - res = True - continue - if exists_before[index] or exists_after[index]: - to_append = (index + 1, new_temp) - to_process.append(to_append) - if not exists_before[index]: - # For all sets which were marked for the current consumption rule - for marked_set in marked_sets[index]: - if marked_set <= new_temp: - to_append = (index + 1, new_temp) - elif new_temp <= marked_set: - to_append = (index + 1, marked_set) - else: - to_append = (index + 1, new_temp.union(marked_set)) - if to_append not in done: - done.add(to_append) - to_process.append(to_append) - return res diff --git a/pyformlang/indexed_grammar/production_rule.py b/pyformlang/indexed_grammar/production_rule.py index 4321ea7..a8bfc47 100644 --- a/pyformlang/indexed_grammar/production_rule.py +++ b/pyformlang/indexed_grammar/production_rule.py @@ -2,13 +2,12 @@ Represents a production rule, i.e. a rule that pushed on the stack """ -from typing import List, Set, Any +from typing import List, Set, Hashable, Any -from pyformlang.cfg import Variable, Terminal -from pyformlang.cfg.utils import to_variable, to_terminal -from pyformlang.cfg.cfg_object import CFGObject +from pyformlang.cfg import CFGObject, Variable, Terminal from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable, to_terminal class ProductionRule(ReducedRule): @@ -25,18 +24,18 @@ class ProductionRule(ReducedRule): The terminal used in the rule, "r" here """ - @property - def f_parameter(self) -> Terminal: - raise NotImplementedError - def __init__(self, - left_term: Any, - right_term: Any, - production: Any) -> None: + left_term: Hashable, + right_term: Hashable, + production: Hashable) -> None: self._left_term = to_variable(left_term) self._right_term = to_variable(right_term) self._production = to_terminal(production) + @property + def f_parameter(self) -> Terminal: + raise NotImplementedError + @property def production(self) -> Terminal: """Gets the terminal used in the production @@ -103,13 +102,13 @@ def terminals(self) -> Set[Terminal]: """ return {self._production} - def __repr__(self) -> str: - """Gets the string representation of the rule""" - return f"{self._left_term} -> {self._right_term} [ {self._production} ]" - def __eq__(self, other: Any) -> bool: if not isinstance(other, ProductionRule): return False return other.left_term == self.left_term \ and other.right_term == self.right_term \ and other.production == self.production + + def __repr__(self) -> str: + """Gets the string representation of the rule""" + return f"{self._left_term} -> {self._right_term} [ {self._production} ]" diff --git a/pyformlang/indexed_grammar/reduced_rule.py b/pyformlang/indexed_grammar/reduced_rule.py index 20ae7dd..39c6c50 100644 --- a/pyformlang/indexed_grammar/reduced_rule.py +++ b/pyformlang/indexed_grammar/reduced_rule.py @@ -2,11 +2,10 @@ Representation of a reduced rule """ -from typing import List, Set +from typing import List, Set, Any from abc import abstractmethod -from pyformlang.cfg import Variable, Terminal -from pyformlang.cfg.cfg_object import CFGObject +from pyformlang.cfg import CFGObject, Variable, Terminal class ReducedRule: @@ -97,3 +96,11 @@ def terminals(self) -> Set[Terminal]: The terminals used in the rule """ raise NotImplementedError + + @abstractmethod + def __eq__(self, other: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError diff --git a/pyformlang/indexed_grammar/rules.py b/pyformlang/indexed_grammar/rules.py index 9e431d7..2397ea2 100644 --- a/pyformlang/indexed_grammar/rules.py +++ b/pyformlang/indexed_grammar/rules.py @@ -2,15 +2,15 @@ Representations of rules in a indexed grammar """ -from typing import Dict, List, Set, Tuple, Iterable, Any +from typing import Dict, List, Set, Tuple, Iterable, Hashable from pyformlang.cfg import Variable, Terminal -from pyformlang.cfg.utils import to_variable, to_terminal +from .reduced_rule import ReducedRule from .production_rule import ProductionRule from .consumption_rule import ConsumptionRule from .rule_ordering import RuleOrdering -from .reduced_rule import ReducedRule +from ..objects.cfg_objects.utils import to_variable, to_terminal class Rules: @@ -122,13 +122,13 @@ def non_terminals(self) -> Set[Variable]: non_terminals : iterable of any The non terminals used in the rule """ - terminals = set() + non_terminals = set() for temp_rule in self._consumption_rules.values(): for rule in temp_rule: - terminals = terminals.union(rule.non_terminals) + non_terminals.update(rule.non_terminals) for rule in self._rules: - terminals = terminals.union(rule.non_terminals) - return terminals + non_terminals.update(rule.non_terminals) + return non_terminals @property def terminals(self) -> Set[Terminal]: @@ -142,16 +142,16 @@ def terminals(self) -> Set[Terminal]: terminals = set() for rules in self._consumption_rules.values(): for rule in rules: - terminals = terminals.union(rule.terminals) + terminals.update(rule.terminals) for rule in self._rules: - terminals = terminals.union(rule.terminals) + terminals.update(rule.terminals) return terminals - def remove_production(self, - left: Any, - right: Any, - prod: Any) -> None: - """Remove the production rule: + def add_production(self, + left: Hashable, + right: Hashable, + prod: Hashable) -> None: + """Add the production rule: left[sigma] -> right[prod sigma] Parameters @@ -166,17 +166,13 @@ def remove_production(self, left = to_variable(left) right = to_variable(right) prod = to_terminal(prod) - self._rules = list(filter(lambda x: not (isinstance(x, ProductionRule) - and x.left_term == left - and x.right_term == right - and x.production == prod), - self._rules)) + self._rules.append(ProductionRule(left, right, prod)) - def add_production(self, - left: Any, - right: Any, - prod: Any) -> None: - """Add the production rule: + def remove_production(self, + left: Hashable, + right: Hashable, + prod: Hashable) -> None: + """Remove the production rule: left[sigma] -> right[prod sigma] Parameters @@ -191,4 +187,8 @@ def add_production(self, left = to_variable(left) right = to_variable(right) prod = to_terminal(prod) - self._rules.append(ProductionRule(left, right, prod)) + self._rules = list(filter(lambda x: not (isinstance(x, ProductionRule) + and x.left_term == left + and x.right_term == right + and x.production == prod), + self._rules)) diff --git a/pyformlang/indexed_grammar/utils.py b/pyformlang/indexed_grammar/utils.py new file mode 100644 index 0000000..ce8c103 --- /dev/null +++ b/pyformlang/indexed_grammar/utils.py @@ -0,0 +1,102 @@ +""" Utility for indexed grammars """ + +# pylint: disable=cell-var-from-loop + +from typing import Callable, List, Set, Iterable, Any + + +def exists(list_elements: List[Any], + check_function: Callable[[Any], bool]) -> bool: + """exists + Check whether at least an element x of l is True for f(x) + :param list_elements: A list of elements to test + :param check_function: The checking function (takes one parameter and \ + return a boolean) + """ + for element in list_elements: + if check_function(element): + return True + return False + + +def addrec_bis(l_sets: Iterable[Any], + marked_left: Set[Any], + marked_right: Set[Any]) -> bool: + """addrec_bis + Optimized version of addrec + :param l_sets: a list containing tuples (C, M) where: + * C is a non-terminal on the left of a consumption rule + * M is the set of the marked set for the right non-terminal in the + production rule + :param marked_left: Sets which are marked for the non-terminal on the + left of the production rule + :param marked_right: Sets which are marked for the non-terminal on the + right of the production rule + """ + was_modified = False + for marked in list(marked_right): + l_temp = [x for x in l_sets if x[0] in marked] + s_temp = [x[0] for x in l_temp] + # At least one symbol to consider + if frozenset(s_temp) == marked and len(marked) > 0: + was_modified |= addrec_ter(l_temp, marked_left) + return was_modified + + +def addrec_ter(l_sets: List[Any], marked_left: Set[Any]) -> bool: + """addrec + Explores all possible combination of consumption rules to mark a + production rule. + :param l_sets: a list containing tuples (C, M) where: + * C is a non-terminal on the left of a consumption rule + * M is the set of the marked set for the right non-terminal in the + production rule + :param marked_left: Sets which are marked for the non-terminal on the + left of the production rule + :return Whether an element was actually marked + """ + # End condition, nothing left to process + temp_in = [x[0] for x in l_sets] + exists_after = [ + exists(l_sets[index + 1:], lambda x: x[0] == l_sets[index][0]) + for index in range(len(l_sets))] + exists_before = [l_sets[index][0] in temp_in[:index] + for index in range(len(l_sets))] + marked_sets = [l_sets[index][1] for index in range(len(l_sets))] + marked_sets = [sorted(x, key=lambda x: -len(x)) for x in marked_sets] + # Try to optimize by having an order of the sets + sorted_zip = sorted(zip(exists_after, exists_before, marked_sets), + key=lambda x: -len(x[2])) + exists_after, exists_before, marked_sets = \ + zip(*sorted_zip) + res = False + # contains tuples of index, temp_set + to_process = [(0, frozenset())] + done = set() + while to_process: + index, new_temp = to_process.pop() + if index >= len(l_sets): + # Check if at least one non-terminal was considered, then if the + # set of non-terminals considered is marked of the right + # non-terminal in the production rule, then if a new set is + # marked or not + if new_temp not in marked_left: + marked_left.add(new_temp) + res = True + continue + if exists_before[index] or exists_after[index]: + to_append = (index + 1, new_temp) + to_process.append(to_append) + if not exists_before[index]: + # For all sets which were marked for the current consumption rule + for marked_set in marked_sets[index]: + if marked_set <= new_temp: + to_append = (index + 1, new_temp) + elif new_temp <= marked_set: + to_append = (index + 1, marked_set) + else: + to_append = (index + 1, new_temp.union(marked_set)) + if to_append not in done: + done.add(to_append) + to_process.append(to_append) + return res diff --git a/pyformlang/rsa/recursive_automaton.py b/pyformlang/rsa/recursive_automaton.py index 703bbfa..4c9edbe 100644 --- a/pyformlang/rsa/recursive_automaton.py +++ b/pyformlang/rsa/recursive_automaton.py @@ -100,7 +100,7 @@ def from_regex(cls, regex: Regex, start_nonterminal: Hashable) \ return RecursiveAutomaton(box, {box}) @classmethod - def from_ebnf(cls, text: str, start_nonterminal: Hashable = Symbol("S")) \ + def from_ebnf(cls, text: str, start_nonterminal: Hashable = "S") \ -> "RecursiveAutomaton": """ Create a recursive automaton from ebnf \ (ebnf = Extended Backus-Naur Form) From 2477c274e4a6f61a6e013731b8aa28bd1238d702 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Thu, 5 Dec 2024 13:49:27 +0300 Subject: [PATCH 04/12] refactor indexed grammar intersection --- pyformlang/fst/fst.py | 134 +---------------- pyformlang/fst/tests/test_fst.py | 31 ---- pyformlang/indexed_grammar/indexed_grammar.py | 138 ++++++++++++++++-- .../tests/test_indexed_grammar.py | 34 ++++- 4 files changed, 158 insertions(+), 179 deletions(-) diff --git a/pyformlang/fst/fst.py b/pyformlang/fst/fst.py index ebda44b..e4f04bd 100644 --- a/pyformlang/fst/fst.py +++ b/pyformlang/fst/fst.py @@ -7,10 +7,6 @@ from networkx import MultiDiGraph from networkx.drawing.nx_pydot import write_dot -from pyformlang.indexed_grammar import IndexedGrammar, Rules, \ - DuplicationRule, ProductionRule, EndRule, ConsumptionRule -from pyformlang.indexed_grammar.reduced_rule import ReducedRule - from .utils import StateRenaming from ..objects.finite_automaton_objects import State, Symbol, Epsilon from ..objects.finite_automaton_objects.utils import to_state, to_symbol @@ -216,7 +212,7 @@ def translate(self, The translation of the input word """ # (remaining in the input, generated so far, current_state) - input_word = [to_symbol(symbol) for symbol in input_word] + input_word = [to_symbol(x) for x in input_word if x != Epsilon()] to_process: List[Tuple[List[Symbol], List[Symbol], State]] = [] seen_by_state = {state: [] for state in self.states} for start_state in self._start_states: @@ -244,134 +240,6 @@ def translate(self, generated + list(output_symbols), next_state)) - def intersection(self, indexed_grammar: IndexedGrammar) -> IndexedGrammar: - """ Compute the intersection with an other object - - Equivalent to: - >> fst and indexed_grammar - """ - rules = indexed_grammar.rules - new_rules: List[ReducedRule] = [EndRule("T", str(Epsilon()))] - self._extract_consumption_rules_intersection(rules, new_rules) - self._extract_indexed_grammar_rules_intersection(rules, new_rules) - self._extract_terminals_intersection(rules, new_rules) - self._extract_epsilon_transitions_intersection(new_rules) - self._extract_fst_delta_intersection(new_rules) - self._extract_fst_epsilon_intersection(new_rules) - self._extract_fst_duplication_rules_intersection(new_rules) - rules = Rules(new_rules, rules.optim) - return IndexedGrammar(rules).remove_useless_rules() - - def _extract_fst_duplication_rules_intersection( - self, - new_rules: List[ReducedRule]) \ - -> None: - for state_p in self._final_states: - for start_state in self._start_states: - new_rules.append(DuplicationRule( - "S", - str((start_state, "S", state_p)), - "T")) - - def _extract_fst_epsilon_intersection( - self, - new_rules: List[ReducedRule]) \ - -> None: - for state_p in self._states: - new_rules.append(EndRule( - str((state_p, Epsilon(), state_p)), str(Epsilon()))) - - def _extract_fst_delta_intersection( - self, - new_rules:List[ReducedRule]) \ - -> None: - for key, pair in self._delta.items(): - state_p = key[0] - terminal = key[1] - for transition in pair: - state_q = transition[0] - symbol = transition[1] - new_rules.append(EndRule(str((state_p, terminal, state_q)), - symbol)) - - def _extract_epsilon_transitions_intersection( - self, - new_rules: List[ReducedRule]) \ - -> None: - for state_p in self._states: - for state_q in self._states: - for state_r in self._states: - new_rules.append(DuplicationRule( - str((state_p, Epsilon(), state_q)), - str((state_p, Epsilon(), state_r)), - str((state_r, Epsilon(), state_q)))) - - def _extract_indexed_grammar_rules_intersection( - self, - rules: Rules, - new_rules: List[ReducedRule]) \ - -> None: - for rule in rules.rules: - if isinstance(rule, DuplicationRule): - for state_p in self._states: - for state_q in self._states: - for state_r in self._states: - new_rules.append(DuplicationRule( - str((state_p, rule.left_term, state_q)), - str((state_p, rule.right_terms[0], state_r)), - str((state_r, rule.right_terms[1], state_q)))) - elif isinstance(rule, ProductionRule): - for state_p in self._states: - for state_q in self._states: - new_rules.append(ProductionRule( - str((state_p, rule.left_term, state_q)), - str((state_p, rule.right_term, state_q)), - str(rule.production))) - elif isinstance(rule, EndRule): - for state_p in self._states: - for state_q in self._states: - new_rules.append(DuplicationRule( - str((state_p, rule.left_term, state_q)), - str((state_p, rule.right_term, state_q)), - "T")) - - def _extract_terminals_intersection( - self, - rules: Rules, - new_rules: List[ReducedRule]) \ - -> None: - terminals = rules.terminals - for terminal in terminals: - for state_p in self._states: - for state_q in self._states: - for state_r in self._states: - new_rules.append(DuplicationRule( - str((state_p, terminal, state_q)), - str((state_p, Epsilon(), state_r)), - str((state_r, terminal, state_q)))) - new_rules.append(DuplicationRule( - str((state_p, terminal, state_q)), - str((state_p, terminal, state_r)), - str((state_r, Epsilon(), state_q)))) - - def _extract_consumption_rules_intersection( - self, - rules: Rules, - new_rules: List[ReducedRule]) \ - -> None: - consumptions = rules.consumption_rules - for consumption_rule in consumptions: - for consumption in consumptions[consumption_rule]: - for state_r in self._states: - for state_s in self._states: - new_rules.append(ConsumptionRule( - consumption.f_parameter, - str((state_r, consumption.left_term, state_s)), - str((state_r, consumption.right_term, state_s)))) - - def __and__(self, other: IndexedGrammar) -> IndexedGrammar: - return self.intersection(other) - def union(self, other_fst: "FST") -> "FST": """ Makes the union of two fst diff --git a/pyformlang/fst/tests/test_fst.py b/pyformlang/fst/tests/test_fst.py index ec4e4da..a80d2fc 100644 --- a/pyformlang/fst/tests/test_fst.py +++ b/pyformlang/fst/tests/test_fst.py @@ -5,9 +5,6 @@ import pytest from pyformlang.fst import FST -from pyformlang.indexed_grammar import ( - DuplicationRule, ProductionRule, EndRule, - ConsumptionRule, IndexedGrammar, Rules) @pytest.fixture @@ -94,34 +91,6 @@ def test_translate(self): assert ["b", "c"] in translation assert ["b"] + ["c"] * 9 in translation - def test_intersection_indexed_grammar(self): - """ Test the intersection with indexed grammar """ - l_rules = [] - rules = Rules(l_rules) - indexed_grammar = IndexedGrammar(rules) - fst = FST() - intersection = fst & indexed_grammar - assert intersection.is_empty() - - l_rules.append(ProductionRule("S", "D", "f")) - l_rules.append(DuplicationRule("D", "A", "B")) - l_rules.append(ConsumptionRule("f", "A", "Afinal")) - l_rules.append(ConsumptionRule("f", "B", "Bfinal")) - l_rules.append(EndRule("Afinal", "a")) - l_rules.append(EndRule("Bfinal", "b")) - - rules = Rules(l_rules) - indexed_grammar = IndexedGrammar(rules) - intersection = fst.intersection(indexed_grammar) - assert intersection.is_empty() - - fst.add_start_state("q0") - fst.add_final_state("final") - fst.add_transition("q0", "a", "q1", ["a"]) - fst.add_transition("q1", "b", "final", ["b"]) - intersection = fst.intersection(indexed_grammar) - assert not intersection.is_empty() - def test_union(self, fst0, fst1): """ Tests the union""" fst_union = fst0.union(fst1) diff --git a/pyformlang/indexed_grammar/indexed_grammar.py b/pyformlang/indexed_grammar/indexed_grammar.py index ab5004e..1bf553d 100644 --- a/pyformlang/indexed_grammar/indexed_grammar.py +++ b/pyformlang/indexed_grammar/indexed_grammar.py @@ -4,15 +4,16 @@ # pylint: disable=cell-var-from-loop -from typing import Dict, List, Set, FrozenSet, Tuple, Hashable, Any +from typing import Dict, List, Set, FrozenSet, Tuple, Hashable from pyformlang.cfg import CFGObject, Variable, Terminal -from pyformlang.finite_automaton import FiniteAutomaton -from pyformlang.regular_expression import Regex +from pyformlang.fst import FST from .rules import Rules +from .reduced_rule import ReducedRule from .duplication_rule import DuplicationRule from .production_rule import ProductionRule +from .consumption_rule import ConsumptionRule from .end_rule import EndRule from .utils import exists, addrec_bis from ..objects.cfg_objects.utils import to_variable @@ -47,7 +48,7 @@ def __init__(self, # Mark all end symbols for non_terminal_a in non_terminals: if exists(self._rules.rules, - lambda x: x.is_end_rule() + lambda x: isinstance(x, EndRule) and x.left_term == non_terminal_a): self._marked[non_terminal_a].add(frozenset()) @@ -363,7 +364,7 @@ def remove_useless_rules(self) -> "IndexedGrammar": rules = Rules(l_rules, self._rules.optim) return IndexedGrammar(rules) - def intersection(self, other: Any) -> "IndexedGrammar": + def intersection(self, other: FST) -> "IndexedGrammar": """ Computes the intersection of the current indexed grammar with the \ other object @@ -387,14 +388,18 @@ def intersection(self, other: Any) -> "IndexedGrammar": When trying to intersection with something else than a regular expression or a finite automaton """ - if isinstance(other, Regex): - other = other.to_epsilon_nfa() - if isinstance(other, FiniteAutomaton): - fst = other.to_fst() - return fst.intersection(self) - raise NotImplementedError - - def __and__(self, other: Any) -> "IndexedGrammar": + new_rules: List[ReducedRule] = [EndRule("T", "epsilon")] + self._extract_consumption_rules_intersection(other, new_rules) + self._extract_indexed_grammar_rules_intersection(other, new_rules) + self._extract_terminals_intersection(other, new_rules) + self._extract_epsilon_transitions_intersection(other, new_rules) + self._extract_fst_delta_intersection(other, new_rules) + self._extract_fst_epsilon_intersection(other, new_rules) + self._extract_fst_duplication_rules_intersection(other, new_rules) + rules = Rules(new_rules, self.rules.optim) + return IndexedGrammar(rules).remove_useless_rules() + + def __and__(self, other: FST) -> "IndexedGrammar": """ Computes the intersection of the current indexed grammar with the other object @@ -409,3 +414,110 @@ def __and__(self, other: Any) -> "IndexedGrammar": The indexed grammar which useless rules """ return self.intersection(other) + + def _extract_fst_duplication_rules_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for final_state in other.final_states: + for start_state in other.start_states: + new_rules.append(DuplicationRule( + "S", + (start_state, "S", final_state), + "T")) + + def _extract_fst_epsilon_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for state in other.states: + new_rules.append(EndRule( + (state, "epsilon", state), + "epsilon")) + + def _extract_fst_delta_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for (s_from, symb_from), (s_to, symb_to) in other: + new_rules.append(EndRule( + (s_from, symb_from, s_to), + symb_to)) + + def _extract_epsilon_transitions_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for state_p in other.states: + for state_q in other.states: + for state_r in other.states: + new_rules.append(DuplicationRule( + (state_p, "epsilon", state_q), + (state_p, "epsilon", state_r), + (state_r, "epsilon", state_q))) + + def _extract_indexed_grammar_rules_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for rule in self.rules.rules: + if isinstance(rule, DuplicationRule): + for state_p in other.states: + for state_q in other.states: + for state_r in other.states: + new_rules.append(DuplicationRule( + (state_p, rule.left_term, state_q), + (state_p, rule.right_terms[0], state_r), + (state_r, rule.right_terms[1], state_q))) + elif isinstance(rule, ProductionRule): + for state_p in other.states: + for state_q in other.states: + new_rules.append(ProductionRule( + (state_p, rule.left_term, state_q), + (state_p, rule.right_term, state_q), + rule.production)) + elif isinstance(rule, EndRule): + for state_p in other.states: + for state_q in other.states: + new_rules.append(DuplicationRule( + (state_p, rule.left_term, state_q), + (state_p, rule.right_term, state_q), + "T")) + + def _extract_terminals_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + for terminal in self.rules.terminals: + for state_p in other.states: + for state_q in other.states: + for state_r in other.states: + new_rules.append(DuplicationRule( + (state_p, terminal, state_q), + (state_p, "epsilon", state_r), + (state_r, terminal, state_q))) + new_rules.append(DuplicationRule( + (state_p, terminal, state_q), + (state_p, terminal, state_r), + (state_r, "epsilon", state_q))) + + def _extract_consumption_rules_intersection( + self, + other: FST, + new_rules: List[ReducedRule]) \ + -> None: + consumptions = self.rules.consumption_rules + for terminal in consumptions: + for consumption in consumptions[terminal]: + for state_r in other.states: + for state_s in other.states: + new_rules.append(ConsumptionRule( + consumption.f_parameter, + (state_r, consumption.left_term, state_s), + (state_r, consumption.right_term, state_s))) diff --git a/pyformlang/indexed_grammar/tests/test_indexed_grammar.py b/pyformlang/indexed_grammar/tests/test_indexed_grammar.py index 9b184be..6f4bb01 100644 --- a/pyformlang/indexed_grammar/tests/test_indexed_grammar.py +++ b/pyformlang/indexed_grammar/tests/test_indexed_grammar.py @@ -8,6 +8,7 @@ from pyformlang.indexed_grammar import DuplicationRule from pyformlang.indexed_grammar import IndexedGrammar from pyformlang.regular_expression import Regex +from pyformlang.fst import FST class TestIndexedGrammar: @@ -338,7 +339,7 @@ def test_removal_useless(self): assert i_grammar2.non_terminals == \ i_grammar2.get_reachable_non_terminals() - def test_intersection(self): + def test_intersection0(self): """ Tests the intersection of indexed grammar with regex Long to run! """ @@ -349,9 +350,38 @@ def test_intersection(self): EndRule("Bfinal", "b")] rules = Rules(l_rules, 6) indexed_grammar = IndexedGrammar(rules) - i_inter = indexed_grammar.intersection(Regex("a.b")) + fst = Regex("a.b").to_epsilon_nfa().to_fst() + i_inter = indexed_grammar.intersection(fst) assert i_inter + def test_intersection1(self): + """ Test the intersection with fst """ + l_rules = [] + rules = Rules(l_rules) + indexed_grammar = IndexedGrammar(rules) + fst = FST() + intersection = indexed_grammar & fst + assert intersection.is_empty() + + l_rules.append(ProductionRule("S", "D", "f")) + l_rules.append(DuplicationRule("D", "A", "B")) + l_rules.append(ConsumptionRule("f", "A", "Afinal")) + l_rules.append(ConsumptionRule("f", "B", "Bfinal")) + l_rules.append(EndRule("Afinal", "a")) + l_rules.append(EndRule("Bfinal", "b")) + + rules = Rules(l_rules) + indexed_grammar = IndexedGrammar(rules) + intersection = indexed_grammar.intersection(fst) + assert intersection.is_empty() + + fst.add_start_state("q0") + fst.add_final_state("final") + fst.add_transition("q0", "a", "q1", ["a"]) + fst.add_transition("q1", "b", "final", ["b"]) + intersection = indexed_grammar.intersection(fst) + assert not intersection.is_empty() + def get_example_rules(): """ Duplicate example of rules """ From e04b70b5794afee5be8b066ff7b9fd6a78a45bc2 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Thu, 5 Dec 2024 21:16:08 +0300 Subject: [PATCH 05/12] correct fst networkx transitions --- pyformlang/cfg/tests/test_terminal.py | 2 ++ pyformlang/cfg/tests/test_variable.py | 2 ++ .../finite_automaton/finite_automaton.py | 4 +-- .../finite_automaton/tests/test_state.py | 1 + .../finite_automaton/tests/test_symbol.py | 1 + pyformlang/fst/fst.py | 27 +++++++++---------- pyformlang/fst/tests/test_fst.py | 14 +++++----- .../indexed_grammar/duplication_rule.py | 4 +-- .../indexed_grammar/tests/test_rules.py | 2 +- pyformlang/pda/pda.py | 4 +-- pyformlang/pda/tests/test_pda.py | 3 +++ 11 files changed, 36 insertions(+), 28 deletions(-) diff --git a/pyformlang/cfg/tests/test_terminal.py b/pyformlang/cfg/tests/test_terminal.py index 7cd9a0e..53fa5a0 100644 --- a/pyformlang/cfg/tests/test_terminal.py +++ b/pyformlang/cfg/tests/test_terminal.py @@ -24,6 +24,8 @@ def test_creation(self): assert epsilon.to_text() == "epsilon" assert Terminal("C").to_text() == '"TER:C"' assert repr(Epsilon()) == "epsilon" + assert str(terminal0) == "0" + assert repr(terminal0) == "Terminal(0)" def test_eq(self): assert "epsilon" == Epsilon() diff --git a/pyformlang/cfg/tests/test_variable.py b/pyformlang/cfg/tests/test_variable.py index 56c186e..f3ed904 100644 --- a/pyformlang/cfg/tests/test_variable.py +++ b/pyformlang/cfg/tests/test_variable.py @@ -20,3 +20,5 @@ def test_creation(self): assert str(variable0) == str(variable3) assert str(variable0) != str(variable1) assert "A" == Variable("A") + assert str(variable1) == "1" + assert repr(variable1) == "Variable(1)" diff --git a/pyformlang/finite_automaton/finite_automaton.py b/pyformlang/finite_automaton/finite_automaton.py index 3017abb..30ce648 100644 --- a/pyformlang/finite_automaton/finite_automaton.py +++ b/pyformlang/finite_automaton/finite_automaton.py @@ -700,10 +700,10 @@ def __try_add(set_to_add_to: Set[Any], element_to_add: Any) -> bool: @staticmethod def __add_start_state_to_graph(graph: MultiDiGraph, state: State) -> None: """ Adds a starting node to a given graph """ - graph.add_node("starting_" + str(state.value), + graph.add_node("starting_" + str(state), label="", shape=None, height=.0, width=.0) - graph.add_edge("starting_" + str(state.value), + graph.add_edge("starting_" + str(state), state.value) diff --git a/pyformlang/finite_automaton/tests/test_state.py b/pyformlang/finite_automaton/tests/test_state.py index 8046f88..6305e50 100644 --- a/pyformlang/finite_automaton/tests/test_state.py +++ b/pyformlang/finite_automaton/tests/test_state.py @@ -22,6 +22,7 @@ def test_repr(self): assert str(state1) == "ABC" state2 = State(1) assert str(state2) == "1" + assert repr(state1) == "State(ABC)" def test_eq(self): """ Tests the equality of states diff --git a/pyformlang/finite_automaton/tests/test_symbol.py b/pyformlang/finite_automaton/tests/test_symbol.py index fcb114c..51e666b 100644 --- a/pyformlang/finite_automaton/tests/test_symbol.py +++ b/pyformlang/finite_automaton/tests/test_symbol.py @@ -22,6 +22,7 @@ def test_repr(self): assert str(symbol1) == "ABC" symbol2 = Symbol(1) assert str(symbol2) == "1" + assert repr(symbol2) == "Symbol(1)" def test_eq(self): """ Tests equality of symbols diff --git a/pyformlang/fst/fst.py b/pyformlang/fst/fst.py index e4f04bd..26c00fd 100644 --- a/pyformlang/fst/fst.py +++ b/pyformlang/fst/fst.py @@ -2,7 +2,6 @@ from typing import Dict, List, Set, Tuple, Iterator, Iterable, Hashable from copy import deepcopy -from json import dumps, loads from networkx import MultiDiGraph from networkx.drawing.nx_pydot import write_dot @@ -417,11 +416,11 @@ def to_networkx(self) -> MultiDiGraph: """ graph = MultiDiGraph() for state in self._states: - graph.add_node(state, + graph.add_node(state.value, is_start=state in self.start_states, is_final=state in self.final_states, peripheries=2 if state in self.final_states else 1, - label=state) + label=state.value) if state in self.start_states: graph.add_node("starting_" + str(state), label="", @@ -429,14 +428,14 @@ def to_networkx(self) -> MultiDiGraph: height=.0, width=.0) graph.add_edge("starting_" + str(state), - state) - for s_from, input_symbol in self._delta: - for s_to, output_symbols in self._delta[(s_from, input_symbol)]: - graph.add_edge( - s_from, - s_to, - label=(dumps(input_symbol) + " -> " + - dumps(output_symbols))) + state.value) + for (s_from, input_symbol), (s_to, output_symbols) in self: + input_symbol = input_symbol.value + output_symbols = tuple(map(lambda x: x.value, output_symbols)) + graph.add_edge( + s_from.value, + s_to.value, + label=(input_symbol, output_symbols)) return graph @classmethod @@ -465,10 +464,8 @@ def from_networkx(cls, graph: MultiDiGraph) -> "FST": for s_to in graph[s_from]: for transition in graph[s_from][s_to].values(): if "label" in transition: - in_symbol, out_symbols = transition["label"].split( - " -> ") - in_symbol = loads(in_symbol) - out_symbols = loads(out_symbols) + label = transition["label"] + in_symbol, out_symbols = label fst.add_transition(s_from, in_symbol, s_to, diff --git a/pyformlang/fst/tests/test_fst.py b/pyformlang/fst/tests/test_fst.py index a80d2fc..05f70f7 100644 --- a/pyformlang/fst/tests/test_fst.py +++ b/pyformlang/fst/tests/test_fst.py @@ -179,12 +179,14 @@ def test_paper(self): (2, "alone", 3, ["seul"])]) fst.add_start_state(0) fst.add_final_state(3) - assert list(fst.translate(["I", "am", "alone"])) == \ - [['Je', 'suis', 'seul'], - ['Je', 'suis', 'tout', 'seul']] + translation = list(fst.translate(["I", "am", "alone"])) + assert ['Je', 'suis', 'seul'] in translation + assert ['Je', 'suis', 'tout', 'seul'] in translation + assert len(translation) == 2 fst = FST.from_networkx(fst.to_networkx()) - assert list(fst.translate(["I", "am", "alone"])) == \ - [['Je', 'suis', 'seul'], - ['Je', 'suis', 'tout', 'seul']] + translation = list(fst.translate(["I", "am", "alone"])) + assert ['Je', 'suis', 'seul'] in translation + assert ['Je', 'suis', 'tout', 'seul'] in translation + assert len(translation) == 2 fst.write_as_dot("fst.dot") assert path.exists("fst.dot") diff --git a/pyformlang/indexed_grammar/duplication_rule.py b/pyformlang/indexed_grammar/duplication_rule.py index 81ab340..de9238f 100644 --- a/pyformlang/indexed_grammar/duplication_rule.py +++ b/pyformlang/indexed_grammar/duplication_rule.py @@ -96,5 +96,5 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: """Gives a string representation of the rule, ignoring the sigmas""" - return f"{self._left_term} -> \ - {self._right_terms[0]} {self._right_terms[1]}" + return f"{self._left_term} -> " \ + + f"{self._right_terms[0]} {self._right_terms[1]}" diff --git a/pyformlang/indexed_grammar/tests/test_rules.py b/pyformlang/indexed_grammar/tests/test_rules.py index 8a4daaf..d8a355d 100644 --- a/pyformlang/indexed_grammar/tests/test_rules.py +++ b/pyformlang/indexed_grammar/tests/test_rules.py @@ -41,7 +41,7 @@ def test_production_rules(self): """ Tests the production rules """ produ = ProductionRule("S", "C", "end") assert produ.terminals == {"end"} - assert str(produ) == "S -> C[ end ]" + assert str(produ) == "S -> C [ end ]" def test_rules(self): """ Tests the rules """ diff --git a/pyformlang/pda/pda.py b/pyformlang/pda/pda.py index f170d02..8d74bae 100644 --- a/pyformlang/pda/pda.py +++ b/pyformlang/pda/pda.py @@ -789,12 +789,12 @@ def __copy__(self) -> "PDA": def __add_start_state_to_graph(graph: MultiDiGraph, state: State) -> None: """ Adds a starting node to a given graph """ - graph.add_node("starting_" + str(state.value), + graph.add_node("starting_" + str(state), label="", shape=None, height=.0, width=.0) - graph.add_edge("starting_" + str(state.value), + graph.add_edge("starting_" + str(state), state.value) @staticmethod diff --git a/pyformlang/pda/tests/test_pda.py b/pyformlang/pda/tests/test_pda.py index 03173cd..5399041 100644 --- a/pyformlang/pda/tests/test_pda.py +++ b/pyformlang/pda/tests/test_pda.py @@ -87,11 +87,14 @@ def test_represent(self): symb = Symbol("S") assert repr(symb) == "Symbol(S)" state = State("T") + assert str(state) == "T" assert repr(state) == "State(T)" stack_symb = StackSymbol("U") assert repr(stack_symb) == "StackSymbol(U)" assert repr(Epsilon()) == "epsilon" assert str(Epsilon()) == "epsilon" + assert str(StackSymbol(12)) == "12" + assert repr(StackSymbol(12)) == "StackSymbol(12)" def test_transition(self): """ Tests the creation of transition """ From 52f3f926e8f24e37373e1dc824103e2896ccca96 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Fri, 6 Dec 2024 12:14:02 +0300 Subject: [PATCH 06/12] debug indexed grammar intersection and generating variables obtaining --- pyformlang/fst/fst.py | 13 ++++ pyformlang/fst/tests/test_fst.py | 26 +++++++ pyformlang/fst/utils.py | 2 +- pyformlang/indexed_grammar/indexed_grammar.py | 75 +++++++++++-------- pyformlang/indexed_grammar/rules.py | 4 +- pyformlang/pda/transition_function.py | 3 +- 6 files changed, 85 insertions(+), 38 deletions(-) diff --git a/pyformlang/fst/fst.py b/pyformlang/fst/fst.py index 26c00fd..fd84dac 100644 --- a/pyformlang/fst/fst.py +++ b/pyformlang/fst/fst.py @@ -146,6 +146,19 @@ def add_transitions(self, transitions: Iterable[InputTransition]) -> None: s_to, output_symbols) + def remove_transition(self, + s_from: Hashable, + input_symbol: Hashable, + s_to: Hashable, + output_symbols: Iterable[Hashable]) -> None: + """ Removes the given transition from the FST """ + s_from = to_state(s_from) + input_symbol = to_symbol(input_symbol) + s_to = to_state(s_to) + output_symbols = tuple(to_symbol(x) for x in output_symbols) + head = (s_from, input_symbol) + self._delta.get(head, set()).discard((s_to, output_symbols)) + def add_start_state(self, start_state: Hashable) -> None: """ Add a start state diff --git a/pyformlang/fst/tests/test_fst.py b/pyformlang/fst/tests/test_fst.py index 05f70f7..6878d93 100644 --- a/pyformlang/fst/tests/test_fst.py +++ b/pyformlang/fst/tests/test_fst.py @@ -190,3 +190,29 @@ def test_paper(self): assert len(translation) == 2 fst.write_as_dot("fst.dot") assert path.exists("fst.dot") + + def test_contains(self, fst0: FST): + """ Tests the containment of transition in the FST """ + assert ("q0", "a", "q1", ["b"]) in fst0 + assert ("a", "b", "c", "d") not in fst0 + fst0.add_transition("a", "b", "c", "d") + assert ("a", "b", "c", "d") in fst0 + + def test_iter(self, fst0: FST): + """ Tests the iteration of FST transitions """ + fst0.add_transition("q1", "A", "q2", ["B"]) + fst0.add_transition("q1", "A", "q2", ["C", "D"]) + transitions = list(iter(fst0)) + assert (("q0", "a"), ("q1", tuple("b"))) in transitions + assert (("q1", "A"), ("q2", tuple("B"))) in transitions + assert (("q1", "A"), ("q2", ("C", "D"))) in transitions + assert len(transitions) == 3 + + def test_remove_transition(self, fst0: FST): + """ Tests the removal of transition from the FST """ + assert ("q0", "a", "q1", "b") in fst0 + fst0.remove_transition("q0", "a", "q1", "b") + assert ("q0", "a", "q1", "b") not in fst0 + fst0.remove_transition("q0", "a", "q1", "b") + assert ("q0", "a", "q1", "b") not in fst0 + assert fst0.get_number_transitions() == 0 diff --git a/pyformlang/fst/utils.py b/pyformlang/fst/utils.py index 6315edf..0a6c243 100644 --- a/pyformlang/fst/utils.py +++ b/pyformlang/fst/utils.py @@ -1,4 +1,4 @@ -""" Class for renaming the states in FST """ +""" Utility for FST """ from typing import Dict, Set, Iterable, Tuple diff --git a/pyformlang/indexed_grammar/indexed_grammar.py b/pyformlang/indexed_grammar/indexed_grammar.py index 1bf553d..8a617a9 100644 --- a/pyformlang/indexed_grammar/indexed_grammar.py +++ b/pyformlang/indexed_grammar/indexed_grammar.py @@ -259,7 +259,7 @@ def get_generating_non_terminals(self) -> Set[Variable]: """ # Preprocess generating_from: Dict[Variable, Set[Variable]] = {} - duplication_pointer: Dict[CFGObject, List[Tuple[Variable, int]]] = {} + duplication_pointer: Dict[CFGObject, List[List]] = {} generating = set() to_process = [] self._preprocess_rules_generating(duplication_pointer, generating, @@ -272,12 +272,12 @@ def get_generating_non_terminals(self) -> Set[Variable]: if symbol not in generating: generating.add(symbol) to_process.append(symbol) - for symbol, pointer in duplication_pointer.get(current, []): - pointer -= 1 - if pointer == 0: - if symbol not in generating: - generating.add(symbol) - to_process.append(symbol) + for duplication in duplication_pointer.get(current, []): + duplication[1] -= 1 + if duplication[1] == 0: + if duplication[0] not in generating: + generating.add(duplication[0]) + to_process.append(duplication[0]) return generating def _preprocess_consumption_rules_generating( @@ -295,7 +295,7 @@ def _preprocess_consumption_rules_generating( def _preprocess_rules_generating( self, - duplication_pointer: Dict[CFGObject, List[Tuple[Variable, int]]], + duplication_pointer: Dict[CFGObject, List[List]], generating: Set[Variable], generating_from: Dict[Variable, Set[Variable]], to_process: List[Variable]) \ @@ -305,7 +305,7 @@ def _preprocess_rules_generating( left = rule.left_term right0 = rule.right_terms[0] right1 = rule.right_terms[1] - temp = (left, 2) + temp = [left, 2] duplication_pointer.setdefault(right0, []).append(temp) duplication_pointer.setdefault(right1, []).append(temp) if isinstance(rule, ProductionRule): @@ -424,7 +424,7 @@ def _extract_fst_duplication_rules_intersection( for start_state in other.start_states: new_rules.append(DuplicationRule( "S", - (start_state, "S", final_state), + (start_state.value, "S", final_state.value), "T")) def _extract_fst_epsilon_intersection( @@ -434,7 +434,7 @@ def _extract_fst_epsilon_intersection( -> None: for state in other.states: new_rules.append(EndRule( - (state, "epsilon", state), + (state.value, "epsilon", state.value), "epsilon")) def _extract_fst_delta_intersection( @@ -444,8 +444,8 @@ def _extract_fst_delta_intersection( -> None: for (s_from, symb_from), (s_to, symb_to) in other: new_rules.append(EndRule( - (s_from, symb_from, s_to), - symb_to)) + (s_from.value, symb_from.value, s_to.value), + tuple(map(lambda x: x.value, symb_to)))) def _extract_epsilon_transitions_intersection( self, @@ -456,9 +456,9 @@ def _extract_epsilon_transitions_intersection( for state_q in other.states: for state_r in other.states: new_rules.append(DuplicationRule( - (state_p, "epsilon", state_q), - (state_p, "epsilon", state_r), - (state_r, "epsilon", state_q))) + (state_p.value, "epsilon", state_q.value), + (state_p.value, "epsilon", state_r.value), + (state_r.value, "epsilon", state_q.value))) def _extract_indexed_grammar_rules_intersection( self, @@ -471,22 +471,29 @@ def _extract_indexed_grammar_rules_intersection( for state_q in other.states: for state_r in other.states: new_rules.append(DuplicationRule( - (state_p, rule.left_term, state_q), - (state_p, rule.right_terms[0], state_r), - (state_r, rule.right_terms[1], state_q))) + (state_p.value, rule.left_term.value, + state_q.value), + (state_p.value, rule.right_terms[0].value, + state_r.value), + (state_r.value, rule.right_terms[1].value, + state_q.value))) elif isinstance(rule, ProductionRule): for state_p in other.states: for state_q in other.states: new_rules.append(ProductionRule( - (state_p, rule.left_term, state_q), - (state_p, rule.right_term, state_q), - rule.production)) + (state_p.value, rule.left_term.value, + state_q.value), + (state_p.value, rule.right_term.value, + state_q.value), + rule.production.value)) elif isinstance(rule, EndRule): for state_p in other.states: for state_q in other.states: new_rules.append(DuplicationRule( - (state_p, rule.left_term, state_q), - (state_p, rule.right_term, state_q), + (state_p.value, rule.left_term.value, + state_q.value), + (state_p.value, rule.right_term.value, + state_q.value), "T")) def _extract_terminals_intersection( @@ -499,13 +506,13 @@ def _extract_terminals_intersection( for state_q in other.states: for state_r in other.states: new_rules.append(DuplicationRule( - (state_p, terminal, state_q), - (state_p, "epsilon", state_r), - (state_r, terminal, state_q))) + (state_p.value, terminal.value, state_q.value), + (state_p.value, "epsilon", state_r.value), + (state_r.value, terminal.value, state_q.value))) new_rules.append(DuplicationRule( - (state_p, terminal, state_q), - (state_p, terminal, state_r), - (state_r, "epsilon", state_q))) + (state_p.value, terminal.value, state_q.value), + (state_p.value, terminal.value, state_r.value), + (state_r.value, "epsilon", state_q.value))) def _extract_consumption_rules_intersection( self, @@ -518,6 +525,8 @@ def _extract_consumption_rules_intersection( for state_r in other.states: for state_s in other.states: new_rules.append(ConsumptionRule( - consumption.f_parameter, - (state_r, consumption.left_term, state_s), - (state_r, consumption.right_term, state_s))) + consumption.f_parameter.value, + (state_r.value, consumption.left_term.value, + state_s.value), + (state_r.value, consumption.right_term.value, + state_s.value))) diff --git a/pyformlang/indexed_grammar/rules.py b/pyformlang/indexed_grammar/rules.py index 2397ea2..f94aa85 100644 --- a/pyformlang/indexed_grammar/rules.py +++ b/pyformlang/indexed_grammar/rules.py @@ -123,8 +123,8 @@ def non_terminals(self) -> Set[Variable]: The non terminals used in the rule """ non_terminals = set() - for temp_rule in self._consumption_rules.values(): - for rule in temp_rule: + for rules in self._consumption_rules.values(): + for rule in rules: non_terminals.update(rule.non_terminals) for rule in self._rules: non_terminals.update(rule.non_terminals) diff --git a/pyformlang/pda/transition_function.py b/pyformlang/pda/transition_function.py index fa52bec..d717400 100644 --- a/pyformlang/pda/transition_function.py +++ b/pyformlang/pda/transition_function.py @@ -64,8 +64,7 @@ def remove_transition(self, stack_to: Tuple[StackSymbol, ...]) -> None: """ Remove the given transition from the function """ key = (s_from, input_symbol, stack_from) - if key in self._transitions: - self._transitions[key].discard((s_to, stack_to)) + self._transitions.get(key, set()).discard((s_to, stack_to)) def copy(self) -> "TransitionFunction": """ Copy the current transition function From a5d499dd874cd555154143bfaec0d88aa8890fc1 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Fri, 6 Dec 2024 14:26:37 +0300 Subject: [PATCH 07/12] add proper fst initialization, add copying of the fst --- pyformlang/fst/__init__.py | 3 +- pyformlang/fst/fst.py | 107 ++++++++++++++------------ pyformlang/fst/tests/test_fst.py | 51 +++++++++--- pyformlang/fst/transition_function.py | 78 +++++++++++++++++++ pyformlang/pda/pda.py | 40 +++++----- pyformlang/pda/transition_function.py | 45 ++++++----- 6 files changed, 223 insertions(+), 101 deletions(-) create mode 100644 pyformlang/fst/transition_function.py diff --git a/pyformlang/fst/__init__.py b/pyformlang/fst/__init__.py index c881874..b7fdb5f 100644 --- a/pyformlang/fst/__init__.py +++ b/pyformlang/fst/__init__.py @@ -12,10 +12,11 @@ """ -from .fst import FST, State, Symbol, Epsilon +from .fst import FST, TransitionFunction, State, Symbol, Epsilon __all__ = ["FST", + "TransitionFunction", "State", "Symbol", "Epsilon"] diff --git a/pyformlang/fst/fst.py b/pyformlang/fst/fst.py index fd84dac..9d5c14e 100644 --- a/pyformlang/fst/fst.py +++ b/pyformlang/fst/fst.py @@ -1,36 +1,38 @@ """ Finite State Transducer """ -from typing import Dict, List, Set, Tuple, Iterator, Iterable, Hashable -from copy import deepcopy +from typing import Dict, List, Set, AbstractSet, \ + Tuple, Iterator, Iterable, Hashable from networkx import MultiDiGraph from networkx.drawing.nx_pydot import write_dot +from .transition_function import TransitionFunction +from .transition_function import TransitionKey, TransitionValues, Transition from .utils import StateRenaming from ..objects.finite_automaton_objects import State, Symbol, Epsilon from ..objects.finite_automaton_objects.utils import to_state, to_symbol -TransitionKey = Tuple[State, Symbol] -TransitionValue = Tuple[State, Tuple[Symbol, ...]] -TransitionValues = Set[TransitionValue] -TransitionFunction = Dict[TransitionKey, TransitionValues] - InputTransition = Tuple[Hashable, Hashable, Hashable, Iterable[Hashable]] -Transition = Tuple[TransitionKey, TransitionValue] class FST(Iterable[Transition]): """ Representation of a Finite State Transducer""" - def __init__(self) -> None: - self._states: Set[State] = set() # Set of states - self._input_symbols: Set[Symbol] = set() # Set of input symbols - self._output_symbols: Set[Symbol] = set() # Set of output symbols - # Dict from _states x _input_symbols U {epsilon} into a subset of - # _states X _output_symbols* - self._delta: TransitionFunction = {} - self._start_states: Set[State] = set() - self._final_states: Set[State] = set() # _final_states is final states + def __init__(self, + states: AbstractSet[Hashable] = None, + input_symbols: AbstractSet[Hashable] = None, + output_symbols: AbstractSet[Hashable] = None, + transition_function: TransitionFunction = None, + start_states: AbstractSet[Hashable] = None, + final_states: AbstractSet[Hashable] = None) -> None: + self._states = {to_state(x) for x in states or set()} + self._input_symbols = {to_symbol(x) for x in input_symbols or set()} + self._output_symbols = {to_symbol(x) for x in output_symbols or set()} + self._transition_function = transition_function or TransitionFunction() + self._start_states = {to_state(x) for x in start_states or set()} + self._states.update(self._start_states) + self._final_states = {to_state(x) for x in final_states or set()} + self._states.update(self._final_states) @property def states(self) -> Set[State]: @@ -87,16 +89,6 @@ def final_states(self) -> Set[State]: """ return self._final_states - def get_number_transitions(self) -> int: - """ Get the number of transitions in the FST - - Returns - ---------- - n_transitions : int - The number of transitions - """ - return sum(len(x) for x in self._delta.values()) - def add_transition(self, s_from: Hashable, input_symbol: Hashable, @@ -125,11 +117,10 @@ def add_transition(self, if input_symbol != Epsilon(): self._input_symbols.add(input_symbol) self._output_symbols.update(output_symbols) - head = (s_from, input_symbol) - if head in self._delta: - self._delta[head].add((s_to, output_symbols)) - else: - self._delta[head] = {(s_to, output_symbols)} + self._transition_function.add_transition(s_from, + input_symbol, + s_to, + output_symbols) def add_transitions(self, transitions: Iterable[InputTransition]) -> None: """ @@ -156,8 +147,20 @@ def remove_transition(self, input_symbol = to_symbol(input_symbol) s_to = to_state(s_to) output_symbols = tuple(to_symbol(x) for x in output_symbols) - head = (s_from, input_symbol) - self._delta.get(head, set()).discard((s_to, output_symbols)) + self._transition_function.remove_transition(s_from, + input_symbol, + s_to, + output_symbols) + + def get_number_transitions(self) -> int: + """ Get the number of transitions in the FST + + Returns + ---------- + n_transitions : int + The number of transitions + """ + return self._transition_function.get_number_transitions() def add_start_state(self, start_state: Hashable) -> None: """ Add a start state @@ -188,7 +191,7 @@ def __call__(self, s_from: Hashable, input_symbol: Hashable) \ """ Calls the transition function of the FST """ s_from = to_state(s_from) input_symbol = to_symbol(input_symbol) - return self._delta.get((s_from, input_symbol), set()) + return self._transition_function(s_from, input_symbol) def __contains__(self, transition: InputTransition) -> bool: """ Whether the given transition is present in the FST """ @@ -201,9 +204,7 @@ def __contains__(self, transition: InputTransition) -> bool: def __iter__(self) -> Iterator[Transition]: """ Gets an iterator of transitions of the FST """ - for key, values in self._delta.items(): - for value in values: - yield key, value + yield from self._transition_function def translate(self, input_word: Iterable[Hashable], @@ -300,14 +301,12 @@ def _add_transitions_to(self, union_fst: "FST", state_renaming: StateRenaming, idx: int) -> None: - for head, transition in self._delta.items(): - s_from, input_symbol = head - for s_to, output_symbols in transition: - union_fst.add_transition( - state_renaming.get_renamed_state(s_from, idx), - input_symbol, - state_renaming.get_renamed_state(s_to, idx), - output_symbols) + for (s_from, input_symbol), (s_to, output_symbols) in self: + union_fst.add_transition( + state_renaming.get_renamed_state(s_from, idx), + input_symbol, + state_renaming.get_renamed_state(s_to, idx), + output_symbols) def _add_extremity_states_to(self, union_fst: "FST", @@ -502,6 +501,18 @@ def write_as_dot(self, filename: str) -> None: """ write_dot(self.to_networkx(), filename) - def to_dict(self) -> TransitionFunction: + def copy(self) -> "FST": + """ Copies the FST """ + return FST(states=self.states, + input_symbols=self.input_symbols, + output_symbols=self.output_symbols, + transition_function=self._transition_function.copy(), + start_states=self.start_states, + final_states=self.final_states) + + def __copy__(self) -> "FST": + return self.copy() + + def to_dict(self) -> Dict[TransitionKey, TransitionValues]: """Gives the transitions as a dictionary""" - return deepcopy(self._delta) + return self._transition_function.to_dict() diff --git a/pyformlang/fst/tests/test_fst.py b/pyformlang/fst/tests/test_fst.py index 6878d93..7914389 100644 --- a/pyformlang/fst/tests/test_fst.py +++ b/pyformlang/fst/tests/test_fst.py @@ -4,7 +4,7 @@ import pytest -from pyformlang.fst import FST +from pyformlang.fst import FST, TransitionFunction, State, Symbol @pytest.fixture @@ -194,9 +194,9 @@ def test_paper(self): def test_contains(self, fst0: FST): """ Tests the containment of transition in the FST """ assert ("q0", "a", "q1", ["b"]) in fst0 - assert ("a", "b", "c", "d") not in fst0 - fst0.add_transition("a", "b", "c", "d") - assert ("a", "b", "c", "d") in fst0 + assert ("a", "b", "c", ["d"]) not in fst0 + fst0.add_transition("a", "b", "c", {"d"}) + assert ("a", "b", "c", ["d"]) in fst0 def test_iter(self, fst0: FST): """ Tests the iteration of FST transitions """ @@ -210,9 +210,42 @@ def test_iter(self, fst0: FST): def test_remove_transition(self, fst0: FST): """ Tests the removal of transition from the FST """ - assert ("q0", "a", "q1", "b") in fst0 - fst0.remove_transition("q0", "a", "q1", "b") - assert ("q0", "a", "q1", "b") not in fst0 - fst0.remove_transition("q0", "a", "q1", "b") - assert ("q0", "a", "q1", "b") not in fst0 + assert ("q0", "a", "q1", ["b"]) in fst0 + fst0.remove_transition("q0", "a", "q1", ["b"]) + assert ("q0", "a", "q1", ["b"]) not in fst0 + fst0.remove_transition("q0", "a", "q1", ["b"]) + assert ("q0", "a", "q1", ["b"]) not in fst0 assert fst0.get_number_transitions() == 0 + + def test_initialization(self): + """ Tests the initialization of the FST """ + fst = FST(states={0}, + input_symbols={"a", "b"}, + output_symbols={"c"}, + start_states={1}, + final_states={2}) + assert fst.states == {0, 1, 2} + assert fst.input_symbols == {"a", "b"} + assert fst.output_symbols == {"c"} + assert fst.get_number_transitions() == 0 + assert not list(iter(fst)) + + function = TransitionFunction() + function.add_transition(State(1), Symbol("a"), State(2), (Symbol("b"),)) + function.add_transition(State(1), Symbol("a"), State(2), (Symbol("c"),)) + fst = FST(transition_function=function) + assert fst.get_number_transitions() == 2 + assert (1, "a", 2, ["b"]) in fst + assert (1, "a", 2, ["c"]) in fst + assert fst(1, "a") == {(2, tuple("b")), (2, tuple("c"))} + + def test_copy(self, fst0: FST): + """ Tests the copying of the FST """ + fst_copy = fst0.copy() + assert fst_copy.states == fst0.states + assert fst_copy.input_symbols == fst0.input_symbols + assert fst_copy.output_symbols == fst0.output_symbols + assert fst_copy.start_states == fst0.start_states + assert fst_copy.final_states == fst0.final_states + assert fst_copy.to_dict() == fst0.to_dict() + assert fst_copy is not fst0 diff --git a/pyformlang/fst/transition_function.py b/pyformlang/fst/transition_function.py new file mode 100644 index 0000000..9f75805 --- /dev/null +++ b/pyformlang/fst/transition_function.py @@ -0,0 +1,78 @@ +""" The transition function of Finite State Transducer """ + +from typing import Dict, Set, Tuple, Iterator, Iterable +from copy import deepcopy + +from ..objects.finite_automaton_objects import State, Symbol + +TransitionKey = Tuple[State, Symbol] +TransitionValue = Tuple[State, Tuple[Symbol, ...]] +TransitionValues = Set[TransitionValue] +Transition = Tuple[TransitionKey, TransitionValue] + + +class TransitionFunction(Iterable[Transition]): + """ The transition function of Finite State Transducer """ + + def __init__(self) -> None: + self._transitions: Dict[TransitionKey, TransitionValues] = {} + + def add_transition(self, + s_from: State, + input_symbol: Symbol, + s_to: State, + output_symbols: Tuple[Symbol, ...]) -> None: + """ Adds given transition to the function """ + key = (s_from, input_symbol) + value = (s_to, output_symbols) + self._transitions.setdefault(key, set()).add(value) + + def remove_transition(self, + s_from: State, + input_symbol: Symbol, + s_to: State, + output_symbols: Tuple[Symbol, ...]) -> None: + """ Removes given transition from the function """ + key = (s_from, input_symbol) + value = (s_to, output_symbols) + self._transitions.get(key, set()).discard(value) + + def get_number_transitions(self) -> int: + """ Gets the number of transitions in the function + + Returns + ---------- + n_transitions : int + The number of transitions + """ + return sum(len(x) for x in self._transitions.values()) + + def __call__(self, s_from: State, input_symbol: Symbol) \ + -> TransitionValues: + """ Calls the transition function """ + return self._transitions.get((s_from, input_symbol), set()) + + def __contains__(self, transition: Transition) -> bool: + """ Whether the given transition is present in the function """ + key, value = transition + return value in self(*key) + + def __iter__(self) -> Iterator[Transition]: + """ Gets an iterator of transitions of the function """ + for key, values in self._transitions.items(): + for value in values: + yield key, value + + def copy(self) -> "TransitionFunction": + """ Copies the transition function """ + new_tf = TransitionFunction() + for key, value in self: + new_tf.add_transition(*key, *value) + return new_tf + + def __copy__(self) -> "TransitionFunction": + return self.copy() + + def to_dict(self) -> Dict[TransitionKey, TransitionValues]: + """ Gives the transition function as a dictionary """ + return deepcopy(self._transitions) diff --git a/pyformlang/pda/pda.py b/pyformlang/pda/pda.py index 8d74bae..3f5e867 100644 --- a/pyformlang/pda/pda.py +++ b/pyformlang/pda/pda.py @@ -189,16 +189,6 @@ def add_final_state(self, state: Hashable) -> None: state = to_state(state) self._final_states.add(state) - def get_number_transitions(self) -> int: - """ Gets the number of transitions in the PDA - - Returns - ---------- - n_transitions : int - The number of transitions - """ - return self._transition_function.get_number_transitions() - def add_transition(self, s_from: Hashable, input_symbol: Hashable, @@ -271,6 +261,16 @@ def remove_transition(self, s_to, stack_to) + def get_number_transitions(self) -> int: + """ Gets the number of transitions in the PDA + + Returns + ---------- + n_transitions : int + The number of transitions + """ + return self._transition_function.get_number_transitions() + def __call__(self, s_from: Hashable, input_symbol: Hashable, @@ -665,16 +665,6 @@ def __and__(self, other: DeterministicFiniteAutomaton) -> "PDA": """ return self.intersection(other) - def to_dict(self) -> Dict[TransitionKey, TransitionValues]: - """ - Get the transitions of the PDA as a dictionary - Returns - ------- - transitions : dict - The transitions - """ - return self._transition_function.to_dict() - def to_networkx(self) -> MultiDiGraph: """ Transform the current pda into a networkx graph @@ -785,6 +775,16 @@ def copy(self) -> "PDA": def __copy__(self) -> "PDA": return self.copy() + def to_dict(self) -> Dict[TransitionKey, TransitionValues]: + """ + Get the transitions of the PDA as a dictionary + Returns + ------- + transitions : dict + The transitions + """ + return self._transition_function.to_dict() + @staticmethod def __add_start_state_to_graph(graph: MultiDiGraph, state: State) -> None: diff --git a/pyformlang/pda/transition_function.py b/pyformlang/pda/transition_function.py index d717400..4801376 100644 --- a/pyformlang/pda/transition_function.py +++ b/pyformlang/pda/transition_function.py @@ -1,7 +1,7 @@ """ A transition function in a pushdown automaton """ -from copy import deepcopy from typing import Dict, Set, Iterator, Iterable, Tuple +from copy import deepcopy from ..objects.pda_objects import State, Symbol, StackSymbol @@ -17,16 +17,6 @@ class TransitionFunction(Iterable[Transition]): def __init__(self) -> None: self._transitions: Dict[TransitionKey, TransitionValues] = {} - def get_number_transitions(self) -> int: - """ Gets the number of transitions - - Returns - ---------- - n_transitions : int - The number of transitions - """ - return sum(len(x) for x in self._transitions.values()) - # pylint: disable=too-many-arguments def add_transition(self, s_from: State, @@ -66,22 +56,15 @@ def remove_transition(self, key = (s_from, input_symbol, stack_from) self._transitions.get(key, set()).discard((s_to, stack_to)) - def copy(self) -> "TransitionFunction": - """ Copy the current transition function + def get_number_transitions(self) -> int: + """ Gets the number of transitions Returns ---------- - new_tf : :class:`~pyformlang.pda.TransitionFunction` - The copy of the transition function + n_transitions : int + The number of transitions """ - new_tf = TransitionFunction() - for temp_in, transition in self._transitions.items(): - for temp_out in transition: - new_tf.add_transition(*temp_in, *temp_out) - return new_tf - - def __copy__(self) -> "TransitionFunction": - return self.copy() + return sum(len(x) for x in self._transitions.values()) def __call__(self, s_from: State, @@ -98,6 +81,22 @@ def __iter__(self) -> Iterator[Transition]: for value in values: yield key, value + def copy(self) -> "TransitionFunction": + """ Copy the current transition function + + Returns + ---------- + new_tf : :class:`~pyformlang.pda.TransitionFunction` + The copy of the transition function + """ + new_tf = TransitionFunction() + for temp_in, temp_out in self: + new_tf.add_transition(*temp_in, *temp_out) + return new_tf + + def __copy__(self) -> "TransitionFunction": + return self.copy() + def to_dict(self) -> Dict[TransitionKey, TransitionValues]: """Get the dictionary representation of the transitions""" return deepcopy(self._transitions) From 5c4598c905ed4477e24b1c8eb671d8976fcd93ac Mon Sep 17 00:00:00 2001 From: bygu4 Date: Fri, 6 Dec 2024 15:49:56 +0300 Subject: [PATCH 08/12] simplify constructors of enfa, pda and cfg, use hashable in fcfg constructor --- pyformlang/cfg/cfg.py | 14 +++------ pyformlang/finite_automaton/epsilon_nfa.py | 22 ++++--------- pyformlang/pda/pda.py | 36 ++++++++-------------- 3 files changed, 24 insertions(+), 48 deletions(-) diff --git a/pyformlang/cfg/cfg.py b/pyformlang/cfg/cfg.py index 1517440..1b5762e 100644 --- a/pyformlang/cfg/cfg.py +++ b/pyformlang/cfg/cfg.py @@ -47,16 +47,12 @@ def __init__(self, start_symbol: Hashable = None, productions: Iterable[Production] = None) -> None: super().__init__() - if variables is not None: - variables = {to_variable(x) for x in variables} - self._variables = variables or set() - if terminals is not None: - terminals = {to_terminal(x) for x in terminals} - self._terminals = terminals or set() + self._variables = {to_variable(x) for x in variables or set()} + self._terminals = {to_terminal(x) for x in terminals or set()} + self._start_symbol = None if start_symbol is not None: - start_symbol = to_variable(start_symbol) - self._variables.add(start_symbol) - self._start_symbol = start_symbol + self._start_symbol = to_variable(start_symbol) + self._variables.add(self._start_symbol) self._productions = set() for production in productions or set(): self.add_production(production) diff --git a/pyformlang/finite_automaton/epsilon_nfa.py b/pyformlang/finite_automaton/epsilon_nfa.py index e80a125..09af9e7 100644 --- a/pyformlang/finite_automaton/epsilon_nfa.py +++ b/pyformlang/finite_automaton/epsilon_nfa.py @@ -65,24 +65,14 @@ def __init__( start_states: AbstractSet[Hashable] = None, final_states: AbstractSet[Hashable] = None) -> None: super().__init__() - if states is not None: - states = {to_state(x) for x in states} - self._states = states or set() - if input_symbols is not None: - input_symbols = {to_symbol(x) for x in input_symbols} - self._input_symbols = input_symbols or set() + self._states = {to_state(x) for x in states or set()} + self._input_symbols = {to_symbol(x) for x in input_symbols or set()} self._transition_function = transition_function \ or NondeterministicTransitionFunction() - if start_states is not None: - start_states = {to_state(x) for x in start_states} - self._start_states = start_states or set() - if final_states is not None: - final_states = {to_state(x) for x in final_states} - self._final_states = final_states or set() - for state in self._final_states: - self._states.add(state) - for state in self._start_states: - self._states.add(state) + self._start_states = {to_state(x) for x in start_states or set()} + self._states.update(self._start_states) + self._final_states = {to_state(x) for x in final_states or set()} + self._states.update(self._final_states) def _get_next_states_iterable( self, diff --git a/pyformlang/pda/pda.py b/pyformlang/pda/pda.py index 3f5e867..8a492ed 100644 --- a/pyformlang/pda/pda.py +++ b/pyformlang/pda/pda.py @@ -69,33 +69,23 @@ def __init__(self, transition_function: TransitionFunction = None, start_state: Hashable = None, start_stack_symbol: Hashable = None, - final_states: AbstractSet[Hashable] = None): + final_states: AbstractSet[Hashable] = None) -> None: # pylint: disable=too-many-arguments - if states is not None: - states = {to_state(x) for x in states} - if input_symbols is not None: - input_symbols = {to_symbol(x) for x in input_symbols} - if stack_alphabet is not None: - stack_alphabet = {to_stack_symbol(x) for x in stack_alphabet} - if start_state is not None: - start_state = to_state(start_state) - if start_stack_symbol is not None: - start_stack_symbol = to_stack_symbol(start_stack_symbol) - if final_states is not None: - final_states = {to_state(x) for x in final_states} - self._states: Set[State] = states or set() - self._input_symbols: Set[PDASymbol] = input_symbols or set() - self._stack_alphabet: Set[StackSymbol] = stack_alphabet or set() + self._states = {to_state(x) for x in states or set()} + self._input_symbols = {to_symbol(x) for x in input_symbols or set()} + self._stack_alphabet = {to_stack_symbol(x) + for x in stack_alphabet or set()} self._transition_function = transition_function or TransitionFunction() - self._start_state: Optional[State] = start_state + self._start_state = None if start_state is not None: - self._states.add(start_state) - self._start_stack_symbol: Optional[StackSymbol] = start_stack_symbol + self._start_state = to_state(start_state) + self._states.add(self._start_state) + self._start_stack_symbol = None if start_stack_symbol is not None: - self._stack_alphabet.add(start_stack_symbol) - self._final_states: Set[State] = final_states or set() - for state in self._final_states: - self._states.add(state) + self._start_stack_symbol = to_stack_symbol(start_stack_symbol) + self._stack_alphabet.add(self._start_stack_symbol) + self._final_states = {to_state(x) for x in final_states or set()} + self._states.update(self._final_states) @property def states(self) -> Set[State]: From b3a59c6f634d9343c8818f0bb029c321bce727cd Mon Sep 17 00:00:00 2001 From: bygu4 Date: Sun, 6 Oct 2024 13:02:04 +0300 Subject: [PATCH 09/12] split CI into different workflows --- .github/workflows/ci_extra.yml | 37 +++++++++++++ .github/workflows/ci_feature.yml | 52 +++++++++++++++++++ .../{python-package.yml => ci_master.yml} | 21 ++++---- 3 files changed, 101 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/ci_extra.yml create mode 100644 .github/workflows/ci_feature.yml rename .github/workflows/{python-package.yml => ci_master.yml} (74%) diff --git a/.github/workflows/ci_extra.yml b/.github/workflows/ci_extra.yml new file mode 100644 index 0000000..dd81ad0 --- /dev/null +++ b/.github/workflows/ci_extra.yml @@ -0,0 +1,37 @@ +# This workflow is for any branch. It runs additional tests for several python versions. + +name: Build extra + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10"] + + steps: + + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Lint with pylint + run: | + pylint pyformlang || true + - name: Lint with pycodestyle + run: | + pycodestyle pyformlang || true + - name: Test with pytest + run: | + pytest --showlocals -v pyformlang diff --git a/.github/workflows/ci_feature.yml b/.github/workflows/ci_feature.yml new file mode 100644 index 0000000..a8c011f --- /dev/null +++ b/.github/workflows/ci_feature.yml @@ -0,0 +1,52 @@ +# This workflow is for feature branches. It sets up python, lints with several analyzers, +# runs tests, collects test coverage and makes a coverage comment. + +name: Build feature + +on: + push: + branches-ignore: "master" + pull_request: + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8"] + + steps: + + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Lint with pylint + run: | + pylint pyformlang || true + - name: Lint with pycodestyle + run: | + pycodestyle pyformlang || true + - name: Test with pytest + run: | + pytest --showlocals -v pyformlang + + - name: Build coverage file + run: | + pytest pyformlang --junitxml=pytest.xml --cov=pyformlang | tee ./pytest-coverage.txt + - name: Make coverage comment + uses: MishaKav/pytest-coverage-comment@main + id: coverageComment + with: + pytest-coverage-path: ./pytest-coverage.txt + junitxml-path: ./pytest.xml + default-branch: master diff --git a/.github/workflows/python-package.yml b/.github/workflows/ci_master.yml similarity index 74% rename from .github/workflows/python-package.yml rename to .github/workflows/ci_master.yml index caf70a6..1c5a06d 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/ci_master.yml @@ -1,9 +1,11 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions +# This workflow is for master branch only. It sets up python, lints with several analyzers, +# runs tests, collects test coverage, makes a coverage comment and creates a coverage badge. -name: Python package +name: Build master -on: [push, pull_request] +on: + push: + branches: "master" jobs: build: @@ -12,9 +14,10 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8"] steps: + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 @@ -25,6 +28,7 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with pylint run: | pylint pyformlang || true @@ -34,20 +38,19 @@ jobs: - name: Test with pytest run: | pytest --showlocals -v pyformlang + - name: Build coverage file - if: ${{ matrix.python-version == '3.8'}} run: | pytest pyformlang --junitxml=pytest.xml --cov=pyformlang | tee ./pytest-coverage.txt - - name: Pytest coverage comment - if: ${{ matrix.python-version == '3.8'}} + - name: Make coverage comment uses: MishaKav/pytest-coverage-comment@main id: coverageComment with: pytest-coverage-path: ./pytest-coverage.txt junitxml-path: ./pytest.xml default-branch: master + - name: Create coverage Badge - if: ${{ github.ref_name == 'master' && matrix.python-version == '3.8'}} uses: schneegans/dynamic-badges-action@v1.0.0 with: auth: ${{ secrets.GIST_SECRET }} From f438eb249e1cdf5beee98b39500f5bc4f1bf3b43 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Thu, 10 Oct 2024 21:41:40 +0300 Subject: [PATCH 10/12] simplify CI dependencies setup --- .github/workflows/ci_extra.yml | 3 +-- .github/workflows/ci_feature.yml | 3 +-- .github/workflows/ci_master.yml | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci_extra.yml b/.github/workflows/ci_extra.yml index dd81ad0..fd63829 100644 --- a/.github/workflows/ci_extra.yml +++ b/.github/workflows/ci_extra.yml @@ -23,8 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -r requirements.txt - name: Lint with pylint run: | diff --git a/.github/workflows/ci_feature.yml b/.github/workflows/ci_feature.yml index a8c011f..42f1165 100644 --- a/.github/workflows/ci_feature.yml +++ b/.github/workflows/ci_feature.yml @@ -27,8 +27,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -r requirements.txt - name: Lint with pylint run: | diff --git a/.github/workflows/ci_master.yml b/.github/workflows/ci_master.yml index 1c5a06d..cc14bef 100644 --- a/.github/workflows/ci_master.yml +++ b/.github/workflows/ci_master.yml @@ -26,8 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -r requirements.txt - name: Lint with pylint run: | From a6104d32f5b41fbe5b9670dffa4dc0c8549e1db5 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Fri, 6 Dec 2024 16:43:24 +0300 Subject: [PATCH 11/12] add pyright to the CI config --- .github/workflows/ci_extra.yml | 3 +++ .github/workflows/ci_feature.yml | 3 +++ .github/workflows/ci_master.yml | 3 +++ requirements.txt | 1 + 4 files changed, 10 insertions(+) diff --git a/.github/workflows/ci_extra.yml b/.github/workflows/ci_extra.yml index fd63829..7309c12 100644 --- a/.github/workflows/ci_extra.yml +++ b/.github/workflows/ci_extra.yml @@ -31,6 +31,9 @@ jobs: - name: Lint with pycodestyle run: | pycodestyle pyformlang || true + - name: Check with pyright + run: | + pyright --stats pyformlang - name: Test with pytest run: | pytest --showlocals -v pyformlang diff --git a/.github/workflows/ci_feature.yml b/.github/workflows/ci_feature.yml index 42f1165..edf0626 100644 --- a/.github/workflows/ci_feature.yml +++ b/.github/workflows/ci_feature.yml @@ -35,6 +35,9 @@ jobs: - name: Lint with pycodestyle run: | pycodestyle pyformlang || true + - name: Check with pyright + run: | + pyright --stats pyformlang - name: Test with pytest run: | pytest --showlocals -v pyformlang diff --git a/.github/workflows/ci_master.yml b/.github/workflows/ci_master.yml index cc14bef..2ab30a8 100644 --- a/.github/workflows/ci_master.yml +++ b/.github/workflows/ci_master.yml @@ -34,6 +34,9 @@ jobs: - name: Lint with pycodestyle run: | pycodestyle pyformlang || true + - name: Check with pyright + run: | + pyright --stats pyformlang - name: Test with pytest run: | pytest --showlocals -v pyformlang diff --git a/requirements.txt b/requirements.txt index 3179fd5..c65ce0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ sphinx_rtd_theme numpy pylint pycodestyle +pyright pydot pygments>=2.7.4 # not directly required, pinned by Snyk to avoid a vulnerability pylint>=2.7.0 # not directly required, pinned by Snyk to avoid a vulnerability From e63bd9d745feb057f332b74a2e0a2121b1b2d663 Mon Sep 17 00:00:00 2001 From: bygu4 Date: Sun, 29 Dec 2024 23:03:21 +0300 Subject: [PATCH 12/12] update indexed grammars tests and utils --- pyformlang/indexed_grammar/indexed_grammar.py | 8 ++++---- .../indexed_grammar/tests/test_rules.py | 19 +++++++++---------- pyformlang/indexed_grammar/utils.py | 18 ++---------------- 3 files changed, 15 insertions(+), 30 deletions(-) diff --git a/pyformlang/indexed_grammar/indexed_grammar.py b/pyformlang/indexed_grammar/indexed_grammar.py index 8a617a9..0037222 100644 --- a/pyformlang/indexed_grammar/indexed_grammar.py +++ b/pyformlang/indexed_grammar/indexed_grammar.py @@ -15,7 +15,7 @@ from .production_rule import ProductionRule from .consumption_rule import ConsumptionRule from .end_rule import EndRule -from .utils import exists, addrec_bis +from .utils import addrec_bis from ..objects.cfg_objects.utils import to_variable @@ -47,9 +47,9 @@ def __init__(self, self._marked[non_terminal_a].add(temp) # Mark all end symbols for non_terminal_a in non_terminals: - if exists(self._rules.rules, - lambda x: isinstance(x, EndRule) - and x.left_term == non_terminal_a): + if any(map(lambda x: isinstance(x, EndRule) + and x.left_term == non_terminal_a, + self._rules.rules)): self._marked[non_terminal_a].add(frozenset()) @property diff --git a/pyformlang/indexed_grammar/tests/test_rules.py b/pyformlang/indexed_grammar/tests/test_rules.py index d8a355d..e0b3ce7 100644 --- a/pyformlang/indexed_grammar/tests/test_rules.py +++ b/pyformlang/indexed_grammar/tests/test_rules.py @@ -17,18 +17,17 @@ class TestIndexedGrammar: def test_consumption_rules(self): """ Tests the consumption rules """ - conso = ConsumptionRule("end", "C", "T") - terminals = conso.terminals + consumption = ConsumptionRule("end", "C", "T") + terminals = consumption.terminals assert terminals == {"end"} - representation = str(conso) + representation = str(consumption) assert representation == "C [ end ] -> T" def test_duplication_rules(self): """ Tests the duplication rules """ - dupli = DuplicationRule("B0", "A0", "C") - assert dupli.terminals == set() - assert str(dupli) == \ - "B0 -> A0 C" + duplication = DuplicationRule("B0", "A0", "C") + assert duplication.terminals == set() + assert str(duplication) == "B0 -> A0 C" def test_end_rule(self): """ Tests the end rules """ @@ -39,9 +38,9 @@ def test_end_rule(self): def test_production_rules(self): """ Tests the production rules """ - produ = ProductionRule("S", "C", "end") - assert produ.terminals == {"end"} - assert str(produ) == "S -> C [ end ]" + production = ProductionRule("S", "C", "end") + assert production.terminals == {"end"} + assert str(production) == "S -> C [ end ]" def test_rules(self): """ Tests the rules """ diff --git a/pyformlang/indexed_grammar/utils.py b/pyformlang/indexed_grammar/utils.py index ce8c103..2fce4df 100644 --- a/pyformlang/indexed_grammar/utils.py +++ b/pyformlang/indexed_grammar/utils.py @@ -2,21 +2,7 @@ # pylint: disable=cell-var-from-loop -from typing import Callable, List, Set, Iterable, Any - - -def exists(list_elements: List[Any], - check_function: Callable[[Any], bool]) -> bool: - """exists - Check whether at least an element x of l is True for f(x) - :param list_elements: A list of elements to test - :param check_function: The checking function (takes one parameter and \ - return a boolean) - """ - for element in list_elements: - if check_function(element): - return True - return False +from typing import List, Set, Iterable, Any def addrec_bis(l_sets: Iterable[Any], @@ -58,7 +44,7 @@ def addrec_ter(l_sets: List[Any], marked_left: Set[Any]) -> bool: # End condition, nothing left to process temp_in = [x[0] for x in l_sets] exists_after = [ - exists(l_sets[index + 1:], lambda x: x[0] == l_sets[index][0]) + any(map(lambda x: x[0] == l_sets[index][0], l_sets[index + 1:])) for index in range(len(l_sets))] exists_before = [l_sets[index][0] in temp_in[:index] for index in range(len(l_sets))]