1010from . import dtypes
1111from .visitor import _Visitor
1212
13+
1314def draw_graph (vars , filename = 'out.png' , prog = 'dot' , rankdir = 'LR' , approx = False ):
1415 gg = GraphGenerator (rankdir = rankdir , approx = approx )
1516 gg .draw (vars , filename , prog )
1617
18+
1719class GraphGenerator (_Visitor ):
20+
1821 def __init__ (self , rankdir = 'LR' , approx = False ):
1922 _Visitor .__init__ (self )
20-
23+
2124 try :
2225 import pygraphviz as pgv
2326 except :
2427 raise ImportError ('Graph generator requires Pygraphviz.' )
2528
2629 self .graph = pgv .AGraph (directed = True , rankdir = rankdir )
27-
30+
2831 self .approx = approx
2932 self .tmp_count = 0
3033
@@ -38,22 +41,22 @@ def draw(self, vars, filename='out.png', prog='dot'):
3841
3942 self .graph .add_subgraph (self .input_nodes , rank = 'same' )
4043
41- for rank , nodes in sorted (self .ranks .items (), key = lambda x :x [0 ]):
44+ for rank , nodes in sorted (self .ranks .items (), key = lambda x : x [0 ]):
4245 self .graph .add_subgraph (nodes , rank = 'same' )
43-
46+
4447 self .graph .add_subgraph (self .output_nodes , rank = 'same' )
45-
48+
4649 self .graph .write ('out.dot' )
4750 self .graph .layout (prog = prog )
4851 self .graph .draw (filename )
4952
5053 def _set_rank (self , rank , node ):
5154 self .ranks [rank ].append (node )
52-
55+
5356 def _add_output (self , node , src ):
5457 if node ._has_output ():
5558 outobj = str (node .output_data )
56- label_data = [ outobj , str (node .width ) ]
59+ label_data = [outobj , str (node .width )]
5760 if node .point > 0 :
5861 label_data .append (str (node .point ))
5962 label = ':' .join (label_data )
@@ -68,11 +71,12 @@ def _add_gap(self, node, mark=''):
6871
6972 if node .start_stage is None or node .end_stage is None :
7073 return node
71-
74+
7275 prev = node
7376 for i in range (node .end_stage - node .start_stage - 1 ):
7477 tmp = self ._get_tmp ()
75- self .graph .add_node (tmp , label = mark , shape = 'box' , color = 'lightgray' , style = 'filled' )
78+ self .graph .add_node (tmp , label = mark , shape = 'box' ,
79+ color = 'lightgray' , style = 'filled' )
7680 self .graph .add_edge (prev , tmp )
7781 self ._set_rank (node .start_stage + 2 + i , tmp )
7882 prev = tmp
@@ -90,38 +94,42 @@ def _get_mark(self, obj):
9094 def _get_tmp (self ):
9195 v = self .tmp_count
9296 self .tmp_count += 1
93- return hash ( (id (self ), v ) )
94-
97+ return hash ((id (self ), v ))
98+
9599 def visit__BinaryOperator (self , node ):
96100 mark = self ._get_mark (node .op )
97101 self .graph .add_node (node , label = mark , shape = 'circle' )
98-
102+
99103 left = self .visit (node .left )
100104 right = self .visit (node .right )
101105 self .graph .add_edge (left , node , label = 'L' )
102106 self .graph .add_edge (right , node , label = 'R' )
103107
104108 if node .start_stage is not None :
105109 self ._set_rank (node .start_stage + 1 , node )
106-
110+
107111 prev = self ._add_gap (node , mark )
108112 self ._add_output (node , prev )
109113 return prev
110-
114+
111115 def visit__UnaryOperator (self , node ):
112116 if self .approx and isinstance (node , dtypes ._Delay ):
113117 prev = self .visit (node .parent_value )
114118 self ._add_output (node , prev )
115119 return prev
116-
120+
117121 mark = ('delay' if isinstance (node , dtypes ._Delay ) else
118122 'prev' if isinstance (node , dtypes ._Prev ) else
119- self ._get_mark (node .op ) )
120- shape = 'box' if isinstance (node , (dtypes ._Delay , dtypes ._Prev )) else 'circle'
121- color = 'lightgray' if isinstance (node , (dtypes ._Delay , dtypes ._Prev )) else 'black'
122- style = 'filled' if isinstance (node , (dtypes ._Delay , dtypes ._Prev )) else None
123- self .graph .add_node (node , label = mark , shape = shape , color = color , style = style )
124-
123+ self ._get_mark (node .op ))
124+ shape = 'box' if isinstance (
125+ node , (dtypes ._Delay , dtypes ._Prev )) else 'circle'
126+ color = 'lightgray' if isinstance (
127+ node , (dtypes ._Delay , dtypes ._Prev )) else 'black'
128+ style = 'filled' if isinstance (
129+ node , (dtypes ._Delay , dtypes ._Prev )) else None
130+ self .graph .add_node (node , label = mark , shape = shape ,
131+ color = color , style = style )
132+
125133 right = self .visit (node .right )
126134 self .graph .add_edge (right , node , label = 'R' )
127135
@@ -131,31 +139,31 @@ def visit__UnaryOperator(self, node):
131139 self ._set_rank (node .start_stage , node )
132140 else :
133141 self ._set_rank (node .start_stage + 1 , node )
134-
142+
135143 prev = self ._add_gap (node , mark )
136144 self ._add_output (node , prev )
137145 return prev
138146
139147 def visit__SpecialOperator (self , node ):
140148 mark = self ._get_mark (node .op )
141149 self .graph .add_node (node , label = mark , shape = 'ellipse' )
142-
150+
143151 for i , arg in enumerate (node .args ):
144152 a = self .visit (arg )
145153 self .graph .add_edge (a , node , label = str (i ))
146-
154+
147155 if node .start_stage is not None :
148156 self ._set_rank (node .start_stage + 1 , node )
149-
157+
150158 prev = self ._add_gap (node , mark )
151159 self ._add_output (node , prev )
152160 return prev
153161
154162 def visit__Accumulator (self , node ):
155- mark = (' ' .join ([ self ._get_mark (op ) for op in node .ops ])
163+ mark = (' ' .join ([self ._get_mark (op ) for op in node .ops ])
156164 if node .label is None else node .label )
157165 self .graph .add_node (node , label = mark , shape = 'box' , style = 'rounded' )
158-
166+
159167 right = self .visit (node .right )
160168 initval = self .visit (node .initval )
161169 if node .enable is not None :
@@ -168,24 +176,24 @@ def visit__Accumulator(self, node):
168176 self .graph .add_edge (enable , node , label = 'enable' )
169177 if node .reset is not None :
170178 self .graph .add_edge (reset , node , label = 'reset' )
171-
179+
172180 if node .start_stage is not None :
173181 self ._set_rank (node .start_stage + 1 , node )
174-
182+
175183 prev = self ._add_gap (node , mark )
176184 self ._add_output (node , prev )
177185 return prev
178-
186+
179187 def visit__ParameterVariable (self , node ):
180188 inobj = str (node .input_data )
181- label_data = [ inobj , str (node .width ) ]
189+ label_data = [inobj , str (node .width )]
182190 if node .point > 0 :
183191 label_data .append (str (node .point ))
184192 label = ':' .join (label_data )
185-
193+
186194 self .graph .add_node (node , label = label , shape = '' ,
187195 color = 'lightblue' , style = 'rounded,filled' , peripheries = 2 )
188-
196+
189197 self .input_nodes .append (node )
190198 self ._add_output (node , node )
191199 return node
@@ -196,27 +204,29 @@ def visit__Variable(self, node):
196204 return input_data
197205
198206 inobj = str (node .input_data )
199- label_data = [ inobj , str (node .width ) ]
207+ label_data = [inobj , str (node .width )]
200208 if node .point > 0 :
201209 label_data .append (str (node .point ))
202210 label = ':' .join (label_data )
203-
211+
204212 self .graph .add_node (node , label = label , shape = 'box' ,
205213 color = 'lightblue' , style = 'filled' , peripheries = 2 )
206-
214+
207215 self .input_nodes .append (node )
208216 self ._add_output (node , node )
209217 return node
210218
211219 def visit__Constant (self , node ):
212220 if isinstance (node , dtypes .FixedPoint ):
213- value = "%f:%d" % (((1.0 * node .value ) / (2.0 ** node .point )), node .point )
221+ value = "%f:%d" % (
222+ ((1.0 * node .value ) / (2.0 ** node .point )), node .point )
214223 elif isinstance (node , dtypes .Float ):
215224 value = "%f" % node .value
216225 else :
217226 value = str (node .value )
218-
219- self .graph .add_node (node , label = value , shape = '' , color = 'lightblue' , style = 'filled' )
220-
227+
228+ self .graph .add_node (node , label = value , shape = '' ,
229+ color = 'lightblue' , style = 'filled' )
230+
221231 self ._add_output (node , node )
222232 return node
0 commit comments