|
2 | 2 |
|
3 | 3 | from typing import Dict, List, Set, Tuple, Iterator, Iterable, Hashable |
4 | 4 | from copy import deepcopy |
5 | | -from json import dumps, loads |
6 | 5 |
|
7 | 6 | from networkx import MultiDiGraph |
8 | 7 | from networkx.drawing.nx_pydot import write_dot |
@@ -417,26 +416,26 @@ def to_networkx(self) -> MultiDiGraph: |
417 | 416 | """ |
418 | 417 | graph = MultiDiGraph() |
419 | 418 | for state in self._states: |
420 | | - graph.add_node(state, |
| 419 | + graph.add_node(state.value, |
421 | 420 | is_start=state in self.start_states, |
422 | 421 | is_final=state in self.final_states, |
423 | 422 | peripheries=2 if state in self.final_states else 1, |
424 | | - label=state) |
| 423 | + label=state.value) |
425 | 424 | if state in self.start_states: |
426 | 425 | graph.add_node("starting_" + str(state), |
427 | 426 | label="", |
428 | 427 | shape=None, |
429 | 428 | height=.0, |
430 | 429 | width=.0) |
431 | 430 | graph.add_edge("starting_" + str(state), |
432 | | - state) |
433 | | - for s_from, input_symbol in self._delta: |
434 | | - for s_to, output_symbols in self._delta[(s_from, input_symbol)]: |
435 | | - graph.add_edge( |
436 | | - s_from, |
437 | | - s_to, |
438 | | - label=(dumps(input_symbol) + " -> " + |
439 | | - dumps(output_symbols))) |
| 431 | + state.value) |
| 432 | + for (s_from, input_symbol), (s_to, output_symbols) in self: |
| 433 | + input_symbol = input_symbol.value |
| 434 | + output_symbols = tuple(map(lambda x: x.value, output_symbols)) |
| 435 | + graph.add_edge( |
| 436 | + s_from.value, |
| 437 | + s_to.value, |
| 438 | + label=(input_symbol, output_symbols)) |
440 | 439 | return graph |
441 | 440 |
|
442 | 441 | @classmethod |
@@ -465,10 +464,8 @@ def from_networkx(cls, graph: MultiDiGraph) -> "FST": |
465 | 464 | for s_to in graph[s_from]: |
466 | 465 | for transition in graph[s_from][s_to].values(): |
467 | 466 | if "label" in transition: |
468 | | - in_symbol, out_symbols = transition["label"].split( |
469 | | - " -> ") |
470 | | - in_symbol = loads(in_symbol) |
471 | | - out_symbols = loads(out_symbols) |
| 467 | + label = transition["label"] |
| 468 | + in_symbol, out_symbols = label |
472 | 469 | fst.add_transition(s_from, |
473 | 470 | in_symbol, |
474 | 471 | s_to, |
|
0 commit comments