@@ -438,6 +438,169 @@ def revert(self, fgraph, checkpoint):
438438 self .history [fgraph ] = h
439439
440440
441+ class FullHistory (Feature ):
442+ """Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states
443+
444+ .. testcode::
445+ import pytensor
446+ import pytensor.tensor as pt
447+ from pytensor.graph.fg import FunctionGraph
448+ from pytensor.graph.features import FullHistory
449+ from pytensor.graph.rewriting.utils import rewrite_graph
450+
451+ x = pt.scalar("x")
452+ out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
453+
454+ fg = FunctionGraph(outputs=[out])
455+ history = FullHistory()
456+ fg.attach_feature(history)
457+
458+ rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
459+
460+ # Replay rewrites
461+ history.start()
462+ pytensor.dprint(fg)
463+ with pytensor.config.change_flags(optimizer_verbose = True):
464+ for i in range(3):
465+ print(">> ", end="")
466+ pytensor.dprint(history.next())
467+
468+ .. testoutput::
469+ Log [id A] 4
470+ └─ True_div [id B] 3
471+ ├─ Exp [id C] 2
472+ │ └─ x [id D]
473+ └─ Sum{axes=None} [id E] 1
474+ └─ Exp [id F] 0
475+ └─ x [id D]
476+ >> MergeOptimizer
477+ Log [id A] 3
478+ └─ True_div [id B] 2
479+ ├─ Exp [id C] 0
480+ │ └─ x [id D]
481+ └─ Sum{axes=None} [id E] 1
482+ └─ Exp [id C] 0
483+ └─ ···
484+ >> local_mul_canonizer
485+ Log [id A] 1
486+ └─ Softmax{axis=None} [id B] 0
487+ └─ x [id C]
488+ >> local_logsoftmax
489+ LogSoftmax{axis=None} [id A] 0
490+ └─ x [id B]
491+
492+
493+ .. testcode::
494+ # Or in reverse
495+ with pytensor.config.change_flags(optimizer_verbose=True):
496+ for i in range(3):
497+ print(">> ", end="")
498+ pytensor.dprint(history.prev())
499+
500+ .. testoutput::
501+ >> local_logsoftmax
502+ Log [id A] 1
503+ └─ Softmax{axis=None} [id B] 0
504+ └─ x [id C]
505+ >> local_mul_canonizer
506+ Log [id A] 3
507+ └─ True_div [id B] 2
508+ ├─ Exp [id C] 0
509+ │ └─ x [id D]
510+ └─ Sum{axes=None} [id E] 1
511+ └─ Exp [id C] 0
512+ └─ ···
513+ >> MergeOptimizer
514+ Log [id A] 4
515+ └─ True_div [id B] 3
516+ ├─ Exp [id C] 2
517+ │ └─ x [id D]
518+ └─ Sum{axes=None} [id E] 1
519+ └─ Exp [id F] 0
520+ └─ x [id D]
521+
522+
523+ .. testcode::
524+ # Or go to any step
525+ pytensor.dprint(history.goto(2))
526+
527+ .. testoutput::
528+ Log [id A] 1
529+ └─ Softmax{axis=None} [id B] 0
530+ └─ x [id C]
531+
532+
533+ """
534+
535+ def __init__ (self ):
536+ self .fw = []
537+ self .bw = []
538+ self .pointer = - 1
539+ self .fg = None
540+
541+ def on_attach (self , fgraph ):
542+ if self .fg is not None :
543+ raise ValueError ("Full History already attached to another fgraph" )
544+ self .fg = fgraph
545+
546+ def on_change_input (self , fgraph , node , i , r , new_r , reason = None ):
547+ self .bw .append (LambdaExtract (fgraph , node , i , r , reason ))
548+ self .fw .append (LambdaExtract (fgraph , node , i , new_r , reason ))
549+ self .pointer += 1
550+
551+ def goto (self , checkpoint ):
552+ """
553+ Reverts the graph to whatever it was at the provided
554+ checkpoint (undoes all replacements). A checkpoint at any
555+ given time can be obtained using self.checkpoint().
556+
557+ """
558+ history_len = len (self .bw )
559+ pointer = self .pointer
560+ assert 0 <= checkpoint <= history_len
561+ verbose = config .optimizer_verbose
562+
563+ # Go backwards
564+ while pointer > checkpoint - 1 :
565+ reverse_fn = self .bw [pointer ]
566+ if verbose :
567+ print (reverse_fn .reason ) # noqa: T201
568+ reverse_fn ()
569+ pointer -= 1
570+
571+ # Go forward
572+ while pointer < checkpoint - 1 :
573+ pointer += 1
574+ forward_fn = self .fw [pointer ]
575+ if verbose :
576+ print (forward_fn .reason ) # noqa: T201
577+ forward_fn ()
578+
579+ # Remove history changes caused by the foward/backward!
580+ self .bw = self .bw [:history_len ]
581+ self .fw = self .fw [:history_len ]
582+ self .pointer = pointer
583+ return self .fg
584+
585+ def start (self ):
586+ return self .goto (0 )
587+
588+ def end (self ):
589+ return self .goto (len (self .bw ))
590+
591+ def prev (self ):
592+ if self .pointer < 0 :
593+ return self .fg
594+ else :
595+ return self .goto (self .pointer )
596+
597+ def next (self ):
598+ if self .pointer >= len (self .bw ) - 1 :
599+ return self .fg
600+ else :
601+ return self .goto (self .pointer + 2 )
602+
603+
441604class Validator (Feature ):
442605 pickle_rm_attr = ["validate" , "consistent" ]
443606
0 commit comments