@@ -1832,25 +1832,73 @@ def __init__(
18321832 if order not in valid_orders :
18331833 raise ValueError (f"order must be one of { valid_orders } , got { order } " )
18341834 self .order = order
1835+ # Use tracks functionality to pre-filter nodes, if it's a single Op or Op type
1836+ tracks = node_rewriter .tracks ()
1837+ self .tracks = tracks [0 ] if (tracks is not None and len (tracks ) == 1 ) else None
18351838 super ().__init__ (node_rewriter , ignore_newtrees , failure_callback )
18361839
18371840 def apply (self , fgraph , start_from = None ):
18381841 if start_from is None :
18391842 start_from = fgraph .outputs
18401843 callback_before = fgraph .execute_callbacks_time
1841- nb_nodes_start = len (fgraph .apply_nodes )
1844+ apply_nodes = fgraph .apply_nodes
1845+ nb_nodes_start = len (apply_nodes )
18421846 t0 = time .perf_counter ()
1843- q = deque (
1847+ if (tracks := self .tracks ) is not None :
1848+ # Pre-filter nodes to consider based on tracks
1849+ if isinstance (tracks , Op ):
1850+ # Equality
1851+ candidate_nodes = {
1852+ node for node in fgraph .apply_nodes if node .op == tracks
1853+ }
1854+ else :
1855+ # isinstance
1856+ candidate_nodes = {
1857+ node for node in fgraph .apply_nodes if isinstance (node .op , tracks )
1858+ }
1859+ if not candidate_nodes :
1860+ # Abort early
1861+ return (
1862+ self ,
1863+ 0 , # nodes changed
1864+ nb_nodes_start ,
1865+ nb_nodes_start , # nb_nodes_end
1866+ time .perf_counter () - t0 , # io_t
1867+ 0 , # loop_t
1868+ 0 , # callback_time
1869+ self .node_rewriter ,
1870+ )
1871+
1872+ if isinstance (tracks , Op ):
1873+
1874+ def importer (node ):
1875+ if node is not current_node and node .op == tracks :
1876+ q .append (node )
1877+
1878+ else :
1879+
1880+ def importer (node ):
1881+ if node is not current_node and isinstance (node .op , tracks ):
1882+ q .append (node )
1883+ else :
1884+ # Otherwise, we will call the node_rewriter on every node in the graph
1885+ candidate_nodes = None
1886+
1887+ def importer (node ):
1888+ if node is not current_node :
1889+ q .append (node )
1890+
1891+ node_iterator = (
18441892 apply_ancestors (start_from )
18451893 if (self .order == "dfs" )
18461894 else toposort (start_from )
18471895 )
1896+ if candidate_nodes :
1897+ q = deque (node for node in node_iterator if node in candidate_nodes )
1898+ else :
1899+ q = deque (node_iterator )
18481900 io_t = time .perf_counter () - t0
18491901
1850- def importer (node ):
1851- if node is not current_node :
1852- q .append (node )
1853-
18541902 u = self .attach_updater (
18551903 fgraph , importer , None , name = getattr (self , "name" , None )
18561904 )
0 commit comments