@@ -2008,25 +2008,89 @@ def __init__(
20082008 if order not in valid_orders :
20092009 raise ValueError (f"order must be one of { valid_orders } , got { order } " )
20102010 self .order = order
2011+ # Use tracks functionality to pre-filter nodes, if it's a single Op or Op type
2012+ tracks = node_rewriter .tracks ()
2013+ self .tracks = tracks [0 ] if (tracks is not None and len (tracks ) == 1 ) else None
20112014 super ().__init__ (node_rewriter , ignore_newtrees , failure_callback )
20122015
20132016 def apply (self , fgraph , start_from = None ):
20142017 if start_from is None :
20152018 start_from = fgraph .outputs
20162019 callback_before = fgraph .execute_callbacks_time
2017- nb_nodes_start = len (fgraph .apply_nodes )
2020+ apply_nodes = fgraph .apply_nodes
2021+ nb_nodes_start = len (apply_nodes )
20182022 t0 = time .perf_counter ()
2019- q = deque (
2023+ if (tracks := self .tracks ) is not None :
2024+ # Pre-filter nodes to consider based on tracks
2025+ if isinstance (tracks , Op ):
2026+ # Equality
2027+ candidate_nodes = {
2028+ node for node in fgraph .apply_nodes if node .op == tracks
2029+ }
2030+ elif isinstance (tracks , OpPattern ):
2031+ candidate_nodes = {
2032+ node
2033+ for node in fgraph .apply_nodes
2034+ if tracks .match_op (node .op ) is not False
2035+ }
2036+ else :
2037+ # isinstance
2038+ candidate_nodes = {
2039+ node for node in fgraph .apply_nodes if isinstance (node .op , tracks )
2040+ }
2041+
2042+ if not candidate_nodes :
2043+ # Abort early
2044+ return (
2045+ self ,
2046+ 0 , # nodes changed
2047+ nb_nodes_start ,
2048+ nb_nodes_start , # nb_nodes_end
2049+ time .perf_counter () - t0 , # io_t
2050+ 0 , # loop_t
2051+ 0 , # callback_time
2052+ self .node_rewriter ,
2053+ )
2054+
2055+ if isinstance (tracks , Op ):
2056+
2057+ def importer (node ):
2058+ if node is not current_node and node .op == tracks :
2059+ q .append (node )
2060+
2061+ elif isinstance (tracks , OpPattern ):
2062+
2063+ def importer (node ):
2064+ if (
2065+ node is not current_node
2066+ and tracks .match_op (node .op ) is not False
2067+ ):
2068+ q .append (node )
2069+
2070+ else :
2071+
2072+ def importer (node ):
2073+ if node is not current_node and isinstance (node .op , tracks ):
2074+ q .append (node )
2075+ else :
2076+ # Otherwise, we will call the node_rewriter on every node in the graph
2077+ candidate_nodes = None
2078+
2079+ def importer (node ):
2080+ if node is not current_node :
2081+ q .append (node )
2082+
2083+ node_iterator = (
20202084 apply_ancestors (start_from )
20212085 if (self .order == "dfs" )
20222086 else toposort (start_from )
20232087 )
2088+ if candidate_nodes :
2089+ q = deque (node for node in node_iterator if node in candidate_nodes )
2090+ else :
2091+ q = deque (node_iterator )
20242092 io_t = time .perf_counter () - t0
20252093
2026- def importer (node ):
2027- if node is not current_node :
2028- q .append (node )
2029-
20302094 u = self .attach_updater (
20312095 fgraph , importer , None , name = getattr (self , "name" , None )
20322096 )
0 commit comments