1414from symbolic_pymc .tensorflow .meta import TFlowMetaOp
1515
1616
17+ class DepthExceededException (Exception ):
18+ pass
19+
20+
1721class TFlowPrinter (object ):
1822 """A printer that indents and keeps track of already printed subgraphs."""
1923
20- def __init__ (self , formatter , buffer ):
24+ def __init__ (self , formatter , buffer , depth_lower_idx = 0 , depth_upper_idx = sys .maxsize ):
25+ # The buffer to which results are printed
2126 self .buffer = buffer
27+ # A function used to pre-process printed results
2228 self .formatter = formatter
23- self .indentation = ""
29+
30+ self .depth_count = 0
31+ self .depth_lower_idx , self .depth_upper_idx = depth_lower_idx , depth_upper_idx
32+
33+ # This is the current indentation string
34+ if self .depth_lower_idx > 0 :
35+ self .indentation = "... "
36+ else :
37+ self .indentation = ""
38+
39+ # The set of graphs that have already been printed
2440 self .printed_subgraphs = set ()
2541
2642 @contextmanager
2743 def indented (self , indent ):
2844 pre_indentation = self .indentation
29- if isinstance (indent , int ):
30- self .indentation += " " * indent
31- else :
45+
46+ self .depth_count += 1
47+
48+ if self .depth_lower_idx < self .depth_count <= self .depth_upper_idx :
3249 self .indentation += indent
50+
3351 try :
3452 yield
53+ except DepthExceededException :
54+ pass
3555 finally :
3656 self .indentation = pre_indentation
57+ self .depth_count -= 1
3758
3859 def format (self , obj ):
3960 return self .indentation + self .formatter (obj )
4061
41- def print (self , obj ):
42- self .buffer .write (self .format (obj ))
43- self .buffer .flush ()
44-
45- def println (self , obj ):
46- self .buffer .write (self .format (obj ) + "\n " )
47- self .buffer .flush ()
62+ def print (self , obj , suffix = "" ):
63+ if self .depth_lower_idx <= self .depth_count < self .depth_upper_idx :
64+ self .buffer .write (self .format (obj ) + suffix )
65+ self .buffer .flush ()
66+ elif self .depth_count == self .depth_upper_idx :
67+ # Only print the cut-off indicator at the first occurrence
68+ self .buffer .write (self .format (f"...{ suffix } " ))
69+ self .buffer .flush ()
4870
71+ # Prevent the caller from traversing at this level or higher
72+ raise DepthExceededException ()
4973
50- def tf_dprint (obj , printer = None ):
51- """Print a textual representation of a TF graph.
74+ def println (self , obj ):
75+ self .print (obj , suffix = "\n " )
76+
77+ def subgraph_add (self , obj ):
78+ if self .depth_lower_idx <= self .depth_count < self .depth_upper_idx :
79+ # Only track printed subgraphs when they're actually printed
80+ self .printed_subgraphs .add (obj )
81+
82+ def __repr__ (self ): # pragma: no cover
83+ return (
84+ "TFlowPrinter\n "
85+ f"\t depth_lower_idx={ self .depth_lower_idx } ,\t depth_upper_idx={ self .depth_upper_idx } \n "
86+ f"\t indentation='{ self .indentation } ',\t depth_count={ self .depth_count } "
87+ )
88+
89+
90+ def tf_dprint (obj , depth_lower = 0 , depth_upper = 10 , printer = None ):
91+ """Print a textual representation of a TF graph. The output roughly follows the format of `theano.printing.debugprint`.
92+
93+ Parameters
94+ ----------
95+ obj : Tensorflow object
96+ Tensorflow graph object to be represented.
97+ depth_lower : int
98+ Used to index specific portions of the graph.
99+ depth_upper : int
100+ Used to index specific portions of the graph.
101+ printer : optional
102+ Backend used to display the output.
52103
53- The output roughly follows the format of `theano.printing.debugprint`.
54104 """
105+
55106 if isinstance (obj , tf .Tensor ):
56107 try :
57108 obj .op
@@ -63,7 +114,7 @@ def tf_dprint(obj, printer=None):
63114 )
64115
65116 if printer is None :
66- printer = TFlowPrinter (str , sys .stdout )
117+ printer = TFlowPrinter (str , sys .stdout , depth_lower , depth_upper )
67118
68119 _tf_dprint (obj , printer )
69120
@@ -75,28 +126,22 @@ def _tf_dprint(obj, printer):
75126
76127@_tf_dprint .register (tf .Tensor )
77128@_tf_dprint .register (TFlowMetaTensor )
78- def _ (obj , printer ):
129+ def _tf_dprint_TFlowMetaTensor (obj , printer ):
79130
80131 try :
81132 shape_str = str (obj .shape .as_list ())
82133 except (ValueError , AttributeError ):
83134 shape_str = "Unknown"
84135
85136 prefix = f'Tensor({ getattr (obj .op , "type" , obj .op )} ):{ obj .value_index } ,\t dtype={ getattr (obj .dtype , "name" , obj .dtype )} ,\t shape={ shape_str } ,\t "{ obj .name } "'
137+
86138 _tf_dprint (prefix , printer )
87139
88140 if isvar (obj .op ):
89141 return
90142 elif isvar (obj .op .inputs ):
91143 with printer .indented ("| " ):
92144 _tf_dprint (f"{ obj .op .inputs } " , printer )
93- elif len (obj .op .inputs ) > 0 :
94- with printer .indented ("| " ):
95- if obj not in printer .printed_subgraphs :
96- printer .printed_subgraphs .add (obj )
97- _tf_dprint (obj .op , printer )
98- else :
99- _tf_dprint ("..." , printer )
100145 elif obj .op .type == "Const" :
101146 with printer .indented ("| " ):
102147 if isinstance (obj , tf .Tensor ):
@@ -110,10 +155,17 @@ def _(obj, printer):
110155 _tf_dprint (
111156 np .array2string (numpy_val , threshold = 20 , prefix = printer .indentation ), printer
112157 )
158+ elif len (obj .op .inputs ) > 0 :
159+ with printer .indented ("| " ):
160+ if obj in printer .printed_subgraphs :
161+ _tf_dprint ("..." , printer )
162+ else :
163+ printer .subgraph_add (obj )
164+ _tf_dprint (obj .op , printer )
113165
114166
115167@_tf_dprint .register (tf .Operation )
116168@_tf_dprint .register (TFlowMetaOp )
117- def _ (obj , printer ):
169+ def _tf_dprint_TFlowMetaOp (obj , printer ):
118170 for op_input in obj .inputs :
119171 _tf_dprint (op_input , printer )
0 commit comments