Skip to content

Commit 6dc0b8a

Browse files
committed
Use single tracks in WalkingGraphRewriter
1 parent 0f426c3 commit 6dc0b8a

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,25 +1832,73 @@ def __init__(
18321832
if order not in valid_orders:
18331833
raise ValueError(f"order must be one of {valid_orders}, got {order}")
18341834
self.order = order
1835+
# Use tracks functionality to pre-filter nodes, if it's a single Op or Op type
1836+
tracks = node_rewriter.tracks()
1837+
self.tracks = tracks[0] if (tracks is not None and len(tracks) == 1) else None
18351838
super().__init__(node_rewriter, ignore_newtrees, failure_callback)
18361839

18371840
def apply(self, fgraph, start_from=None):
18381841
if start_from is None:
18391842
start_from = fgraph.outputs
18401843
callback_before = fgraph.execute_callbacks_time
1841-
nb_nodes_start = len(fgraph.apply_nodes)
1844+
apply_nodes = fgraph.apply_nodes
1845+
nb_nodes_start = len(apply_nodes)
18421846
t0 = time.perf_counter()
1843-
q = deque(
1847+
if (tracks := self.tracks) is not None:
1848+
# Pre-filter nodes to consider based on tracks
1849+
if isinstance(tracks, Op):
1850+
# Equality
1851+
candidate_nodes = {
1852+
node for node in fgraph.apply_nodes if node.op == tracks
1853+
}
1854+
else:
1855+
# isinstance
1856+
candidate_nodes = {
1857+
node for node in fgraph.apply_nodes if isinstance(node.op, tracks)
1858+
}
1859+
if not candidate_nodes:
1860+
# Abort early
1861+
return (
1862+
self,
1863+
0, # nodes changed
1864+
nb_nodes_start,
1865+
nb_nodes_start, # nb_nodes_end
1866+
time.perf_counter() - t0, # io_t
1867+
0, # loop_t
1868+
0, # callback_time
1869+
self.node_rewriter,
1870+
)
1871+
1872+
if isinstance(tracks, Op):
1873+
1874+
def importer(node):
1875+
if node is not current_node and node.op == tracks:
1876+
q.append(node)
1877+
1878+
else:
1879+
1880+
def importer(node):
1881+
if node is not current_node and isinstance(node.op, tracks):
1882+
q.append(node)
1883+
else:
1884+
# Otherwise, we will call the node_rewriter on every node in the graph
1885+
candidate_nodes = None
1886+
1887+
def importer(node):
1888+
if node is not current_node:
1889+
q.append(node)
1890+
1891+
node_iterator = (
18441892
apply_ancestors(start_from)
18451893
if (self.order == "dfs")
18461894
else toposort(start_from)
18471895
)
1896+
if candidate_nodes:
1897+
q = deque(node for node in node_iterator if node in candidate_nodes)
1898+
else:
1899+
q = deque(node_iterator)
18481900
io_t = time.perf_counter() - t0
18491901

1850-
def importer(node):
1851-
if node is not current_node:
1852-
q.append(node)
1853-
18541902
u = self.attach_updater(
18551903
fgraph, importer, None, name=getattr(self, "name", None)
18561904
)

0 commit comments

Comments
 (0)