44from IPython .display import display
55
66from pytensor .graph import FunctionGraph , Variable , rewrite_graph
7- from pytensor .graph .features import FullHistory
7+ from pytensor .graph .features import AlreadyThere , FullHistory
88
99
1010class CodeBlockWidget (anywidget .AnyWidget ):
@@ -45,29 +45,41 @@ class CodeBlockWidget(anywidget.AnyWidget):
4545
4646class InteractiveRewrite :
4747 """
48- A class that wraps a graph history object with interactive widgets
49- to navigate through history and display the graph at each step.
50-
51- Includes an option to display the reason for the last change.
48+ Visualize a graph history through a series of rewrites.
5249 """
5350
54- def __init__ (self , fg , display_reason = True ):
51+ def __init__ (
52+ self ,
53+ fg ,
54+ display_reason = True ,
55+ rewrite_options : dict | None = None ,
56+ dprint_options : dict | None = None ,
57+ ):
5558 """
56- Initialize with a history object that has a goto method
57- and tracks a FunctionGraph.
58-
5959 Parameters:
6060 -----------
6161 fg : FunctionGraph (or Variables)
6262 The function graph to track
6363 display_reason : bool, optional
6464 Whether to display the reason for each rewrite
65+ rewrite_options : dict, optional
66+ Options for rewriting the graph. Defaults to {'include': ('fast_run',), 'exclude': ('inplace',)}
67+ print_options : dict, optional
68+ Print options passed to `debugprint` used to generate the text representation of the graph.
69+ Useful options are {'print_shape': True, 'print_op_info': True}
6570 """
71+ self .dprint_options = dprint_options or {}
72+ self .rewrite_options = rewrite_options or dict (
73+ include = ("fast_run" ,), exclude = ("inplace" ,)
74+ )
6675 self .history = FullHistory (callback = self ._history_callback )
6776 if not isinstance (fg , FunctionGraph ):
6877 outs = [fg ] if isinstance (fg , Variable ) else fg
6978 fg = FunctionGraph (outputs = outs )
70- fg .attach_feature (self .history )
79+ try :
80+ fg .attach_feature (self .history )
81+ except AlreadyThere :
82+ self .history .end ()
7183
7284 self .updating_from_callback = False # Flag to prevent recursion
7385 self .code_widget = CodeBlockWidget (content = "" )
@@ -163,7 +175,7 @@ def _update_display(self):
163175 reason = ""
164176 else :
165177 reason = self .history .fw [self .history .pointer ].reason
166- reason = getattr (reason , "name" , str (reason ) )
178+ reason = getattr (reason , "name" , None ) or str (reason )
167179
168180 self .reason_label .value = f"""
169181 <div style='padding: 5px; margin-bottom: 10px; background-color: #e6f7ff; border-left: 4px solid #1890ff;'>
@@ -172,7 +184,9 @@ def _update_display(self):
172184 """
173185
174186 # Update the graph display
175- self .code_widget .content = self .history .fg .dprint (file = "str" )
187+ self .code_widget .content = self .history .fg .dprint (
188+ file = "str" , ** self .dprint_options
189+ )
176190
177191 # Update slider range if history length has changed
178192 history_len = len (self .history .fw ) + 1
@@ -189,14 +203,13 @@ def _update_display(self):
189203 f"History: { self .history .pointer + 1 } /{ history_len - 1 } "
190204 )
191205
192- def rewrite (self , * args , include = ( "fast_run" ,), exclude = ( "inplace" ,), ** kwargs ):
206+ def rewrite (self , * args , ** kwargs ):
193207 """Apply rewrites to the current graph"""
194208 rewrite_graph (
195209 self .history .fg ,
196210 * args ,
197- include = include ,
198- exclude = exclude ,
199211 ** kwargs ,
212+ ** self .rewrite_options ,
200213 clone = False ,
201214 )
202215 self ._update_display ()
0 commit comments