Skip to content

Commit cd213d1

Browse files
committed
Use single tracks in WalkingGraphRewriter
1 parent 538a5e7 commit cd213d1

File tree

1 file changed

+70
-6
lines changed

1 file changed

+70
-6
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,25 +2008,89 @@ def __init__(
20082008
if order not in valid_orders:
20092009
raise ValueError(f"order must be one of {valid_orders}, got {order}")
20102010
self.order = order
2011+
# Use tracks functionality to pre-filter nodes, if it's a single Op or Op type
2012+
tracks = node_rewriter.tracks()
2013+
self.tracks = tracks[0] if (tracks is not None and len(tracks) == 1) else None
20112014
super().__init__(node_rewriter, ignore_newtrees, failure_callback)
20122015

20132016
def apply(self, fgraph, start_from=None):
20142017
if start_from is None:
20152018
start_from = fgraph.outputs
20162019
callback_before = fgraph.execute_callbacks_time
2017-
nb_nodes_start = len(fgraph.apply_nodes)
2020+
apply_nodes = fgraph.apply_nodes
2021+
nb_nodes_start = len(apply_nodes)
20182022
t0 = time.perf_counter()
2019-
q = deque(
2023+
if (tracks := self.tracks) is not None:
2024+
# Pre-filter nodes to consider based on tracks
2025+
if isinstance(tracks, Op):
2026+
# Equality
2027+
candidate_nodes = {
2028+
node for node in fgraph.apply_nodes if node.op == tracks
2029+
}
2030+
elif isinstance(tracks, OpPattern):
2031+
candidate_nodes = {
2032+
node
2033+
for node in fgraph.apply_nodes
2034+
if tracks.match_op(node.op) is not False
2035+
}
2036+
else:
2037+
# isinstance
2038+
candidate_nodes = {
2039+
node for node in fgraph.apply_nodes if isinstance(node.op, tracks)
2040+
}
2041+
2042+
if not candidate_nodes:
2043+
# Abort early
2044+
return (
2045+
self,
2046+
0, # nodes changed
2047+
nb_nodes_start,
2048+
nb_nodes_start, # nb_nodes_end
2049+
time.perf_counter() - t0, # io_t
2050+
0, # loop_t
2051+
0, # callback_time
2052+
self.node_rewriter,
2053+
)
2054+
2055+
if isinstance(tracks, Op):
2056+
2057+
def importer(node):
2058+
if node is not current_node and node.op == tracks:
2059+
q.append(node)
2060+
2061+
elif isinstance(tracks, OpPattern):
2062+
2063+
def importer(node):
2064+
if (
2065+
node is not current_node
2066+
and tracks.match_op(node.op) is not False
2067+
):
2068+
q.append(node)
2069+
2070+
else:
2071+
2072+
def importer(node):
2073+
if node is not current_node and isinstance(node.op, tracks):
2074+
q.append(node)
2075+
else:
2076+
# Otherwise, we will call the node_rewriter on every node in the graph
2077+
candidate_nodes = None
2078+
2079+
def importer(node):
2080+
if node is not current_node:
2081+
q.append(node)
2082+
2083+
node_iterator = (
20202084
apply_ancestors(start_from)
20212085
if (self.order == "dfs")
20222086
else toposort(start_from)
20232087
)
2088+
if candidate_nodes:
2089+
q = deque(node for node in node_iterator if node in candidate_nodes)
2090+
else:
2091+
q = deque(node_iterator)
20242092
io_t = time.perf_counter() - t0
20252093

2026-
def importer(node):
2027-
if node is not current_node:
2028-
q.append(node)
2029-
20302094
u = self.attach_updater(
20312095
fgraph, importer, None, name=getattr(self, "name", None)
20322096
)

0 commit comments

Comments
 (0)