From 3c3ca0b071b26b366fc7c46131b3a69f65ba95c1 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 4 Sep 2025 19:25:51 +0200 Subject: [PATCH 01/33] Benchmark another FusionOptimizer graph --- pytensor/tensor/rewriting/elemwise.py | 12 ++++--- tests/tensor/rewriting/test_elemwise.py | 42 +++++++++++++++++++++---- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e2d420f361..0eb2900729 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -569,8 +569,6 @@ def elemwise_to_scalar(inputs, outputs): return scalar_inputs, scalar_outputs def apply(self, fgraph): - nb_replacement = 0 - if fgraph.profile: validate_before = fgraph.profile.validate_time callbacks_before = fgraph.execute_callbacks_times.copy() @@ -925,6 +923,8 @@ def update_fuseable_mappings_after_fg_replace( starting_nodes=starting_nodes, ) + nb_fused = 0 + nb_replacement = 0 for inputs, outputs in find_next_fuseable_subgraph(fgraph): if (len(inputs) + len(outputs)) > max_operands: warn( @@ -943,11 +943,13 @@ def update_fuseable_mappings_after_fg_replace( if old_out.name: composite_out.name = old_out.name + starting_nodes = len(fgraph.apply_nodes) fgraph.replace_all_validate( list(zip(outputs, composite_outputs, strict=True)), reason=self.__class__.__name__, ) - nb_replacement += 1 + nb_fused += 1 + nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1 if fgraph.profile: validate_time = fgraph.profile.validate_time - validate_before @@ -965,7 +967,7 @@ def update_fuseable_mappings_after_fg_replace( return ( self, - 1, # nb_iter + nb_fused, nb_replacement, 0, # nb_inconsintency_replace validate_time, @@ -978,7 +980,7 @@ def update_fuseable_mappings_after_fg_replace( def print_profile(stream, prof, level=0): blanc = " " * level print(blanc, "FusionOptimizer", file=stream) - print(blanc, " nb_iter", prof[1], file=stream) + print(blanc, " nb_fused", prof[1], file=stream) print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream) print(blanc, " validate_time", prof[4], file=stream) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index c23d0ac23a..3c549788e1 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -273,7 +273,8 @@ def my_init(dtype="float64", num=0): fwx = fw + fx ftanx = tan(fx) - def large_fuseable_graph(self, n): + @staticmethod + def large_fuseable_graph(n): factors = [] sd = dscalar() means = dvector() @@ -296,6 +297,24 @@ def large_fuseable_graph(self, n): dlogp = [pytensor.grad(logp, v) for v in vars] return vars, dlogp + @staticmethod + def deep_small_kernels(n): + x = pt.matrix("x") + out = x + for _ in range(n): + out = pt.sin(out.T) + pt.cos(out) + + return [x], [out] + + @staticmethod + def diamond_graph(n): + a = pt.matrix("a") + b = pt.exp(a) + c = pt.log(b) + d = pt.sin(c) + e = c + d + return [a], [e] + @pytest.mark.parametrize( "case", [ @@ -1347,16 +1366,27 @@ def test_eval_benchmark(self, benchmark): benchmark(func) @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") - def test_rewrite_benchmark(self, benchmark): - inps, outs = self.large_fuseable_graph(n=25) + @pytest.mark.parametrize( + "graph_fn, n, expected_n_repl", + [ + # ("diamond_graph", None, (1, 4)), + ("deep_small_kernels", 20, (20, 60)), + ("large_fuseable_graph", 25, (103, 876)), + ], + ) + def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): + inps, outs = getattr(self, graph_fn)(n) fg = FunctionGraph(inps, outs) opt = FusionOptimizer() def rewrite_func(): - nb_replacement = opt.apply(fg.clone())[2] - return nb_replacement + fg_clone = fg.clone() + _, nb_fused, nb_replacement, *_ = opt.apply(fg_clone) + # fg_clone.dprint() + return nb_fused, nb_replacement - assert benchmark(rewrite_func) == 103 + assert rewrite_func() == expected_n_repl + benchmark.pedantic(rewrite_func, rounds=7, iterations=5) def test_no_warning_from_old_client(self): # There used to be a warning issued when creating fuseable mapping From 71197705a6ecbe83d5bfc61e10f82ac8a2a83d56 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 20 Sep 2025 10:19:40 +0200 Subject: [PATCH 02/33] Short-circuit `as_scalar` common cases faster --- pytensor/scalar/basic.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index f28a1122c8..eef7f15b14 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -987,25 +987,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant: def as_scalar(x: Any, name: str | None = None) -> ScalarVariable: - from pytensor.tensor.basic import scalar_from_tensor - from pytensor.tensor.type import TensorType + if isinstance(x, ScalarVariable): + return x + + if isinstance(x, Variable): + from pytensor.tensor.basic import scalar_from_tensor + from pytensor.tensor.type import TensorType + + if isinstance(x.type, TensorType) and x.type.ndim == 0: + return scalar_from_tensor(x) + else: + raise TypeError(f"Cannot convert {x} to a scalar type") if isinstance(x, Apply): + # FIXME: Why do we support calling this with Apply? + # Also, if we do, why can't we support multiple outputs? if len(x.outputs) != 1: raise ValueError( "It is ambiguous which output of a multi-output" " Op has to be fetched.", x, ) - else: - x = x.outputs[0] - if isinstance(x, Variable): - if isinstance(x, ScalarVariable): - return x - elif isinstance(x.type, TensorType) and x.type.ndim == 0: - return scalar_from_tensor(x) - else: - raise TypeError(f"Cannot convert {x} to a scalar type") + return as_scalar(x.outputs[0]) return constant(x) From b944c9f0b19f24d58085f65cfd5dba633d83c45d Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 19 Sep 2025 01:01:55 +0200 Subject: [PATCH 03/33] Speedup supports c_code Not using `__call__` avoids the test_value computation --- pytensor/scalar/basic.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index eef7f15b14..78d7c044ba 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1333,32 +1333,26 @@ def supports_c_code(self, inputs, outputs): the given Elemwise inputs, outputs. """ - try: - tmp_s_input = [] - # To keep the same aliasing between inputs - mapping = dict() - for ii in inputs: - if ii in mapping: - tmp_s_input.append(mapping[ii]) - else: - tmp = get_scalar_type(ii.dtype).make_variable() - tmp_s_input.append(tmp) - mapping[ii] = tmp_s_input[-1] - - with config.change_flags(compute_test_value="ignore"): - s_op = self(*tmp_s_input, return_list=True) + tmp_s_input = [] + # To keep the same aliasing between inputs + mapping = {} + for ii in inputs: + if ii in mapping: + tmp_s_input.append(mapping[ii]) + else: + tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable() + tmp_s_input.append(tmp) - # if the scalar_op don't have a c implementation, - # we skip its fusion to allow the fusion of the - # other ops. + try: self.c_code( - s_op[0].owner, + self.make_node(*tmp_s_input), "test_presence_of_c_code", + # FIXME: Shouldn't this be a unique name per unique variable? ["x" for x in inputs], ["z" for z in outputs], {"fail": "%(fail)s"}, ) - except (MethodNotDefined, NotImplementedError): + except (NotImplementedError, MethodNotDefined): return False return True From 0e5c76078d8b6491bbfc3852450923f19ef914a8 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 13:33:36 +0200 Subject: [PATCH 04/33] Speedup FusionOptimizer.elemwise_to_scalar --- pytensor/scalar/basic.py | 8 ++-- pytensor/tensor/rewriting/elemwise.py | 55 +++++++++------------------ 2 files changed, 23 insertions(+), 40 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 78d7c044ba..fc43ae411b 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -779,9 +779,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType: This caches objects to save allocation and run time. """ - if dtype not in cache: - cache[dtype] = ScalarType(dtype=dtype) - return cache[dtype] + try: + return cache[dtype] + except KeyError: + cache[dtype] = res = ScalarType(dtype=dtype) + return res # Register C code for ViewOp on Scalars. diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 0eb2900729..1eb3d7c037 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -28,7 +28,7 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern -from pytensor.graph.traversal import ancestors +from pytensor.graph.traversal import ancestors, toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( @@ -530,43 +530,24 @@ def add_requirements(self, fgraph): @staticmethod def elemwise_to_scalar(inputs, outputs): - replace_inputs = [(inp, inp.clone()) for inp in inputs] - outputs = clone_replace(outputs, replace=replace_inputs) - - inputs = [inp for _, inp in replace_inputs] - fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) - middle_inputs = [] - - scalar_inputs = [ - ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs - ] - middle_scalar_inputs = [] - - for node in fg.toposort(): - node_scalar_inputs = [] - for inp in node.inputs: - if inp in inputs: - node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) - elif inp in middle_inputs: - node_scalar_inputs.append( - middle_scalar_inputs[middle_inputs.index(inp)] + replacement = { + inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + } + for node in toposort(outputs, blockers=inputs): + scalar_inputs = [replacement[inp] for inp in node.inputs] + replacement.update( + dict( + zip( + node.outputs, + node.op.scalar_op.make_node(*scalar_inputs).outputs, ) - else: - new_scalar_input = ps.get_scalar_type( - inp.type.dtype - ).make_variable() - node_scalar_inputs.append(new_scalar_input) - middle_scalar_inputs.append(new_scalar_input) - middle_inputs.append(inp) - - new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) - middle_scalar_inputs.append(new_scalar_node.outputs[0]) - middle_inputs.append(node.outputs[0]) - - scalar_outputs = [ - middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs - ] - return scalar_inputs, scalar_outputs + ) + ) + + return ( + [replacement[inp] for inp in inputs], + [replacement[out] for out in outputs], + ) def apply(self, fgraph): if fgraph.profile: From 62de41971b44bfd0fe560d7ff5ffa634d56a5eec Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 20 Sep 2025 10:05:06 +0200 Subject: [PATCH 05/33] Avoid double cloning of Composite Ops created by FusionOptimizer --- pytensor/scalar/basic.py | 19 ++++++++++++------- pytensor/tensor/rewriting/elemwise.py | 13 +++++++------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index fc43ae411b..9b27c369f3 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -13,7 +13,6 @@ import builtins import math from collections.abc import Callable -from copy import copy from itertools import chain from textwrap import dedent from typing import Any, TypeAlias @@ -4094,12 +4093,12 @@ def __init__(self, *args, **kwargs): self.prepare_node_called = set() super().__init__(*args, **kwargs) - def _cleanup_graph(self, inputs, outputs): + def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True): # TODO: We could convert to TensorVariable, optimize graph, # and then convert back to ScalarVariable. # This would introduce rewrites like `log(1 + x) -> log1p`. - fgraph = FunctionGraph(copy(inputs), copy(outputs)) + fgraph = FunctionGraph(inputs, outputs, clone=clone) # Validate node types for node in fgraph.apply_nodes: @@ -4282,7 +4281,9 @@ class Composite(ScalarInnerGraphOp): init_param: tuple[str, ...] = ("inputs", "outputs") - def __init__(self, inputs, outputs, name="Composite"): + def __init__( + self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True + ): self.name = name self._name = None # We need to clone the graph as sometimes its nodes already @@ -4300,10 +4301,13 @@ def __init__(self, inputs, outputs, name="Composite"): if len(outputs) > 1 or not any( isinstance(var.owner.op, Composite) for var in outputs ): - # No inner Composite - inputs, outputs = clone(inputs, outputs) + if clone_graph: + inputs, outputs = clone(inputs, outputs) + else: # Inner Composite that we need to flatten + # FIXME: There could be a composite in the middle of the graph, why is this here? + # If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway. assert len(outputs) == 1 # 1. Create a new graph from inputs up to the # Composite @@ -4322,7 +4326,8 @@ def __init__(self, inputs, outputs, name="Composite"): assert res[0] != inputs inputs, outputs = res[0], res2[1] - self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) + # We already cloned the graph, or the user told us there was no need for it + self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False) self.inputs_type = tuple(input.type for input in self.inputs) self.outputs_type = tuple(output.type for output in self.outputs) self.nin = len(inputs) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 1eb3d7c037..42f4b6fc67 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -915,12 +915,13 @@ def update_fuseable_mappings_after_fg_replace( break scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) - composite_outputs = Elemwise(ps.Composite(scalar_inputs, scalar_outputs))( - *inputs - ) - if not isinstance(composite_outputs, list): - composite_outputs = [composite_outputs] - for old_out, composite_out in zip(outputs, composite_outputs, strict=True): + composite_outputs = Elemwise( + # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables + ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False) + )(*inputs, return_list=True) + assert len(outputs) == len(composite_outputs) + for old_out, composite_out in zip(outputs, composite_outputs): + # Preserve any names on the original outputs if old_out.name: composite_out.name = old_out.name From 0337dce381eeff981ded9e788f627bf964465e04 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 18:06:06 +0200 Subject: [PATCH 06/33] Do not recompute toposort in every iteration of FusionOptimizer It's not really needed as we never expand on the new nodes --- pytensor/tensor/rewriting/elemwise.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 42f4b6fc67..689b47c28d 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -625,10 +625,10 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool: def find_fuseable_subgraph( *, - fg: FunctionGraph, visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, + toposort_index: dict[Apply, int], ) -> tuple[list[Variable], list[Variable]]: KT = TypeVar("KT") VT = TypeVar("VT", list, set) @@ -648,8 +648,7 @@ def variables_depend_on( for a in ancestors(variables, blockers=stop_search_at) ) - toposort = fg.toposort() - for starting_node in toposort: + for starting_node in toposort_index: if starting_node in visited_nodes: continue @@ -791,7 +790,7 @@ def variables_depend_on( and inp.owner not in visited_nodes ) ), - key=lambda inp: toposort.index(inp.owner), + key=lambda inp: toposort_index[inp.owner], reverse=True, ): fuseable_nodes_to_visit.appendleft(inp.owner) @@ -803,7 +802,7 @@ def variables_depend_on( for node in fuseable_clients_temp.get(next_out, ()) if node not in visited_nodes ), - key=lambda node: toposort.index(node), + key=lambda node: toposort_index[node], ): fuseable_nodes_to_visit.append(next_node) @@ -877,20 +876,22 @@ def update_fuseable_mappings_after_fg_replace( # client (those that don't fit into 1)) fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) visited_nodes: set[Apply] = set() + toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} while True: - starting_nodes = fg.apply_nodes.copy() try: subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( - fg=fg, visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, + toposort_index=toposort_index, ) except ValueError: return else: # The caller is now expected to update fg in place, # by replacing the subgraph with a Composite Op + starting_nodes = fg.apply_nodes.copy() + yield subgraph_inputs, subgraph_outputs # This is where we avoid repeated work by using a stateful From 42de0ea828c5519da7c0a852d5c8b94b3361eb5e Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 23:49:50 +0200 Subject: [PATCH 07/33] Cleanup FusionOptimizer code --- pytensor/tensor/rewriting/elemwise.py | 164 ++++++++++++-------------- 1 file changed, 76 insertions(+), 88 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 689b47c28d..d37f04feb5 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -5,7 +5,7 @@ from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce -from typing import TypeVar +from typing import Literal from warnings import warn import pytensor.scalar.basic as ps @@ -568,8 +568,7 @@ def find_next_fuseable_subgraph( This generator assumes that such subgraph is replaced by a single Elemwise Composite before being accessed again in the next iteration. """ - - FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]] + FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] def initialize_fuseable_mappings( @@ -591,35 +590,33 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool: # to ensure the rewrite remains deterministic. # This is not a problem from unfuseable ones, as they can never # become part of the graph. - fuseable_clients: FUSEABLE_MAPPING = defaultdict(list) + fuseable_clients: FUSEABLE_MAPPING = defaultdict(set) unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) for out, clients in fg.clients.items(): - # Old FunctionGraph nodes remain in the clients dictionary - # even after they are removed by rewrites - if not clients: - continue - out_maybe_fuseable = ( - out.owner + out.owner is not None and isinstance(out.owner.op, Elemwise) # and not isinstance(out.owner.op.scalar_op, ps.Composite) and len(out.owner.outputs) == 1 and elemwise_scalar_op_has_c_code(out.owner) ) - for client, _ in clients: - if ( - out_maybe_fuseable - and isinstance(client.op, Elemwise) - # and not isinstance(client.op.scalar_op, ps.Composite) - and len(client.outputs) == 1 - and out.type.broadcastable - == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) - ): - if client not in fuseable_clients[out]: - fuseable_clients[out].append(client) - else: - unfuseable_clients[out].add(client) + if out_maybe_fuseable: + out_bcast = ( + out.type.broadcastable if out_maybe_fuseable else None + ) + for client, _ in clients: + if ( + isinstance(client.op, Elemwise) + # and not isinstance(client.op.scalar_op, ps.Composite) + and len(client.outputs) == 1 + and out_bcast == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ): + fuseable_clients[out].add(client) + else: + unfuseable_clients[out].add(client) + else: + unfuseable_clients[out] = {client for client, _ in clients} return fuseable_clients, unfuseable_clients @@ -630,16 +627,6 @@ def find_fuseable_subgraph( unfuseable_clients: UNFUSEABLE_MAPPING, toposort_index: dict[Apply, int], ) -> tuple[list[Variable], list[Variable]]: - KT = TypeVar("KT") - VT = TypeVar("VT", list, set) - - def shallow_clone_defaultdict( - d: defaultdict[KT, VT], - ) -> defaultdict[KT, VT]: - new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory) - new_dict.update({k: v.copy() for k, v in d.items()}) - return new_dict - def variables_depend_on( variables, depend_on, stop_search_at=None ) -> bool: @@ -657,17 +644,19 @@ def variables_depend_on( visited_nodes.add(starting_node) continue - subgraph_inputs: list[Variable] = [] - subgraph_outputs: list[Variable] = [] + subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set + subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set unfuseable_clients_subgraph: set[Variable] = set() # Shallow cloning of maps so that they can be manipulated in place - fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients) - unfuseable_clients_clone = shallow_clone_defaultdict( - unfuseable_clients + fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set) + fuseable_clients_clone.update( + {k: v.copy() for k, v in fuseable_clients.items()} + ) + unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set) + unfuseable_clients_clone.update( + {k: v.copy() for k, v in unfuseable_clients.items()} ) - - fuseable_nodes_to_visit = deque([starting_node]) # We now try to expand as much as possible towards the potentially # fuseable clients and ancestors to detect the largest possible @@ -676,6 +665,7 @@ def variables_depend_on( # some inputs or clients may depend on other nodes of the same # subgraph via a path that cannot be included in the Composite # (unfuseable) + fuseable_nodes_to_visit = deque([starting_node]) while fuseable_nodes_to_visit: next_node = fuseable_nodes_to_visit.popleft() visited_nodes.add(next_node) @@ -684,15 +674,14 @@ def variables_depend_on( # If the output variable of next_node has no fuseable clients # or has unfuseable clients, then next_node must become an output # if it is to be fused. - must_become_output = ( - next_out not in fuseable_clients_temp - or next_out in unfuseable_clients_clone - ) + must_become_output = not fuseable_clients_clone.get( + next_out + ) or unfuseable_clients_clone.get(next_out) # We have backtracked to this node, and it may no longer be a viable output, # so we remove it and check again as if we had never seen this node - if must_become_output and next_out in subgraph_outputs: - subgraph_outputs.remove(next_out) + if must_become_output: + subgraph_outputs.pop(next_out, None) required_unfuseable_inputs = [ inp @@ -744,18 +733,19 @@ def variables_depend_on( if ( inp.owner in visited_nodes # next_node could have the same input repeated - and next_node in fuseable_clients_temp[inp] + and next_node in fuseable_clients_clone[inp] ): - fuseable_clients_temp[inp].remove(next_node) + fuseable_clients_clone[inp].remove(next_node) unfuseable_clients_clone[inp].add(next_node) # This input must become an output of the subgraph, # because it can't be merged with next_node. # We will revisit it to make sure this is safe. fuseable_nodes_to_visit.appendleft(inp.owner) - for client in fuseable_clients_temp[next_out]: + # need to convert to tuple not to change set size during iteration + for client in tuple(fuseable_clients_clone[next_out]): if client in visited_nodes: - fuseable_clients_temp[next_out].remove(client) + fuseable_clients_clone[next_out].remove(client) unfuseable_clients_clone[next_out].add(client) # next_out must become an input of the subgraph. # We will revisit any of its clients currently @@ -771,74 +761,72 @@ def variables_depend_on( # mappings as if it next_node was part of it. # Useless inputs will be removed by the useless Composite rewrite for inp in new_required_unfuseable_inputs: - if inp not in subgraph_inputs: - subgraph_inputs.append(inp) + subgraph_inputs[inp] = None if must_become_output: - subgraph_outputs.append(next_out) + subgraph_outputs[next_out] = None unfuseable_clients_subgraph.update( new_implied_unfuseable_clients ) # Expand through unvisited fuseable ancestors - for inp in sorted( - ( - inp - for inp in next_node.inputs - if ( - inp not in required_unfuseable_inputs - and inp.owner not in visited_nodes - ) - ), - key=lambda inp: toposort_index[inp.owner], - reverse=True, - ): - fuseable_nodes_to_visit.appendleft(inp.owner) + fuseable_nodes_to_visit.extendleft( + sorted( + ( + inp.owner + for inp in next_node.inputs + if ( + inp not in required_unfuseable_inputs + and inp.owner not in visited_nodes + ) + ), + key=toposort_index.get, # type: ignore[arg-type] + ) + ) # Expand through unvisited fuseable clients - for next_node in sorted( - ( - node - for node in fuseable_clients_temp.get(next_out, ()) - if node not in visited_nodes - ), - key=lambda node: toposort_index[node], - ): - fuseable_nodes_to_visit.append(next_node) + fuseable_nodes_to_visit.extend( + sorted( + ( + node + for node in fuseable_clients_clone.get(next_out, ()) + if node not in visited_nodes + ), + key=toposort_index.get, # type: ignore[arg-type] + ) + ) # Don't return if final subgraph is just the original Elemwise if len(subgraph_outputs) == 1 and set( - subgraph_outputs[0].owner.inputs + next(iter(subgraph_outputs)).owner.inputs ) == set(subgraph_inputs): # Update global fuseable mappings # No input was actually fuseable for inp in starting_node.inputs: - if starting_node in fuseable_clients.get(inp, ()): - fuseable_clients[inp].remove(starting_node) - unfuseable_clients[inp].add(starting_node) + fuseable_clients[inp].discard(starting_node) + unfuseable_clients[inp].add(starting_node) # No client was actually fuseable unfuseable_clients[starting_out].update( fuseable_clients.pop(starting_out, ()) ) continue - return subgraph_inputs, subgraph_outputs + return list(subgraph_inputs), list(subgraph_outputs) raise ValueError def update_fuseable_mappings_after_fg_replace( *, - fg: FunctionGraph, visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, starting_nodes: set[Apply], + updated_nodes: set[Apply], ) -> None: # Find new composite node and dropped intermediate nodes # by comparing the current fg.apply nodes with the cached # original nodes - next_nodes = fg.apply_nodes - (new_composite_node,) = next_nodes - starting_nodes - dropped_nodes = starting_nodes - next_nodes + (new_composite_node,) = updated_nodes - starting_nodes + dropped_nodes = starting_nodes - updated_nodes # Remove intermediate Composite nodes from mappings for dropped_node in dropped_nodes: @@ -850,11 +838,11 @@ def update_fuseable_mappings_after_fg_replace( # Update fuseable information for subgraph inputs for inp in subgraph_inputs: if inp in fuseable_clients: - new_fuseable_clients = [ + new_fuseable_clients = { client for client in fuseable_clients[inp] if client not in dropped_nodes - ] + } if new_fuseable_clients: fuseable_clients[inp] = new_fuseable_clients else: @@ -898,11 +886,11 @@ def update_fuseable_mappings_after_fg_replace( # generator. For large models (as in `TestFusion.test_big_fusion`) # this can provide huge speedups update_fuseable_mappings_after_fg_replace( - fg=fg, visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, starting_nodes=starting_nodes, + updated_nodes=fg.apply_nodes, ) nb_fused = 0 From ca607bfb2624a057c5a334e37d1c900af0f9bcdf Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 23:49:50 +0200 Subject: [PATCH 08/33] Copy on write in FusionOptimizer --- pytensor/tensor/rewriting/elemwise.py | 82 ++++++++++++++++++++------- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index d37f04feb5..7d65ce5f95 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,6 +2,7 @@ import itertools import operator import sys +import typing from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce @@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int: return 1024 +class CopyOnWriteDictOfSets: + __slots__ = ("d", "d_copy") + + def __init__(self, d: dict[typing.Any, set]): + self.d = d + self.d_copy: dict[typing.Any, set] = {} + + def __getitem__(self, key): + try: + return self.d_copy[key] + except KeyError: + return self.d[key] + + def get(self, key, default=frozenset()): + try: + return self.d_copy[key] + except KeyError: + try: + return self.d[key] + except KeyError: + return default + + def remove_from_key(self, key, value): + try: + self.d_copy[key].remove(value) + except KeyError: + self.d_copy[key] = copied_value = self.d[key].copy() + copied_value.remove(value) + + def add_to_key(self, key, value): + try: + self.d_copy[key].add(value) + except KeyError: + self.d_copy[key] = copied_value = self.d[key].copy() + copied_value.add(value) + + class FusionOptimizer(GraphRewriter): """Graph optimizer that fuses consecutive Elemwise operations.""" @@ -648,15 +686,10 @@ def variables_depend_on( subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set unfuseable_clients_subgraph: set[Variable] = set() - # Shallow cloning of maps so that they can be manipulated in place - fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set) - fuseable_clients_clone.update( - {k: v.copy() for k, v in fuseable_clients.items()} - ) - unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set) - unfuseable_clients_clone.update( - {k: v.copy() for k, v in unfuseable_clients.items()} - ) + # If we need to manipulate the maps in place, we'll do a shallow copy later + # For now we query on the original ones + fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients) + unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients) # We now try to expand as much as possible towards the potentially # fuseable clients and ancestors to detect the largest possible @@ -686,7 +719,7 @@ def variables_depend_on( required_unfuseable_inputs = [ inp for inp in next_node.inputs - if next_node in unfuseable_clients_clone.get(inp, ()) + if next_node in unfuseable_clients_clone.get(inp) ] new_required_unfuseable_inputs = [ inp @@ -709,7 +742,7 @@ def variables_depend_on( if not must_backtrack: implied_unfuseable_clients = { c - for client in unfuseable_clients_clone.get(next_out, ()) + for client in unfuseable_clients_clone.get(next_out) if not isinstance(client.op, Output) for c in client.outputs } @@ -730,13 +763,15 @@ def variables_depend_on( if must_backtrack: for inp in next_node.inputs: - if ( - inp.owner in visited_nodes - # next_node could have the same input repeated - and next_node in fuseable_clients_clone[inp] - ): - fuseable_clients_clone[inp].remove(next_node) - unfuseable_clients_clone[inp].add(next_node) + if inp.owner in visited_nodes: + if next_node not in fuseable_clients_clone[inp]: + # This can happen when next node has repeated inputs + continue + fuseable_clients_clone.remove_from_key( + inp, next_node + ) + unfuseable_clients_clone.add_to_key(inp, next_node) + # This input must become an output of the subgraph, # because it can't be merged with next_node. # We will revisit it to make sure this is safe. @@ -745,8 +780,13 @@ def variables_depend_on( # need to convert to tuple not to change set size during iteration for client in tuple(fuseable_clients_clone[next_out]): if client in visited_nodes: - fuseable_clients_clone[next_out].remove(client) - unfuseable_clients_clone[next_out].add(client) + fuseable_clients_clone.remove_from_key( + next_out, client + ) + unfuseable_clients_clone.add_to_key( + next_out, client + ) + # next_out must become an input of the subgraph. # We will revisit any of its clients currently # in the subgraph to make sure this is safe. @@ -789,7 +829,7 @@ def variables_depend_on( sorted( ( node - for node in fuseable_clients_clone.get(next_out, ()) + for node in fuseable_clients_clone.get(next_out) if node not in visited_nodes ), key=toposort_index.get, # type: ignore[arg-type] From 4364beefdeb2f7cf454ce75a5a49b6ff5ab179a6 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 18:57:02 +0200 Subject: [PATCH 09/33] Use bitset to check ancestors more efficiently --- pytensor/tensor/rewriting/elemwise.py | 139 +++++++++++++------------- tests/test_printing.py | 14 +-- 2 files changed, 77 insertions(+), 76 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 7d65ce5f95..77cf934705 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -6,6 +6,7 @@ from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce +from operator import or_ from typing import Literal from warnings import warn @@ -29,7 +30,7 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern -from pytensor.graph.traversal import ancestors, toposort +from pytensor.graph.traversal import toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( @@ -663,16 +664,9 @@ def find_fuseable_subgraph( visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, + ancestors_bitset: dict[Apply, int], toposort_index: dict[Apply, int], ) -> tuple[list[Variable], list[Variable]]: - def variables_depend_on( - variables, depend_on, stop_search_at=None - ) -> bool: - return any( - a in depend_on - for a in ancestors(variables, blockers=stop_search_at) - ) - for starting_node in toposort_index: if starting_node in visited_nodes: continue @@ -684,7 +678,8 @@ def variables_depend_on( subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set - unfuseable_clients_subgraph: set[Variable] = set() + subgraph_inputs_ancestors_bitset = 0 + unfuseable_clients_subgraph_bitset = 0 # If we need to manipulate the maps in place, we'll do a shallow copy later # For now we query on the original ones @@ -716,50 +711,32 @@ def variables_depend_on( if must_become_output: subgraph_outputs.pop(next_out, None) - required_unfuseable_inputs = [ - inp - for inp in next_node.inputs - if next_node in unfuseable_clients_clone.get(inp) - ] - new_required_unfuseable_inputs = [ - inp - for inp in required_unfuseable_inputs - if inp not in subgraph_inputs - ] - - must_backtrack = False - if new_required_unfuseable_inputs and subgraph_outputs: - # We need to check that any new inputs required by this node - # do not depend on other outputs of the current subgraph, - # via an unfuseable path. - if variables_depend_on( - [next_out], - depend_on=unfuseable_clients_subgraph, - stop_search_at=subgraph_outputs, - ): - must_backtrack = True + # We need to check that any inputs required by this node + # do not depend on other outputs of the current subgraph, + # via an unfuseable path. + must_backtrack = ( + ancestors_bitset[next_node] + & unfuseable_clients_subgraph_bitset + ) if not must_backtrack: - implied_unfuseable_clients = { - c - for client in unfuseable_clients_clone.get(next_out) - if not isinstance(client.op, Output) - for c in client.outputs - } - - new_implied_unfuseable_clients = ( - implied_unfuseable_clients - unfuseable_clients_subgraph + implied_unfuseable_clients_bitset = reduce( + or_, + ( + 1 << toposort_index[client] + for client in unfuseable_clients_clone.get(next_out) + if not isinstance(client.op, Output) + ), + 0, ) - if new_implied_unfuseable_clients and subgraph_inputs: - # We need to check that any inputs of the current subgraph - # do not depend on other clients of this node, - # via an unfuseable path. - if variables_depend_on( - subgraph_inputs, - depend_on=new_implied_unfuseable_clients, - ): - must_backtrack = True + # We need to check that any inputs of the current subgraph + # do not depend on other clients of this node, + # via an unfuseable path. + must_backtrack = ( + subgraph_inputs_ancestors_bitset + & implied_unfuseable_clients_bitset + ) if must_backtrack: for inp in next_node.inputs: @@ -800,29 +777,24 @@ def variables_depend_on( # immediate dependency problems. Update subgraph # mappings as if it next_node was part of it. # Useless inputs will be removed by the useless Composite rewrite - for inp in new_required_unfuseable_inputs: - subgraph_inputs[inp] = None - if must_become_output: subgraph_outputs[next_out] = None - unfuseable_clients_subgraph.update( - new_implied_unfuseable_clients + unfuseable_clients_subgraph_bitset |= ( + implied_unfuseable_clients_bitset ) - # Expand through unvisited fuseable ancestors - fuseable_nodes_to_visit.extendleft( - sorted( - ( - inp.owner - for inp in next_node.inputs - if ( - inp not in required_unfuseable_inputs - and inp.owner not in visited_nodes - ) - ), - key=toposort_index.get, # type: ignore[arg-type] - ) - ) + for inp in sorted( + next_node.inputs, + key=lambda x: toposort_index.get(x.owner, -1), + ): + if next_node in unfuseable_clients_clone.get(inp, ()): + # input must become an input of the subgraph since it's unfuseable with new node + subgraph_inputs_ancestors_bitset |= ( + ancestors_bitset.get(inp.owner, 0) + ) + subgraph_inputs[inp] = None + elif inp.owner not in visited_nodes: + fuseable_nodes_to_visit.appendleft(inp.owner) # Expand through unvisited fuseable clients fuseable_nodes_to_visit.extend( @@ -859,6 +831,8 @@ def update_fuseable_mappings_after_fg_replace( visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, + toposort_index: dict[Apply, int], + ancestors_bitset: dict[Apply, int], starting_nodes: set[Apply], updated_nodes: set[Apply], ) -> None: @@ -869,11 +843,25 @@ def update_fuseable_mappings_after_fg_replace( dropped_nodes = starting_nodes - updated_nodes # Remove intermediate Composite nodes from mappings + # And compute the ancestors bitset of the new composite node + # As well as the new toposort index for the new node + new_node_ancestor_bitset = 0 + new_node_toposort_index = len(toposort_index) for dropped_node in dropped_nodes: (dropped_out,) = dropped_node.outputs fuseable_clients.pop(dropped_out, None) unfuseable_clients.pop(dropped_out, None) visited_nodes.remove(dropped_node) + # The new composite ancestor bitset is the union + # of the ancestors of all the dropped nodes + new_node_ancestor_bitset |= ancestors_bitset[dropped_node] + # The new composite node can have the same order as the latest node that was absorbed into it + new_node_toposort_index = max( + new_node_toposort_index, toposort_index[dropped_node] + ) + + ancestors_bitset[new_composite_node] = new_node_ancestor_bitset + toposort_index[new_composite_node] = new_node_toposort_index # Update fuseable information for subgraph inputs for inp in subgraph_inputs: @@ -905,12 +893,23 @@ def update_fuseable_mappings_after_fg_replace( fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) visited_nodes: set[Apply] = set() toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} + # Create a bitset for each node of all its ancestors + # This allows to quickly check if a variable depends on a set + ancestors_bitset: dict[Apply, int] = {} + for node, index in toposort_index.items(): + node_ancestor_bitset = 1 << index + for inp in node.inputs: + if (inp_node := inp.owner) is not None: + node_ancestor_bitset |= ancestors_bitset[inp_node] + ancestors_bitset[node] = node_ancestor_bitset + while True: try: subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, + ancestors_bitset=ancestors_bitset, toposort_index=toposort_index, ) except ValueError: @@ -929,6 +928,8 @@ def update_fuseable_mappings_after_fg_replace( visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, + toposort_index=toposort_index, + ancestors_bitset=ancestors_bitset, starting_nodes=starting_nodes, updated_nodes=fg.apply_nodes, ) diff --git a/tests/test_printing.py b/tests/test_printing.py index 95c3c938cf..dbad8c063b 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -301,7 +301,8 @@ def test_debugprint(): Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv" exp_res = dedent( r""" - Composite{(i2 + (i0 - i1))} 4 + Composite{(i0 + (i1 - i2))} 4 + ├─ A ├─ ExpandDims{axis=0} v={0: [0]} 3 """ f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2" @@ -313,17 +314,16 @@ def test_debugprint(): │ ├─ B │ ├─ │ └─ 0.0 - ├─ D - └─ A + └─ D Inner graphs: - Composite{(i2 + (i0 - i1))} + Composite{(i0 + (i1 - i2))} ← add 'o0' - ├─ i2 - └─ sub ├─ i0 - └─ i1 + └─ sub + ├─ i1 + └─ i2 """ ).lstrip() From 6f5a3faf95c077986e28d957593af31daefc9f97 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 18 Sep 2025 09:36:13 +0200 Subject: [PATCH 10/33] Avoid backtracking in FusionOptimizer The change in number of fused kernels has to do with the order of iteration, and could be replicated in the old approach by iterating in topological order. It was an accident that it happen to visit in an order where it connected two branches, instead of keeping them separate. The underlying limitation already existed and is described in https://github.com/pymc-devs/pytensor/issues/249 --- pytensor/tensor/rewriting/elemwise.py | 600 ++++++++++-------------- tests/tensor/rewriting/test_elemwise.py | 2 +- 2 files changed, 238 insertions(+), 364 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 77cf934705..a862711eab 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,12 +2,10 @@ import itertools import operator import sys -import typing -from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce +from heapq import heapify, heappop, heappush from operator import or_ -from typing import Literal from warnings import warn import pytensor.scalar.basic as ps @@ -524,43 +522,6 @@ def elemwise_max_operands_fct(node) -> int: return 1024 -class CopyOnWriteDictOfSets: - __slots__ = ("d", "d_copy") - - def __init__(self, d: dict[typing.Any, set]): - self.d = d - self.d_copy: dict[typing.Any, set] = {} - - def __getitem__(self, key): - try: - return self.d_copy[key] - except KeyError: - return self.d[key] - - def get(self, key, default=frozenset()): - try: - return self.d_copy[key] - except KeyError: - try: - return self.d[key] - except KeyError: - return default - - def remove_from_key(self, key, value): - try: - self.d_copy[key].remove(value) - except KeyError: - self.d_copy[key] = copied_value = self.d[key].copy() - copied_value.remove(value) - - def add_to_key(self, key, value): - try: - self.d_copy[key].add(value) - except KeyError: - self.d_copy[key] = copied_value = self.d[key].copy() - copied_value.add(value) - - class FusionOptimizer(GraphRewriter): """Graph optimizer that fuses consecutive Elemwise operations.""" @@ -596,353 +557,266 @@ def apply(self, fgraph): max_operands = elemwise_max_operands_fct(None) - def find_next_fuseable_subgraph( + def find_fuseable_subgraphs( fg: FunctionGraph, ) -> Generator[tuple[list[Variable], list[Variable]], None, None]: - """Find all subgraphs in a FunctionGraph that can be fused together + """Find subgraphs of Elemwise nodes that can be fused together. - Yields - ------- - List of inputs and outputs that determine subgraphs which can be fused. - This generator assumes that such subgraph is replaced by a single - Elemwise Composite before being accessed again in the next iteration. + In general there is no single solution, we try to find large subgraphs eagerly + + Any two consecutive Elemwise nodes that have the same broadcasting pattern, + and a C-implementation (historical accident that should be revisited), are potentially fuseable. + + However, we need to be careful about keeping the fused subgraph "convex", meaning that no two + nodes in the same subgraph are connected via a path that goes outside the subgraph, either because they + are connected via unfuseable nodes, or nodes that have been claimed by another subgraph. + + For example the graph add(sin(exp(x)), sum(exp(x)) cannot be fused into a single Elemwise, because the sum + node breaks the convexity of the subgraph {exp, sin, add}. However, we can fuse {exp, sin}, + and perhaps fuse add with somethnig else. """ - FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] - UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] - - def initialize_fuseable_mappings( - *, fg: FunctionGraph - ) -> tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]: - @cache - def elemwise_scalar_op_has_c_code(node: Apply) -> bool: - # TODO: This should not play a role in non-c backends! - if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): - return True - else: - if config.optimizer_verbose: - warn( - f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." - ) - return False - - # Fuseable nodes have to be accessed in a deterministic manner - # to ensure the rewrite remains deterministic. - # This is not a problem from unfuseable ones, as they can never - # become part of the graph. - fuseable_clients: FUSEABLE_MAPPING = defaultdict(set) - unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) - for out, clients in fg.clients.items(): - out_maybe_fuseable = ( - out.owner is not None - and isinstance(out.owner.op, Elemwise) - # and not isinstance(out.owner.op.scalar_op, ps.Composite) - and len(out.owner.outputs) == 1 - and elemwise_scalar_op_has_c_code(out.owner) + + @cache + def elemwise_scalar_op_has_c_code( + node: Apply, optimizer_verbose=config.optimizer_verbose + ) -> bool: + # TODO: This should not play a role in non-c backends! + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + elif optimizer_verbose: + warn( + f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." ) - if out_maybe_fuseable: - out_bcast = ( - out.type.broadcastable if out_maybe_fuseable else None - ) - for client, _ in clients: - if ( - isinstance(client.op, Elemwise) - # and not isinstance(client.op.scalar_op, ps.Composite) - and len(client.outputs) == 1 - and out_bcast == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) - ): - fuseable_clients[out].add(client) - else: - unfuseable_clients[out].add(client) - else: - unfuseable_clients[out] = {client for client, _ in clients} - - return fuseable_clients, unfuseable_clients - - def find_fuseable_subgraph( - *, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - ancestors_bitset: dict[Apply, int], - toposort_index: dict[Apply, int], - ) -> tuple[list[Variable], list[Variable]]: - for starting_node in toposort_index: - if starting_node in visited_nodes: - continue + return False + + fuseable_clients: dict[Apply, set[Apply]] = {} + candidate_nodes = set() + fg_clients = fg.clients + for out, clients_and_indices in fg_clients.items(): + out_node = out.owner + + if not ( + out_node is not None + and len(out_node.outputs) == 1 + and isinstance(out_node.op, Elemwise) + and elemwise_scalar_op_has_c_code(out_node) + ): + continue - starting_out = starting_node.outputs[0] - if not fuseable_clients.get(starting_out): - visited_nodes.add(starting_node) - continue + candidate_nodes.add(out_node) + out_bcast = out.type.broadcastable + out_fuseable_clients = { + client + for client, _ in clients_and_indices + if ( + len(client.outputs) == 1 + and isinstance(client.op, Elemwise) + and out_bcast == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ) + } + if out_fuseable_clients: + fuseable_clients[out_node] = out_fuseable_clients + + if not fuseable_clients: + return None + + # Create a bitset of ancestors for each node. + # Each node is represented by a bit flag of it's position in the toposort + # With two variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c, + # the ancestors bit flags would be {A: 0b001, B: 0b010, C: 0b100} + # and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111} + # This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND + # For example, to ask if B is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[B] != 0` + # We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do + # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` + node_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} + ancestors_bitset = { + None: 0 + } # Root variables have `None` as owner, which we can handle with a bitset of 0 for `None` + for node, node_bit in node_bitflags.items(): + # The bitset of each node is the union of the bitsets of its inputs, plus its own bit + ancestors_bitset[node] = reduce( + or_, (ancestors_bitset[inp.owner] for inp in node.inputs), node_bit + ) + # handle root and leaf nodes gracefully + node_bitflags[None] = ( + 0 # Root variables have `None` as owner, which we can handle with a bitflag of 0 for `None` + ) + out_bitflag = 1 << len( + node_bitflags + ) # Nothing ever depends on output nodes, so just use a new bit for all + for out in fg.outputs: + for client, _ in fg_clients[out]: + if isinstance(client.op, Output): + node_bitflags[client] = out_bitflag + + subgraphs: list[tuple[int, tuple[list[Variable], list[Variable]]]] = [] + all_subgraphs_bitset = 0 + # Start exploring from candidate sink nodes (backwards) + # These are Elemwise nodes with a C-implementation, that are not part of another subgraph + # And have no other fuseable clients (i.e., are sinks) + for starting_node, starting_bitflag in reversed(node_bitflags.items()): + if ( + starting_bitflag & all_subgraphs_bitset + or starting_node not in candidate_nodes + ): + continue - subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set - subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set - subgraph_inputs_ancestors_bitset = 0 - unfuseable_clients_subgraph_bitset = 0 - - # If we need to manipulate the maps in place, we'll do a shallow copy later - # For now we query on the original ones - fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients) - unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients) - - # We now try to expand as much as possible towards the potentially - # fuseable clients and ancestors to detect the largest possible - # subgraph that can be Composed together into a single `Op`. The - # largest issue to watch out is for cyclical dependencies, where - # some inputs or clients may depend on other nodes of the same - # subgraph via a path that cannot be included in the Composite - # (unfuseable) - fuseable_nodes_to_visit = deque([starting_node]) - while fuseable_nodes_to_visit: - next_node = fuseable_nodes_to_visit.popleft() - visited_nodes.add(next_node) - next_out = next_node.outputs[0] - - # If the output variable of next_node has no fuseable clients - # or has unfuseable clients, then next_node must become an output - # if it is to be fused. - must_become_output = not fuseable_clients_clone.get( - next_out - ) or unfuseable_clients_clone.get(next_out) - - # We have backtracked to this node, and it may no longer be a viable output, - # so we remove it and check again as if we had never seen this node - if must_become_output: - subgraph_outputs.pop(next_out, None) - - # We need to check that any inputs required by this node - # do not depend on other outputs of the current subgraph, - # via an unfuseable path. - must_backtrack = ( - ancestors_bitset[next_node] - & unfuseable_clients_subgraph_bitset - ) - - if not must_backtrack: - implied_unfuseable_clients_bitset = reduce( - or_, - ( - 1 << toposort_index[client] - for client in unfuseable_clients_clone.get(next_out) - if not isinstance(client.op, Output) - ), - 0, - ) + if starting_node in fuseable_clients: + # Not a sink, + continue - # We need to check that any inputs of the current subgraph - # do not depend on other clients of this node, - # via an unfuseable path. - must_backtrack = ( - subgraph_inputs_ancestors_bitset - & implied_unfuseable_clients_bitset - ) + # We keep an ordered queue for expanding the subgraph + # We always want to visit ancestors before clients + # For ancestors, we want to visit the later nodes first (those that have more dependencies) + # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) + # We negate the bitflag for ancestors to achieve this ordering + fuseables_nodes_queue = [(-starting_bitflag, starting_node)] + heapify(fuseables_nodes_queue) + + # We keep 3 bitsets during the exploration: + # - the nodes that are part of the subgraph + # - the unfuseable ancestors of the subgraph (i.e., ancestors that are not fuseable with any node in the subgraph) + # - the unfuseable clients of the subgraph (i.e., clients that are not fuseable with any node in the subgraph) + # Whenever we visit a node, we check if unfuseable ancestors depend on it, or if it depends on an unfuseable client, + # in which case we can't fuse it. If we can fuse it, we then add its unfuseable ancestors/clients to the respective bitsets + # and add its fuseable ancestors/clients to the queue to explore later. This approach requires a visit in the order described above. + # Otherwise, we need to recompute target bitsets in every iteration and/or backtrack. + subgraph_nodes = [] + subgraph_bitset = 0 + unfuseable_ancestors_bitset = 0 + unfuseable_clients_bitset = 0 + + # print(f"\nStarting new subgraph exploration from {starting_node}") + while fuseables_nodes_queue: + node_bitflag, node = heappop(fuseables_nodes_queue) + is_ancestor = node_bitflag < 0 + if is_ancestor: + node_bitflag = -node_bitflag + # print(f"\t > Visiting {'ancestor' if is_ancestor else 'client'} {next_node}") + + if node_bitflag & subgraph_bitset: + # Already part of the subgraph + # print("\t - already in subgraph") + continue - if must_backtrack: - for inp in next_node.inputs: - if inp.owner in visited_nodes: - if next_node not in fuseable_clients_clone[inp]: - # This can happen when next node has repeated inputs - continue - fuseable_clients_clone.remove_from_key( - inp, next_node - ) - unfuseable_clients_clone.add_to_key(inp, next_node) - - # This input must become an output of the subgraph, - # because it can't be merged with next_node. - # We will revisit it to make sure this is safe. - fuseable_nodes_to_visit.appendleft(inp.owner) - - # need to convert to tuple not to change set size during iteration - for client in tuple(fuseable_clients_clone[next_out]): - if client in visited_nodes: - fuseable_clients_clone.remove_from_key( - next_out, client - ) - unfuseable_clients_clone.add_to_key( - next_out, client - ) - - # next_out must become an input of the subgraph. - # We will revisit any of its clients currently - # in the subgraph to make sure this is safe. - fuseable_nodes_to_visit.appendleft(client) - - # Revisit node at a later time - visited_nodes.remove(next_node) + if is_ancestor: + if node_bitflag & unfuseable_ancestors_bitset: + # An unfuseable ancestor depends on this node, can't fuse + # print("\t failed - unfuseable ancestor depends on it") continue + elif ancestors_bitset[node] & unfuseable_clients_bitset: + # This node depends on an unfuseable client, can't fuse + # print("\t failed - depends on unfuseable client") + continue - # Adding next_node to subgraph does not result in any - # immediate dependency problems. Update subgraph - # mappings as if it next_node was part of it. - # Useless inputs will be removed by the useless Composite rewrite - if must_become_output: - subgraph_outputs[next_out] = None - unfuseable_clients_subgraph_bitset |= ( - implied_unfuseable_clients_bitset + # print("\t succeeded - adding to subgraph") + subgraph_nodes.append(node) + subgraph_bitset |= node_bitflag + + # Expand through ancestors and client nodes + # A node can either be: + # - already part of the subgraph (skip) + # - fuseable (add to queue) + # - unfuseable (add to respective unfuseable bitset) + for ancestor in node.inputs: + ancestor_node = ancestor.owner + ancestor_bitflag = node_bitflags[ancestor_node] + if ancestor_bitflag & subgraph_bitset: + continue + if node in fuseable_clients.get(ancestor_node, ()): + heappush( + fuseables_nodes_queue, + (-ancestor_bitflag, ancestor_node), ) + else: + # If an ancestor is unfuseable, so are all its ancestors + unfuseable_ancestors_bitset |= ancestors_bitset[ + ancestor_node + ] + + next_fuseable_clients = fuseable_clients.get(node, ()) + for client, _ in fg_clients[node.outputs[0]]: + client_bitflag = node_bitflags[client] + if client_bitflag & subgraph_bitset: + continue + if client in next_fuseable_clients: + heappush(fuseables_nodes_queue, (client_bitflag, client)) + else: + # If a client is unfuseable, so are all its clients, but we don't need to keep track of those + # Any downstream client will also depend on this unfuseable client and will be rejected when visited + unfuseable_clients_bitset |= client_bitflag - for inp in sorted( - next_node.inputs, - key=lambda x: toposort_index.get(x.owner, -1), - ): - if next_node in unfuseable_clients_clone.get(inp, ()): - # input must become an input of the subgraph since it's unfuseable with new node - subgraph_inputs_ancestors_bitset |= ( - ancestors_bitset.get(inp.owner, 0) - ) - subgraph_inputs[inp] = None - elif inp.owner not in visited_nodes: - fuseable_nodes_to_visit.appendleft(inp.owner) - - # Expand through unvisited fuseable clients - fuseable_nodes_to_visit.extend( - sorted( - ( - node - for node in fuseable_clients_clone.get(next_out) - if node not in visited_nodes - ), - key=toposort_index.get, # type: ignore[arg-type] - ) - ) - - # Don't return if final subgraph is just the original Elemwise - if len(subgraph_outputs) == 1 and set( - next(iter(subgraph_outputs)).owner.inputs - ) == set(subgraph_inputs): - # Update global fuseable mappings - # No input was actually fuseable - for inp in starting_node.inputs: - fuseable_clients[inp].discard(starting_node) - unfuseable_clients[inp].add(starting_node) - # No client was actually fuseable - unfuseable_clients[starting_out].update( - fuseable_clients.pop(starting_out, ()) - ) - continue + # Finished exploring this subgraph + all_subgraphs_bitset |= subgraph_bitset - return list(subgraph_inputs), list(subgraph_outputs) - raise ValueError - - def update_fuseable_mappings_after_fg_replace( - *, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - toposort_index: dict[Apply, int], - ancestors_bitset: dict[Apply, int], - starting_nodes: set[Apply], - updated_nodes: set[Apply], - ) -> None: - # Find new composite node and dropped intermediate nodes - # by comparing the current fg.apply nodes with the cached - # original nodes - (new_composite_node,) = updated_nodes - starting_nodes - dropped_nodes = starting_nodes - updated_nodes - - # Remove intermediate Composite nodes from mappings - # And compute the ancestors bitset of the new composite node - # As well as the new toposort index for the new node - new_node_ancestor_bitset = 0 - new_node_toposort_index = len(toposort_index) - for dropped_node in dropped_nodes: - (dropped_out,) = dropped_node.outputs - fuseable_clients.pop(dropped_out, None) - unfuseable_clients.pop(dropped_out, None) - visited_nodes.remove(dropped_node) - # The new composite ancestor bitset is the union - # of the ancestors of all the dropped nodes - new_node_ancestor_bitset |= ancestors_bitset[dropped_node] - # The new composite node can have the same order as the latest node that was absorbed into it - new_node_toposort_index = max( - new_node_toposort_index, toposort_index[dropped_node] - ) + if subgraph_bitset == starting_bitflag: + # No fusion possible, single node subgraph + continue - ancestors_bitset[new_composite_node] = new_node_ancestor_bitset - toposort_index[new_composite_node] = new_node_toposort_index - - # Update fuseable information for subgraph inputs - for inp in subgraph_inputs: - if inp in fuseable_clients: - new_fuseable_clients = { - client - for client in fuseable_clients[inp] - if client not in dropped_nodes - } - if new_fuseable_clients: - fuseable_clients[inp] = new_fuseable_clients - else: - fuseable_clients.pop(inp) - unfuseable_clients[inp] = ( - unfuseable_clients[inp] - dropped_nodes - ) | {new_composite_node} - - # Update fuseable information for subgraph outputs - for out in new_composite_node.outputs: - unfuseable_clients[out] = {client for client, _ in fg.clients[out]} - - visited_nodes.add(new_composite_node) - return - - # We start by creating two maps, 1) from each node to each potentially - # fuseable client (both nodes must be single output Elemwise with same - # broadcast type) and 2) from each node to each certainly unfuseable - # client (those that don't fit into 1)) - fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) - visited_nodes: set[Apply] = set() - toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} - # Create a bitset for each node of all its ancestors - # This allows to quickly check if a variable depends on a set - ancestors_bitset: dict[Apply, int] = {} - for node, index in toposort_index.items(): - node_ancestor_bitset = 1 << index - for inp in node.inputs: - if (inp_node := inp.owner) is not None: - node_ancestor_bitset |= ancestors_bitset[inp_node] - ancestors_bitset[node] = node_ancestor_bitset - - while True: - try: - subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, - ancestors_bitset=ancestors_bitset, - toposort_index=toposort_index, + # Find out inputs/outputs of subgraph_nodes + not_subgraph_bitset = ~subgraph_bitset + subgraph_inputs = list( + { + inp: None + for node in subgraph_nodes + for inp in node.inputs + if (ancestor_node := inp.owner) is None + or node_bitflags[ancestor_node] & not_subgraph_bitset + } + ) + + subgraph_outputs = [ + node.outputs[0] + for node in subgraph_nodes + if any( + node_bitflags[client] & not_subgraph_bitset + for client, _ in fg_clients[node.outputs[0]] ) - except ValueError: - return - else: - # The caller is now expected to update fg in place, - # by replacing the subgraph with a Composite Op - starting_nodes = fg.apply_nodes.copy() - - yield subgraph_inputs, subgraph_outputs - - # This is where we avoid repeated work by using a stateful - # generator. For large models (as in `TestFusion.test_big_fusion`) - # this can provide huge speedups - update_fuseable_mappings_after_fg_replace( - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, - toposort_index=toposort_index, - ancestors_bitset=ancestors_bitset, - starting_nodes=starting_nodes, - updated_nodes=fg.apply_nodes, + ] + + # We use the min toposort_index for sorting the subgraphs later + min_toposort_index = min( + node_bitflags[var.owner] for var in subgraph_outputs + ) + + # print(f"Found subgraph with {len(subgraph_inputs)} inputs, {len(subgraph_outputs)} outputs, and {len(subgraph_nodes)} nodes ({min_toposort_index=})") + # FunctionGraph(list(subgraph_inputs), subgraph_outputs).dprint() + + subgraphs.append( + ( + min_toposort_index, + (subgraph_inputs, subgraph_outputs), ) + ) + + # Update fuseable clients, inputs can no longer be fused with graph variables + # and outputs can't be fused with anything else + for ancestor in subgraph_inputs: + if (ancestor_node := ancestor.owner) is not None: + if ancestor_fuseable_clients := fuseable_clients.get( + ancestor_node + ): + ancestor_fuseable_clients.difference_update(subgraph_nodes) + if not ancestor_fuseable_clients: + del fuseable_clients[ancestor_node] + + for out in subgraph_outputs: + fuseable_clients.pop(out.owner, None) + + # We need to replace in reverse topological order + yield from (io for _, io in sorted(subgraphs, reverse=True)) nb_fused = 0 nb_replacement = 0 - for inputs, outputs in find_next_fuseable_subgraph(fgraph): + for inputs, outputs in find_fuseable_subgraphs(fgraph): if (len(inputs) + len(outputs)) > max_operands: warn( "Loop fusion failed because the resulting node would exceed " "the kernel argument limit." ) - break + continue scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) composite_outputs = Elemwise( @@ -957,7 +831,7 @@ def update_fuseable_mappings_after_fg_replace( starting_nodes = len(fgraph.apply_nodes) fgraph.replace_all_validate( - list(zip(outputs, composite_outputs, strict=True)), + tuple(zip(outputs, composite_outputs)), reason=self.__class__.__name__, ) nb_fused += 1 diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 3c549788e1..7e625043ec 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1371,7 +1371,7 @@ def test_eval_benchmark(self, benchmark): [ # ("diamond_graph", None, (1, 4)), ("deep_small_kernels", 20, (20, 60)), - ("large_fuseable_graph", 25, (103, 876)), + ("large_fuseable_graph", 25, (128, 876)), ], ) def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): From 6a99dcaca2c26c4626a255c74649fd276a08ba6f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 3 Sep 2025 13:14:32 +0200 Subject: [PATCH 11/33] Benchmark function compilation --- tests/compile/function/test_types.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 90589db337..6f122767bb 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -1357,3 +1357,17 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark): rng_val = np.random.default_rng() benchmark(f, rng_val) + + +@pytest.mark.parametrize("mode", ["FAST_COMPILE", "FAST_RUN"]) +@pytest.mark.parametrize("depth", [2, 20]) +def test_function_compilation_benchmark(mode, depth, benchmark): + def compile_function(mode=mode, depth=depth): + x = pt.matrix("x") + out = x + for _ in range(depth): + out = pt.sin(out.T) + pt.cos(out) + fn = function([x], out, mode=mode) + return fn + + benchmark.pedantic(compile_function, iterations=20, rounds=5) From 538a5e7b2f15c566081081afd01bd546201fbaf8 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 3 Sep 2025 13:14:19 +0200 Subject: [PATCH 12/33] Avoid FunctionGraph overhead when compiling single Ops to C --- pytensor/link/c/op.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 8ccfa2a9a3..17d183e52e 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -10,6 +10,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable +from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import ComputeMapType, Op, StorageMapType, ThunkType from pytensor.graph.type import HasDataType from pytensor.graph.utils import MethodNotDefined @@ -32,6 +33,30 @@ def is_cthunk_wrapper_type(thunk: Callable[[], None]) -> CThunkWrapperType: return res +class SingleOpFunctionGraph(FunctionGraph): + """A `FunctionGraph` with a single `Apply` node. + + This is used to compile a single `Apply` node with the C linker. + + """ + + def __init__(self, node: Apply, clone: bool = True): + if clone: + node = node.clone_with_new_inputs([i.clone() for i in node.inputs]) + self.node = node + self.apply_nodes = {node} + self.inputs = inputs = node.inputs + self.outputs = outputs = node.outputs + self.variables = set(inputs) | set(outputs) + self.clients = {inp: [(node, idx)] for idx, inp in enumerate(inputs)} + self.clients |= { + out: [(Output(idx).make_node(out), 0)] for idx, out in enumerate(outputs) + } + + def toposort(self): + return [self.node] + + class COp(Op, CLinkerOp): """An `Op` with a C implementation.""" @@ -51,12 +76,11 @@ def make_c_thunk( # The conclusion should be that the antire "make_c_thunk" method should be defined # in pytensor.link.c and dispatched onto the Op! import pytensor.link.c.basic - from pytensor.graph.fg import FunctionGraph node_input_storage = [storage_map[r] for r in node.inputs] node_output_storage = [storage_map[r] for r in node.outputs] - e = FunctionGraph(node.inputs, node.outputs) + e = SingleOpFunctionGraph(node) e_no_recycling = [ new_o for (new_o, old_o) in zip(e.outputs, node.outputs, strict=True) From cd213d17a2b35144b623acf73b7a6cf2db05dac4 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 3 Sep 2025 13:46:55 +0200 Subject: [PATCH 13/33] Use single tracks in WalkingGraphRewriter --- pytensor/graph/rewriting/basic.py | 76 ++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 6 deletions(-) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 750250ea0d..c373bf42a9 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -2008,25 +2008,89 @@ def __init__( if order not in valid_orders: raise ValueError(f"order must be one of {valid_orders}, got {order}") self.order = order + # Use tracks functionality to pre-filter nodes, if it's a single Op or Op type + tracks = node_rewriter.tracks() + self.tracks = tracks[0] if (tracks is not None and len(tracks) == 1) else None super().__init__(node_rewriter, ignore_newtrees, failure_callback) def apply(self, fgraph, start_from=None): if start_from is None: start_from = fgraph.outputs callback_before = fgraph.execute_callbacks_time - nb_nodes_start = len(fgraph.apply_nodes) + apply_nodes = fgraph.apply_nodes + nb_nodes_start = len(apply_nodes) t0 = time.perf_counter() - q = deque( + if (tracks := self.tracks) is not None: + # Pre-filter nodes to consider based on tracks + if isinstance(tracks, Op): + # Equality + candidate_nodes = { + node for node in fgraph.apply_nodes if node.op == tracks + } + elif isinstance(tracks, OpPattern): + candidate_nodes = { + node + for node in fgraph.apply_nodes + if tracks.match_op(node.op) is not False + } + else: + # isinstance + candidate_nodes = { + node for node in fgraph.apply_nodes if isinstance(node.op, tracks) + } + + if not candidate_nodes: + # Abort early + return ( + self, + 0, # nodes changed + nb_nodes_start, + nb_nodes_start, # nb_nodes_end + time.perf_counter() - t0, # io_t + 0, # loop_t + 0, # callback_time + self.node_rewriter, + ) + + if isinstance(tracks, Op): + + def importer(node): + if node is not current_node and node.op == tracks: + q.append(node) + + elif isinstance(tracks, OpPattern): + + def importer(node): + if ( + node is not current_node + and tracks.match_op(node.op) is not False + ): + q.append(node) + + else: + + def importer(node): + if node is not current_node and isinstance(node.op, tracks): + q.append(node) + else: + # Otherwise, we will call the node_rewriter on every node in the graph + candidate_nodes = None + + def importer(node): + if node is not current_node: + q.append(node) + + node_iterator = ( apply_ancestors(start_from) if (self.order == "dfs") else toposort(start_from) ) + if candidate_nodes: + q = deque(node for node in node_iterator if node in candidate_nodes) + else: + q = deque(node_iterator) io_t = time.perf_counter() - t0 - def importer(node): - if node is not current_node: - q.append(node) - u = self.attach_updater( fgraph, importer, None, name=getattr(self, "name", None) ) From 1db1f92ff4142d5c498938f0773be3f6d5e62667 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 3 Sep 2025 14:23:02 +0200 Subject: [PATCH 14/33] Exit from DestroyHandler orderings faster --- pytensor/graph/destroyhandler.py | 190 +++++++++++++++---------------- 1 file changed, 95 insertions(+), 95 deletions(-) diff --git a/pytensor/graph/destroyhandler.py b/pytensor/graph/destroyhandler.py index 1fe59f2c6d..ad7768c499 100644 --- a/pytensor/graph/destroyhandler.py +++ b/pytensor/graph/destroyhandler.py @@ -699,106 +699,106 @@ def orderings(self, fgraph, ordered=True): c) an Apply destroys (illegally) one of its own inputs by aliasing """ + if not self.destroyers: + return {} + set_type = OrderedSet if ordered else set rval = {} - if self.destroyers: - # BUILD DATA STRUCTURES - # CHECK for multiple destructions during construction of variables - - droot, impact, __ignore = self.refresh_droot_impact() - - # check for destruction of constants - illegal_destroy = [ - r - for r in droot - if getattr(r.tag, "indestructible", False) or isinstance(r, Constant) - ] - if illegal_destroy: - raise InconsistencyError( - f"Attempting to destroy indestructible variables: {illegal_destroy}" - ) + # BUILD DATA STRUCTURES + # CHECK for multiple destructions during construction of variables + + droot, impact, __ignore = self.refresh_droot_impact() - # add destroyed variable clients as computational dependencies - for app in self.destroyers: - # keep track of clients that should run before the current Apply - root_clients = set_type() - # for each destroyed input... - for input_idx_list in app.op.destroy_map.values(): - destroyed_idx = input_idx_list[0] - destroyed_variable = app.inputs[destroyed_idx] - root = droot[destroyed_variable] - root_impact = impact[root] - # we generally want to put all clients of things which depend on root - # as pre-requisites of app. - # But, app is itself one such client! - # App will always be a client of the node we're destroying - # (destroyed_variable, but the tricky thing is when it is also a client of - # *another variable* viewing on the root. Generally this is illegal, (e.g., - # add_inplace(x, x.T). In some special cases though, the in-place op will - # actually be able to work properly with multiple destroyed inputs (e.g, - # add_inplace(x, x). An Op that can still work in this case should declare - # so via the 'destroyhandler_tolerate_same' attribute or - # 'destroyhandler_tolerate_aliased' attribute. - # - # destroyhandler_tolerate_same should be a list of pairs of the form - # [(idx0, idx1), (idx0, idx2), ...] - # The first element of each pair is the input index of a destroyed - # variable. - # The second element of each pair is the index of a different input where - # we will permit exactly the same variable to appear. - # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed - # input is also allowed to appear as the second argument. - # - # destroyhandler_tolerate_aliased is the same sort of list of - # pairs. - # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the - # destroyhandler to IGNORE an aliasing between a destroyed - # input idx0 and another input idx1. - # This is generally a bad idea, but it is safe in some - # cases, such as - # - the op reads from the aliased idx1 before modifying idx0 - # - the idx0 and idx1 are guaranteed not to overlap (e.g. - # they are pointed at different rows of a matrix). - # - - # CHECK FOR INPUT ALIASING - # OPT: pre-compute this on import - tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", []) - assert isinstance(tolerate_same, list) - tolerated = { - idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx - } - tolerated.add(destroyed_idx) - tolerate_aliased = getattr( - app.op, "destroyhandler_tolerate_aliased", [] - ) - assert isinstance(tolerate_aliased, list) - ignored = { - idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx - } - for i, input in enumerate(app.inputs): - if i in ignored: - continue - if input in root_impact and ( - i not in tolerated or input is not destroyed_variable - ): - raise InconsistencyError( - f"Input aliasing: {app} ({destroyed_idx}, {i})" - ) - - # add the rule: app must be preceded by all other Apply instances that - # depend on destroyed_input - for r in root_impact: - assert not [a for a, c in self.clients[r].items() if not c] - root_clients.update( - [a for a, c in self.clients[r].items() if c] + # check for destruction of constants + illegal_destroy = [ + r + for r in droot + if getattr(r.tag, "indestructible", False) or isinstance(r, Constant) + ] + if illegal_destroy: + raise InconsistencyError( + f"Attempting to destroy indestructible variables: {illegal_destroy}" + ) + + # add destroyed variable clients as computational dependencies + for app in self.destroyers: + # keep track of clients that should run before the current Apply + root_clients = set_type() + # for each destroyed input... + for input_idx_list in app.op.destroy_map.values(): + destroyed_idx = input_idx_list[0] + destroyed_variable = app.inputs[destroyed_idx] + root = droot[destroyed_variable] + root_impact = impact[root] + # we generally want to put all clients of things which depend on root + # as pre-requisites of app. + # But, app is itself one such client! + # App will always be a client of the node we're destroying + # (destroyed_variable, but the tricky thing is when it is also a client of + # *another variable* viewing on the root. Generally this is illegal, (e.g., + # add_inplace(x, x.T). In some special cases though, the in-place op will + # actually be able to work properly with multiple destroyed inputs (e.g, + # add_inplace(x, x). An Op that can still work in this case should declare + # so via the 'destroyhandler_tolerate_same' attribute or + # 'destroyhandler_tolerate_aliased' attribute. + # + # destroyhandler_tolerate_same should be a list of pairs of the form + # [(idx0, idx1), (idx0, idx2), ...] + # The first element of each pair is the input index of a destroyed + # variable. + # The second element of each pair is the index of a different input where + # we will permit exactly the same variable to appear. + # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed + # input is also allowed to appear as the second argument. + # + # destroyhandler_tolerate_aliased is the same sort of list of + # pairs. + # op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the + # destroyhandler to IGNORE an aliasing between a destroyed + # input idx0 and another input idx1. + # This is generally a bad idea, but it is safe in some + # cases, such as + # - the op reads from the aliased idx1 before modifying idx0 + # - the idx0 and idx1 are guaranteed not to overlap (e.g. + # they are pointed at different rows of a matrix). + # + + # CHECK FOR INPUT ALIASING + # OPT: pre-compute this on import + tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", []) + assert isinstance(tolerate_same, list) + tolerated = { + idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx + } + tolerated.add(destroyed_idx) + tolerate_aliased = getattr( + app.op, "destroyhandler_tolerate_aliased", [] + ) + assert isinstance(tolerate_aliased, list) + ignored = { + idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx + } + for i, input in enumerate(app.inputs): + if i in ignored: + continue + if input in root_impact and ( + i not in tolerated or input is not destroyed_variable + ): + raise InconsistencyError( + f"Input aliasing: {app} ({destroyed_idx}, {i})" ) - # app itself is a client of the destroyed inputs, - # but should not run before itself - root_clients.remove(app) - if root_clients: - rval[app] = root_clients + # add the rule: app must be preceded by all other Apply instances that + # depend on destroyed_input + for r in root_impact: + assert not [a for a, c in self.clients[r].items() if not c] + root_clients.update([a for a, c in self.clients[r].items() if c]) + + # app itself is a client of the destroyed inputs, + # but should not run before itself + root_clients.remove(app) + if root_clients: + rval[app] = root_clients return rval From 54c3f02427b28174e4e7483dd753632735b4f6a0 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 09:02:33 +0200 Subject: [PATCH 15/33] Make DimShuffle a regular COp This is much faster to create/compile and produces better code for the most ubiquitous Op in PyTensor --- pytensor/link/c/op.py | 4 + pytensor/tensor/c_code/dimshuffle.c | 86 ------------- pytensor/tensor/elemwise.py | 190 ++++++++++++++++++---------- 3 files changed, 125 insertions(+), 155 deletions(-) delete mode 100644 pytensor/tensor/c_code/dimshuffle.c diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index 17d183e52e..e1e61e4874 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -331,6 +331,10 @@ def __init__( files overriding sections in previous files. """ + warnings.warn( + "ExternalCOp is deprecated and will be removed in a future release. Use regular COp instead.", + FutureWarning, + ) if not isinstance(func_files, list): self.func_files = [Path(func_files)] else: diff --git a/pytensor/tensor/c_code/dimshuffle.c b/pytensor/tensor/c_code/dimshuffle.c deleted file mode 100644 index 0bfc5df3bb..0000000000 --- a/pytensor/tensor/c_code/dimshuffle.c +++ /dev/null @@ -1,86 +0,0 @@ -#section support_code_apply - -int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PARAMS_TYPE *params) { - npy_int64* new_order; - npy_intp nd_in; - npy_intp nd_out; - npy_intp* dimensions; - npy_intp* strides; - - if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) { - PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous."); - return 1; - } - new_order = (npy_int64*) PyArray_DATA(params->_new_order); - nd_in = (npy_intp)(params->input_ndim); - nd_out = PyArray_SIZE(params->_new_order); - - if (PyArray_NDIM(input) != nd_in) { - PyErr_SetString(PyExc_ValueError, "DimShuffle: Input has less dimensions than expected."); - return 1; - } - - // Compute new dimensions and strides - dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); - strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp)); - if (dimensions == NULL || strides == NULL) { - PyErr_NoMemory(); - free(dimensions); - free(strides); - return 1; - }; - - npy_intp original_size = PyArray_SIZE(input); - npy_intp new_size = 1; - for (npy_intp i = 0; i < nd_out; ++i) { - // We set the strides of length 1 dimensions to PyArray_ITEMSIZE(input). - // The value is arbitrary, because there is never a next element. - // np.expand_dims(x, 0) and x[None] do different things here. - // I would prefer zero, but there are some poorly implemented BLAS operations - // That don't handle zero strides correctly. At least they won't fail because of DimShuffle. - if (new_order[i] != -1) { - dimensions[i] = PyArray_DIMS(input)[new_order[i]]; - strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? PyArray_ITEMSIZE(input) : PyArray_STRIDES(input)[new_order[i]]; - } else { - dimensions[i] = 1; - strides[i] = PyArray_ITEMSIZE(input); - } - new_size *= dimensions[i]; - } - - if (original_size != new_size) { - PyErr_SetString(PyExc_ValueError, "DimShuffle: Attempting to squeeze axes with size not equal to one."); - free(dimensions); - free(strides); - return 1; - } - - if (*res) - Py_XDECREF(*res); - - // Create the new array. - *res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions, - PyArray_TYPE(input), strides, - PyArray_DATA(input), PyArray_ITEMSIZE(input), - // borrow only the writable flag from the base - // the NPY_OWNDATA flag will default to 0. - (NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(input)), - NULL); - - if (*res == NULL) { - free(dimensions); - free(strides); - return 1; - } - - // Declare it a view of the original input - Py_INCREF((PyObject*)input); - PyArray_SetBaseObject(*res, (PyObject*)input); - - // recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED - PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL); - - free(strides); - free(dimensions); - return 0; -} \ No newline at end of file diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 4dd29dc37d..712b016a8a 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -13,14 +13,13 @@ from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.basic import failure_code -from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp -from pytensor.link.c.params_type import ParamsType +from pytensor.link.c.op import COp, OpenMPOp from pytensor.misc.frozendict import frozendict from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type from pytensor.scalar.basic import identity as scalar_identity -from pytensor.scalar.basic import int64, transfer_type, upcast +from pytensor.scalar.basic import transfer_type, upcast from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable @@ -29,7 +28,6 @@ continuous_dtypes, discrete_dtypes, float_dtypes, - lvector, ) from pytensor.tensor.utils import ( broadcast_static_dim_lengths, @@ -40,7 +38,7 @@ from pytensor.utils import uniq -class DimShuffle(ExternalCOp): +class DimShuffle(COp): """ Allows to reorder the dimensions of a tensor or insert or remove broadcastable dimensions. @@ -114,20 +112,9 @@ class DimShuffle(ExternalCOp): _f16_ok = True check_input = False __props__ = ("input_ndim", "new_order") - c_func_file = "c_code/dimshuffle.c" - c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)" view_map = {0: [0]} - @property - def params_type(self): - return ParamsType( - _new_order=lvector, - input_ndim=int64, - ) - def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): - super().__init__([self.c_func_file], self.c_func_name) - if not isinstance(input_ndim, int): raise TypeError(f"input_ndim must be an integer, got {type(int)}") @@ -135,53 +122,44 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): self.new_order = tuple(new_order) self._new_order = [(-1 if x == "x" else x) for x in self.new_order] - for i, j in enumerate(new_order): - if j != "x": - if not isinstance(j, int | np.integer): - raise TypeError( - "DimShuffle indices must be Python ints; got " - f"{j} of type {type(j)}." - ) - if j >= input_ndim: - raise ValueError( - f"new_order[{i}] is {j}, but the input only has " - f"{input_ndim} axes." - ) - if j in new_order[(i + 1) :]: - raise ValueError( - "The same input dimension may not appear " - f"twice in the list of output dimensions: {new_order}" - ) - # List of input dimensions to drop - drop = [i for i in range(input_ndim) if i not in new_order] + self.drop = drop = [i for i in range(input_ndim) if i not in new_order] # This is the list of the original dimensions that we keep - self.shuffle = [x for x in new_order if x != "x"] - self.transposition = self.shuffle + drop - # List of dimensions of the output that are broadcastable and were not - # in the original input - self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x") - self.drop = drop + self.shuffle = shuffle = [x for x in new_order if x != "x"] + + # Input validation + if not all(isinstance(x, int | np.integer) for x in shuffle): + raise TypeError( + "DimShuffle indices must be Python ints; got " + f"{shuffle} of type {[type(x) for x in shuffle]}." + ) + if len(shuffle) != len(set(shuffle)): + raise ValueError( + f"Some dimensions were duplicated in new_order: {new_order}" + ) + if max(shuffle, default=0) > input_ndim: + raise ValueError( + f"Some dimensions in new_order are too large for input_ndim {input_ndim}: {new_order}" + ) - dims_are_shuffled = sorted(self.shuffle) != self.shuffle + self.transposition = self.shuffle + drop + # List of expand_dims positions + self.augment = augment = [i for i, x in enumerate(new_order) if x == "x"] + # Properties that are useful for rewrites + self.dims_are_shuffled = dims_are_shuffled = sorted(shuffle) != shuffle self.is_transpose = dims_are_shuffled and not augment and not drop self.is_squeeze = drop and not dims_are_shuffled and not augment - self.is_expand_dims = augment and not dims_are_shuffled and not drop - self.is_left_expand_dims = self.is_expand_dims and ( + self.is_expand_dims = is_expand_dims = ( + augment and not dims_are_shuffled and not drop + ) + self.is_left_expand_dims = is_expand_dims and ( input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) ) - self.is_right_expand_dims = self.is_expand_dims and new_order[ - :input_ndim - ] == list(range(input_ndim)) - - def __setstate__(self, state): - self.__dict__.update(state) - if not hasattr(self, "func_files"): - # Perhaps we are loading an old `Op` version of DimShuffle. - # Let's just build the ExternalCOp. - super().__init__([self.c_func_file], self.c_func_name) + self.is_right_expand_dims = is_expand_dims and new_order[:input_ndim] == list( + range(input_ndim) + ) def make_node(self, inp): input = as_tensor_variable(inp) @@ -193,22 +171,18 @@ def make_node(self, inp): input_static_shape = input.type.shape - # Runtime check for invalid drop - for d in self.drop: - if input_static_shape[d] not in (1, None): - raise TypeError( - f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}" - ) - - out_static_shape = [] - for dim_idx in self.new_order: - if dim_idx == "x": - out_static_shape.append(1) - else: - out_static_shape.append(input_static_shape[dim_idx]) - - output = TensorType(dtype=input.type.dtype, shape=out_static_shape)() + # Check for invalid drop + if self.drop: + for d in self.drop: + if input_static_shape[d] not in (1, None): + raise TypeError( + f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}" + ) + output = TensorType( + dtype=input.type.dtype, + shape=[1 if d == "x" else input_static_shape[d] for d in self.new_order], + )() return Apply(self, [input], [output]) def __str__(self): @@ -273,6 +247,84 @@ def grad(self, inp, grads): else: return [gz.dimshuffle(grad_order)] + def c_code(self, node, name, input_names, output_names, sub): + [inp] = input_names + [out] = output_names + nd_in = node.inputs[0].ndim + nd_out = node.outputs[0].ndim + drop = self.drop + fail = sub["fail"] + + code = f"npy_intp dimensions[{nd_out}];\n" + code += f"npy_intp strides[{nd_out}];\n" + if drop: + code += "npy_intp new_size = 1;\n" + + code += dedent( + f""" + if (PyArray_NDIM({inp}) != {nd_in}) {{ + PyErr_SetString(PyExc_ValueError, "ExpandDims: Input dimensions do not match expected."); + {fail} + }} + """ + ) + + for i, o in enumerate(self.new_order): + if o == "x": + code += f"dimensions[{i}] = 1;\n" + code += f"strides[{i}] = PyArray_ITEMSIZE({inp});\n" + else: + code += f"dimensions[{i}] = PyArray_DIMS({inp})[{o}];\n" + code += f"strides[{i}] = PyArray_DIMS({inp})[{o}] == 1 ? PyArray_ITEMSIZE({inp}) : PyArray_STRIDES({inp})[{o}];\n" + if drop: + code += f"new_size *= dimensions[{i}];\n" + + if drop: + code += dedent( + f""" + if (PyArray_SIZE({inp}) != new_size) {{ + PyErr_SetString(PyExc_ValueError, "DimShuffle: Attempting to squeeze axes with size not equal to one."); + {fail} + }} + """ + ) + + code += dedent( + f""" + Py_XDECREF({out}); + + Py_INCREF(PyArray_DESCR({inp})); + {out} = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type, + PyArray_DESCR({inp}), + {nd_out}, dimensions, + strides, + PyArray_DATA({inp}), + (PyArray_FLAGS({inp}) & ~NPY_ARRAY_OWNDATA), + NULL); + + if ({out} == NULL) {{ + {fail} + }} + + // Declare it a view of the original input + Py_INCREF((PyObject*){inp}); + PyArray_SetBaseObject({out}, (PyObject*){inp}); + """ + ) + + if self.dims_are_shuffled: + code += dedent( + f""" + // recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED + PyArray_UpdateFlags({out}, NPY_ARRAY_UPDATE_ALL); + """ + ) + + return code + + def c_code_cache_version(self): + return (0,) + class DimShufflePrinter(Printer): def __p(self, new_order, pstate, r): From 5e2ceb76905b4cf944ea78b312e850cd1258dc0f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 12:14:37 +0200 Subject: [PATCH 16/33] Speedup _gemm_canonicalize --- pytensor/tensor/rewriting/blas.py | 61 ++++++++++++++++++------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index 03a1e8b0ab..2f2bb6b67f 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -60,6 +60,7 @@ import numpy as np from pytensor.graph.traversal import toposort +from pytensor.scalar import Add, Mul, Neg, Sub from pytensor.tensor.rewriting.basic import register_specialize @@ -100,10 +101,7 @@ from pytensor.tensor.math import ( Dot, _matmul, - add, mul, - neg, - sub, variadic_add, ) from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift @@ -237,22 +235,27 @@ def scaled(thing): rval.append(scaled(r)) return rval - if maxclients and len(fgraph.clients[r]) > maxclients: + if ( + (r.owner is None) + or (not isinstance(r.owner.op, Elemwise)) + or (maxclients and len(fgraph.clients[r]) > maxclients) + ): rval.append((scale, r)) return rval - if r.owner and r.owner.op == sub: + scalar_op = r.owner.op.scalar_op + if isinstance(scalar_op, Sub): _gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1) _gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1) - elif r.owner and r.owner.op == add: + elif isinstance(scalar_op, Add): for i in r.owner.inputs: _gemm_canonicalize(fgraph, i, scale, rval, 1) - elif r.owner and r.owner.op == neg: + elif isinstance(scalar_op, Neg): _gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1) - elif r.owner and r.owner.op == mul: + elif isinstance(scalar_op, Mul): scalars = [] vectors = [] matrices = [] @@ -460,35 +463,45 @@ def apply(self, fgraph): callbacks_before = fgraph.execute_callbacks_times.copy() callback_before = fgraph.execute_callbacks_time - nodelist = list(toposort(fgraph.outputs)) + relevant_core_ops = ( + pytensor.scalar.Add + | pytensor.scalar.Sub + | pytensor.scalar.Neg + | pytensor.scalar.Mul + ) + nodelist = [ + a + for a in toposort(fgraph.outputs) + if ( + isinstance(a.op, Elemwise) + and isinstance(a.op.scalar_op, relevant_core_ops) + ) + ] + if not nodelist: + return None + nodelist.reverse() def on_import(new_node): - if new_node is not node: + if ( + new_node is not node + and isinstance(new_node.op, Elemwise) + and isinstance(new_node.op.scalar_op, relevant_core_ops) + ): nodelist.append(new_node) u = pytensor.graph.rewriting.basic.DispatchingFeature( on_import, None, None, name="GemmOptimizer" ) fgraph.attach_feature(u) + fgraph_apply_nodes = fgraph.apply_nodes while did_something: nb_iter += 1 t0 = time.perf_counter() time_toposort += time.perf_counter() - t0 did_something = False for node in nodelist: - if not ( - isinstance(node.op, Elemwise) - and isinstance( - node.op.scalar_op, - pytensor.scalar.Add - | pytensor.scalar.Sub - | pytensor.scalar.Neg - | pytensor.scalar.Mul, - ) - ): - continue - if node not in fgraph.apply_nodes: + if node not in fgraph_apply_nodes: # This mean that we already removed this node from # the graph continue @@ -502,7 +515,6 @@ def on_import(new_node): continue if new_outputs: new_outputs, old_dot22 = new_outputs - assert len(new_outputs) == len(node.outputs) new_outputs[ 0 ].tag.values_eq_approx = values_eq_approx_remove_inf_nan @@ -518,8 +530,7 @@ def on_import(new_node): did_something = True nb_replacement += 1 except InconsistencyError: - # TODO: retry other applications of gemm (see comment - # in _gemm_from_node) + # TODO: retry other applications of gemm (see comment in _gemm_from_node) nb_inconsistency_replace += 1 except ReplacementDidNotRemoveError: nb_replacement_didn_t_remove += 1 From 1719169074777aca5f4896925ffcb77ac1fedbe2 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 13:05:41 +0200 Subject: [PATCH 17/33] Simpler Elemwise.infer_shape --- pytensor/tensor/elemwise.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 712b016a8a..efa4211ae2 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -8,7 +8,7 @@ import pytensor.tensor.basic from pytensor.configdefaults import config from pytensor.gradient import DisconnectedType -from pytensor.graph.basic import Apply +from pytensor.graph.basic import Apply, Constant from pytensor.graph.null_type import NullType from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed from pytensor.graph.utils import MethodNotDefined @@ -797,10 +797,24 @@ def _check_runtime_broadcast(node, inputs): ) def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...]]: - from pytensor.tensor.extra_ops import broadcast_shape + out_shape = list(node.outputs[0].type.shape) + if missing_dims := [i for i, s in enumerate(out_shape) if s is None]: + for inp_idx, inp in enumerate(node.inputs): + inp_st_shape = inp.type.shape + for d in missing_dims: + if inp_st_shape[d] == 1: + continue # Nothing to learn from this input + if inp_st_shape[d] is not None: + out_shape[d] = inp_st_shape[d] + missing_dims.remove(d) + else: + out_shape[d] = new_dim = i_shapes[inp_idx][d] + if isinstance(new_dim, Constant): + missing_dims.remove(d) + if not missing_dims: + break - out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True) - return [tuple(as_tensor_variable(s) for s in out_shape)] * len(node.outputs) + return [tuple(out_shape) for _ in node.outputs] def _c_all(self, node, nodename, inames, onames, sub): # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` From e4da2d7459193e6cab648d4ee8dbe65c09003d8e Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 14:41:26 +0200 Subject: [PATCH 18/33] Use non recursive algorithm in `rebuild_collect_shared` --- pytensor/compile/function/pfunc.py | 87 ++++++++++++++++-------------- 1 file changed, 47 insertions(+), 40 deletions(-) diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index 91d6e1a588..2056d077be 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -179,47 +179,54 @@ def clone_v_get_shared_updates(v, copy_inputs_over): """ # this co-recurses with clone_a - assert v is not None - if v in clone_d: - return clone_d[v] - if v.owner: - owner = v.owner - if owner not in clone_d: - for i in owner.inputs: - clone_v_get_shared_updates(i, copy_inputs_over) - clone_node_and_cache( - owner, - clone_d, - strict=rebuild_strict, - clone_inner_graphs=clone_inner_graphs, - ) - return clone_d.setdefault(v, v) - elif isinstance(v, SharedVariable): - if v not in shared_inputs: - shared_inputs.append(v) - if v.default_update is not None: - # Check that v should not be excluded from the default - # updates list - if no_default_updates is False or ( - isinstance(no_default_updates, list) and v not in no_default_updates - ): - # Do not use default_update if a "real" update was - # provided - if v not in update_d: - v_update = v.type.filter_variable( - v.default_update, allow_convert=False + stack = [v] + try: + while True: + v = stack.pop() + if v in clone_d: + continue + if (apply := v.owner) is not None: + if all(i in clone_d for i in apply.inputs): + # all inputs have been cloned, we can clone this node + clone_node_and_cache( + apply, + clone_d, + strict=rebuild_strict, + clone_inner_graphs=clone_inner_graphs, ) - if not v.type.is_super(v_update.type): - raise TypeError( - "An update must have a type compatible with " - "the original shared variable" - ) - update_d[v] = v_update - update_expr.append((v, v_update)) - if not copy_inputs_over: - return clone_d.setdefault(v, v.clone()) - else: - return clone_d.setdefault(v, v) + else: + # expand on the inputs + stack.extend(apply.inputs) + else: + clone_d[v] = v if copy_inputs_over else v.clone() + + # Special handling of SharedVariables + if isinstance(v, SharedVariable): + if v not in shared_inputs: + shared_inputs.append(v) + if v.default_update is not None: + # Check that v should not be excluded from the default + # updates list + if no_default_updates is False or ( + isinstance(no_default_updates, list) + and v not in no_default_updates + ): + # Do not use default_update if a "real" update was + # provided + if v not in update_d: + v_update = v.type.filter_variable( + v.default_update, allow_convert=False + ) + if not v.type.is_super(v_update.type): + raise TypeError( + "An update must have a type compatible with " + "the original shared variable" + ) + update_d[v] = v_update + update_expr.append((v, v_update)) + except IndexError: + pass # stack is empty + return clone_d[v] # initialize the clone_d mapping with the replace dictionary if replace is None: From b4cc77e25422573b857d999c6f6767f1fc6e8550 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 14:42:26 +0200 Subject: [PATCH 19/33] .avoid cast in hot loop --- pytensor/graph/basic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 5d6667683b..bb6334aed8 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -915,10 +915,8 @@ def clone_node_and_cache( # Use a cached `Op` clone when available new_op: Op | None = cast(Optional["Op"], clone_d.get(node.op)) - cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs] - new_node = node.clone_with_new_inputs( - cloned_inputs, + [clone_d[i] for i in node.inputs], # Only clone inner-graph `Op`s when there isn't a cached clone (and # when `clone_inner_graphs` is enabled) clone_inner_graph=clone_inner_graphs if new_op is None else False, From fdf12fde603e8dee5f6ade20f16b9fa7703dc48a Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 14:42:50 +0200 Subject: [PATCH 20/33] .faster tensortype creation TODO: Do same for ScalarTypes --- pytensor/tensor/type.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 5ae92006e2..6f29269ac7 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -35,6 +35,9 @@ int_dtypes = list(map(str, ps.int_types)) uint_dtypes = list(map(str, ps.uint_types)) +_all_dtypes_str: dict[str, str] = {d: d for d in all_dtypes} +_str_to_numpy_dtype: dict[str, np.dtype] = {} + # TODO: add more type correspondences for e.g. int32, int64, float32, # complex64, etc. dtype_specs_map = { @@ -99,13 +102,26 @@ def __init__( ) shape = broadcastable - if str(dtype) == "floatX": - self.dtype = config.floatX + if isinstance(dtype, str): + if dtype == "floatX": + dtype = config.floatX + elif dtype not in _all_dtypes_str: + # Check if dtype is a valid numpy dtype + try: + dtype = str(np.dtype(dtype)) + except TypeError as exc: + raise TypeError( + f"Unsupported dtype for TensorType: {dtype}" + ) from exc + else: + _all_dtypes_str[dtype] = dtype else: try: - self.dtype = str(np.dtype(dtype)) - except TypeError: - raise TypeError(f"Invalid dtype: {dtype}") + dtype = str(np.dtype(dtype)) + except TypeError as exc: + raise TypeError(f"Unsupported dtype for TensorType: {dtype}") from exc + + self.dtype = dtype def parse_bcast_and_shape(s): if isinstance(s, bool | np.bool_): @@ -121,7 +137,10 @@ def parse_bcast_and_shape(s): self.shape = tuple(parse_bcast_and_shape(s) for s in shape) self.dtype_specs() # error checking is done there self.name = name - self.numpy_dtype = np.dtype(self.dtype) + try: + self.numpy_dtype = _str_to_numpy_dtype[dtype] + except KeyError: + self.numpy_dtype = _str_to_numpy_dtype[dtype] = np.dtype(dtype) def __call__(self, *args, shape=None, **kwargs): if shape is not None: From 03083848ec6ca313d5c230a9b57861b5174ac759 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 15:15:56 +0200 Subject: [PATCH 21/33] upcast not needed all the time --- pytensor/tensor/rewriting/blas.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index 2f2bb6b67f..34273992c3 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -60,7 +60,7 @@ import numpy as np from pytensor.graph.traversal import toposort -from pytensor.scalar import Add, Mul, Neg, Sub +from pytensor.scalar import Add, Mul, Neg, Sub, upcast from pytensor.tensor.rewriting.basic import register_specialize @@ -359,8 +359,12 @@ def _gemm_from_factored_list(fgraph, lst): if isinstance(sM, tuple): sm0, sm1 = sM sm0 = ptb.as_tensor_variable(sm0) - if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: - lst2.append((ptb.cast(sm0, sm1.dtype), sM[1])) + sm0_dtype = sm0.type.dtype + sm1_dtype = sm1.type.dtype + if sm0_dtype == sm1_dtype: + lst2.append((sm0, sm1)) + elif upcast(sm0_dtype, sm1_dtype) == sm1_dtype: + lst2.append((ptb.cast(sm0, sm1_dtype), sm1)) lst = lst2 @@ -385,20 +389,15 @@ def item_to_var(t): if not M_j.type.in_same_class(M_i.type): continue - # print 'TRYING', (s_i, M_i, s_j, M_j) - gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M( fgraph, s_i, M_i, s_j, M_j ) - # print 'GOT IT', gemm_of_sM_list if gemm_of_sM_list: - assert len(gemm_of_sM_list) == 1 + [new_add_inp] = gemm_of_sM_list add_inputs = [ item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) ] - add_inputs.extend(gemm_of_sM_list) - rval = [variadic_add(*add_inputs)] - # print "RETURNING GEMM THING", rval + rval = [variadic_add(*add_inputs, new_add_inp)] return rval, old_dot22 From a3bcf0c9e9da767047b5bbdcbdd5583c080d8675 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 22:31:04 +0200 Subject: [PATCH 22/33] cache _upcast_impl --- pytensor/scalar/basic.py | 60 ++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 9b27c369f3..742e35d742 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -13,6 +13,7 @@ import builtins import math from collections.abc import Callable +from functools import lru_cache from itertools import chain from textwrap import dedent from typing import Any, TypeAlias @@ -57,40 +58,51 @@ class IntegerDivisionError(Exception): """ -def upcast(dtype, *dtypes) -> str: +@lru_cache +def _upcast_pairwise(dtype1, dtype2=None, *, cast_policy, floatX): # This tries to keep data in floatX or lower precision, unless we # explicitly request a higher precision datatype. - keep_float32 = [ - (config.cast_policy == "numpy+floatX" and config.floatX == "float32") - ] - keep_float16 = [ - (config.cast_policy == "numpy+floatX" and config.floatX == "float16") - ] - - def make_array(dt): - if dt == "float64": - # There is an explicit float64 dtype: we cannot keep float32. - keep_float32[0] = False - keep_float16[0] = False - if dt == "float32": - keep_float16[0] = False - return np.zeros((), dtype=dt) - - z = make_array(dtype) - for dt in dtypes: - z = z + make_array(dt=dt) - rval = str(z.dtype) + if dtype1 == "float64": + keep_float32, keep_float16 = False, False + else: + keep_float32 = cast_policy == "numpy+floatX" and floatX == "float32" + keep_float16 = cast_policy == "numpy+floatX" and floatX == "float16" + + if dtype2 is not None: + if dtype2 == "float64": + keep_float32, keep_float16 = False, False + elif dtype2 == "float32": + keep_float16 = False + + if dtype2 is None: + rval = dtype1 + else: + rval = (np.zeros((), dtype=dtype1) + np.zeros((), dtype=dtype2)).dtype.name + if rval == "float64": - if keep_float16[0]: + if keep_float16: return "float16" - if keep_float32[0]: + if keep_float32: return "float32" elif rval == "float32": - if keep_float16[0]: + if keep_float16: return "float16" return rval +def upcast(dtype, *dtypes) -> str: + # This tries to keep data in floatX or lower precision, unless we + # explicitly request a higher precision datatype. + floatX = config.floatX + cast_policy = config.cast_policy + res_dtype = _upcast_pairwise(dtype, cast_policy=cast_policy, floatX=floatX) + for dt in dtypes: + res_dtype = _upcast_pairwise( + res_dtype, dt, cast_policy=cast_policy, floatX=floatX + ) + return res_dtype + + def as_common_dtype(*vars): """ For for pytensor.scalar.ScalarType and TensorVariable. From 4753a08bddde708ab15b07b5e576a528f76c1a25 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 23:13:02 +0200 Subject: [PATCH 23/33] Gemm optimizer spends too much time creating constants of the wrong type and then casting them --- pytensor/tensor/rewriting/blas.py | 39 +++++++++++++++++-------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index 34273992c3..c6eabd2585 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -83,7 +83,13 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError -from pytensor.tensor import basic as ptb +from pytensor.tensor import as_tensor_variable +from pytensor.tensor.basic import ( + AllocEmpty, + cast, + get_underlying_scalar_constant_value, + zeros, +) from pytensor.tensor.blas import ( Dot22, _batched_dot, @@ -143,7 +149,7 @@ def _as_scalar(res, dtype=None): # as the cast of the scalar can be done before or after the dot22 # and this will give the same result. if pytensor.scalar.upcast(res.dtype, dtype) == dtype: - return ptb.cast(rval, dtype) + return cast(rval, dtype) else: return None @@ -358,13 +364,13 @@ def _gemm_from_factored_list(fgraph, lst): # sM can be a tuple of 2 elements or an PyTensor variable. if isinstance(sM, tuple): sm0, sm1 = sM - sm0 = ptb.as_tensor_variable(sm0) - sm0_dtype = sm0.type.dtype sm1_dtype = sm1.type.dtype + sm0 = as_tensor_variable(sm0, dtype=sm1_dtype) + sm0_dtype = sm0.type.dtype if sm0_dtype == sm1_dtype: lst2.append((sm0, sm1)) elif upcast(sm0_dtype, sm1_dtype) == sm1_dtype: - lst2.append((ptb.cast(sm0, sm1_dtype), sm1)) + lst2.append((cast(sm0, sm1_dtype), sm1)) lst = lst2 @@ -654,7 +660,7 @@ def local_gemm_to_ger(fgraph, node): xv = x.dimshuffle(0) yv = y.dimshuffle(1) try: - bval = ptb.get_underlying_scalar_constant_value(b) + bval = get_underlying_scalar_constant_value(b) except NotScalarConstantError: # b isn't a constant, GEMM is doing useful pre-scaling return @@ -663,8 +669,7 @@ def local_gemm_to_ger(fgraph, node): rval = ger(z, a, xv, yv) new_out = [rval] elif bval == 0: # GER on zeros_like should be faster than GEMM - zeros = ptb.zeros([x.shape[0], y.shape[1]], x.dtype) - rval = ger(zeros, a, xv, yv) + rval = ger(zeros([x.shape[0], y.shape[1]], x.dtype), a, xv, yv) new_out = [rval] else: # if bval is another constant, then z is being usefully @@ -681,32 +686,32 @@ def local_dot22_to_ger_or_gemv(fgraph, node): x, y = node.inputs xb = x.broadcastable yb = y.broadcastable - one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype)) - zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype)) + one = as_tensor_variable(np.asarray(1, dtype=x.dtype)) + zero = as_tensor_variable(np.asarray(0, dtype=x.dtype)) if xb[1] and yb[0]: # x and y are both vectors so this might qualifies for a GER xv = x.dimshuffle(0) yv = y.dimshuffle(1) - zeros = ptb.zeros([x.shape[0], y.shape[1]], dtype=x.dtype) + zeros = zeros([x.shape[0], y.shape[1]], dtype=x.dtype) rval = ger(zeros, one, xv, yv) new_out = [rval] elif xb[0] and yb[1]: # x and y are both vectors so this qualifies for a sdot / ddot # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not xv = x.dimshuffle(1) - zeros = ptb.AllocEmpty(x.dtype)(1) + zeros = AllocEmpty(x.dtype)(1) rval = gemv_no_inplace(zeros, one, y.T, xv, zero) new_out = [rval.dimshuffle("x", 0)] elif xb[0] and not yb[0] and not yb[1]: # x is vector, y is matrix so try gemv xv = x.dimshuffle(1) - zeros = ptb.AllocEmpty(x.dtype)(y.shape[1]) + zeros = AllocEmpty(x.dtype)(y.shape[1]) rval = gemv_no_inplace(zeros, one, y.T, xv, zero) new_out = [rval.dimshuffle("x", 0)] elif not xb[0] and not xb[1] and yb[1]: # x is matrix, y is vector, try gemv yv = y.dimshuffle(0) - zeros = ptb.AllocEmpty(x.dtype)(x.shape[0]) + zeros = AllocEmpty(x.dtype)(x.shape[0]) rval = gemv_no_inplace(zeros, one, x, yv, zero) new_out = [rval.dimshuffle(0, "x")] else: @@ -841,9 +846,7 @@ def local_dot22_to_dot22scalar(fgraph, node): " matrix type" ) return False - a = ptb.cast( - _as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype - ) + a = cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype) assert not a.type.ndim dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a) @@ -881,7 +884,7 @@ def local_dot22_to_dot22scalar(fgraph, node): o.remove(d) o.remove(s) - a = ptb.cast(i_scalar[scalar_idx], d.type.dtype) + a = cast(i_scalar[scalar_idx], d.type.dtype) assert not a.type.ndim if len(o) == 0: return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)] From 4f09fb7bf856347ff3ed172c6932ea49b335722f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 6 Sep 2025 00:11:16 +0200 Subject: [PATCH 24/33] Fail fast scan memory inplace --- pytensor/scan/rewriting.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 09793ab15a..c946039f0f 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -1048,8 +1048,13 @@ def attempt_scan_inplace( return None def apply(self, fgraph): + scan_nodes = {node for node in fgraph.apply_nodes if isinstance(node.op, Scan)} + + if not scan_nodes: + return + for scan_idx, original_node in enumerate(reversed(fgraph.toposort())): - if not isinstance(original_node.op, Scan): + if original_node not in scan_nodes: continue # First attempt to make the Scan compute inplace every recurrent From 272816e065fa3a3bd74d9cfd5438075711af03cd Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 7 Sep 2025 18:14:44 +0200 Subject: [PATCH 25/33] Fast Scan equality for same identity --- pytensor/scan/op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 05a860584e..3b43feaf4b 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1248,6 +1248,9 @@ def is_cpu_vector(s): return apply_node def __eq__(self, other): + if self is other: + return True + if type(self) is not type(other): return False From 90ab712545528b8198e1c085c2325d53771c858c Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 11:56:28 +0200 Subject: [PATCH 26/33] .speedup composite rewrites This should be cleaned up into nicer git commits --- pytensor/scalar/basic.py | 77 +++++++++++++++++---------- pytensor/tensor/rewriting/elemwise.py | 57 +++++++++++++------- 2 files changed, 87 insertions(+), 47 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 742e35d742..d4dd29f492 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1255,7 +1255,9 @@ def make_node(self, *inputs): f"Wrong number of inputs for {self}.make_node " f"(got {len(inputs)}({inputs}), expected {self.nin})" ) - inputs = [as_scalar(input) for input in inputs] + inputs = [ + inp if isinstance(inp, ScalarVariable) else as_scalar(inp) for inp in inputs + ] outputs = [t() for t in self.output_types([input.type for input in inputs])] if len(outputs) != self.nout: inputs_str = (", ".join(str(input) for input in inputs),) @@ -4294,7 +4296,13 @@ class Composite(ScalarInnerGraphOp): init_param: tuple[str, ...] = ("inputs", "outputs") def __init__( - self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True + self, + inputs, + outputs, + name="Composite", + clone_graph: builtins.bool = True, + cleanup_graph: builtins.bool = True, + output_types_preference=None, ): self.name = name self._name = None @@ -4324,7 +4332,9 @@ def __init__( # 1. Create a new graph from inputs up to the # Composite res = pytensor.compile.rebuild_collect_shared( - inputs=inputs, outputs=outputs[0].owner.inputs, copy_inputs_over=False + inputs=inputs, + outputs=outputs[0].owner.inputs, + copy_inputs_over=False, ) # Clone also the inputs # 2. We continue this partial clone with the graph in # the inner Composite @@ -4338,36 +4348,42 @@ def __init__( assert res[0] != inputs inputs, outputs = res[0], res2[1] - # We already cloned the graph, or the user told us there was no need for it - self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False) + if cleanup_graph: + # We already cloned the graph, or the user told us there was no need for it + self.inputs, self.outputs = self._cleanup_graph( + inputs, outputs, clone=False + ) + else: + self.inputs, self.outputs = inputs, outputs self.inputs_type = tuple(input.type for input in self.inputs) self.outputs_type = tuple(output.type for output in self.outputs) self.nin = len(inputs) self.nout = len(outputs) - super().__init__() + super().__init__(output_types_preference=output_types_preference) def __str__(self): if self._name is not None: return self._name - # Rename internal variables - for i, r in enumerate(self.fgraph.inputs): - r.name = f"i{i}" - for i, r in enumerate(self.fgraph.outputs): - r.name = f"o{i}" - io = set(self.fgraph.inputs + self.fgraph.outputs) - for i, r in enumerate(self.fgraph.variables): - if ( - not isinstance(r, Constant) - and r not in io - and len(self.fgraph.clients[r]) > 1 - ): - r.name = f"t{i}" + fgraph = self.fgraph - if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10: + if len(fgraph.outputs) > 1 or len(fgraph.apply_nodes) > 10: self._name = "Composite{...}" else: - outputs_str = ", ".join(pprint(output) for output in self.fgraph.outputs) + # Rename internal variables + for i, r in enumerate(fgraph.inputs): + r.name = f"i{i}" + for i, r in enumerate(fgraph.outputs): + r.name = f"o{i}" + io = set(fgraph.inputs + fgraph.outputs) + for i, r in enumerate(fgraph.variables): + if ( + not isinstance(r, Constant) + and r not in io + and len(fgraph.clients[r]) > 1 + ): + r.name = f"t{i}" + outputs_str = ", ".join(pprint(output) for output in fgraph.outputs) self._name = f"Composite{{{outputs_str}}}" return self._name @@ -4380,12 +4396,16 @@ def make_new_inplace(self, output_types_preference=None, name=None): """ d = {k: getattr(self, k) for k in self.init_param} - out = self.__class__(**d) - if name: - out.name = name - else: - name = out.name - super(Composite, out).__init__(output_types_preference, name) + out = type(self)( + **d, + cleanup_graph=False, + clone_graph=False, + output_types_preference=output_types_preference, + name=name or self.name, + ) + # No need to recompute the _cocde and nodenames if they were already computed (which is true if the hash of the Op was requested) + out._c_code = self._c_code + out.nodenames = self.nodenames return out @property @@ -4452,9 +4472,10 @@ def c_code_template(self): fg = self.fgraph subd = {e: f"%(i{i})s" for i, e in enumerate(fg.inputs)} + inputs_set = frozenset(fg.inputs) for var in fg.variables: if var.owner is None: - if var not in fg.inputs: + if var not in inputs_set: # This is an orphan if isinstance(var, Constant) and isinstance(var.type, CLinkerType): subd[var] = f"({var.type.c_literal(var.data)})" diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index a862711eab..cee496c384 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -28,8 +28,9 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern -from pytensor.graph.traversal import toposort +from pytensor.graph.traversal import graph_inputs, toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined +from pytensor.scalar import ScalarConstant from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( MakeVector, @@ -885,26 +886,44 @@ def print_profile(stream, prof, level=0): def local_useless_composite_outputs(fgraph, node): """Remove inputs and outputs of Composite Ops that are not used anywhere.""" comp = node.op.scalar_op - used_outputs_idxs = [ - i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern] - ] - used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs] - comp_fgraph = FunctionGraph( - inputs=comp.inputs, outputs=used_inner_outputs, clone=False - ) + + clients = fgraph.clients + outer_inputs, outer_outputs = node.inputs, node.outputs + inner_inputs, inner_outputs = comp.inputs, comp.outputs + + used_inner_outputs = { + inner_out + for inner_out, outer_out in zip(inner_outputs, outer_outputs) + if clients[outer_out] + } + used_inner_inputs = { + inner_inp + for inner_inp in graph_inputs(used_inner_outputs) + if not isinstance(inner_inp, ScalarConstant) + } + + if len(used_inner_inputs) == len(outer_inputs) or len(used_inner_outputs) == len( + outer_outputs + ): + return None + used_inputs_idxs = [ - i - for i, i_intern in enumerate(comp_fgraph.inputs) - if comp_fgraph.clients[i_intern] + i for i, inp in enumerate(inner_inputs) if inp in used_inner_inputs ] - used_inner_inputs = [comp.inputs[i] for i in used_inputs_idxs] - if len(used_inner_inputs) < len(node.inputs) or len(used_inner_outputs) < len( - node.outputs - ): - used_inputs = [node.inputs[i] for i in used_inputs_idxs] - c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) - e = Elemwise(scalar_op=c)(*used_inputs, return_list=True) - return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True)) + used_inner_inputs = [inner_inputs[i] for i in used_inputs_idxs] + used_outer_inputs = [outer_inputs[i] for i in used_inputs_idxs] + + new_comp = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) + new_outer_outputs = Elemwise(scalar_op=new_comp)( + *used_outer_inputs, return_list=True + ) + + used_outer_outputs = ( + outer_outputs[i] + for i, out in enumerate(inner_outputs) + if out in used_inner_outputs + ) + return dict(zip(used_outer_outputs, new_outer_outputs, strict=True)) @node_rewriter([CAReduce]) From 330a1e03273ad19cae6c22baf9b0a97b23397b55 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 20 Sep 2025 16:02:28 +0200 Subject: [PATCH 27/33] Revert "Use non recursive algorithm in `rebuild_collect_shared`" This reverts commit ae43c14685387fe1a763426342bb957db1c2afb6. Breaks the @pytensor_jit example --- pytensor/compile/function/pfunc.py | 87 ++++++++++++++---------------- 1 file changed, 40 insertions(+), 47 deletions(-) diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index 2056d077be..91d6e1a588 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -179,54 +179,47 @@ def clone_v_get_shared_updates(v, copy_inputs_over): """ # this co-recurses with clone_a - stack = [v] - try: - while True: - v = stack.pop() - if v in clone_d: - continue - if (apply := v.owner) is not None: - if all(i in clone_d for i in apply.inputs): - # all inputs have been cloned, we can clone this node - clone_node_and_cache( - apply, - clone_d, - strict=rebuild_strict, - clone_inner_graphs=clone_inner_graphs, + assert v is not None + if v in clone_d: + return clone_d[v] + if v.owner: + owner = v.owner + if owner not in clone_d: + for i in owner.inputs: + clone_v_get_shared_updates(i, copy_inputs_over) + clone_node_and_cache( + owner, + clone_d, + strict=rebuild_strict, + clone_inner_graphs=clone_inner_graphs, + ) + return clone_d.setdefault(v, v) + elif isinstance(v, SharedVariable): + if v not in shared_inputs: + shared_inputs.append(v) + if v.default_update is not None: + # Check that v should not be excluded from the default + # updates list + if no_default_updates is False or ( + isinstance(no_default_updates, list) and v not in no_default_updates + ): + # Do not use default_update if a "real" update was + # provided + if v not in update_d: + v_update = v.type.filter_variable( + v.default_update, allow_convert=False ) - else: - # expand on the inputs - stack.extend(apply.inputs) - else: - clone_d[v] = v if copy_inputs_over else v.clone() - - # Special handling of SharedVariables - if isinstance(v, SharedVariable): - if v not in shared_inputs: - shared_inputs.append(v) - if v.default_update is not None: - # Check that v should not be excluded from the default - # updates list - if no_default_updates is False or ( - isinstance(no_default_updates, list) - and v not in no_default_updates - ): - # Do not use default_update if a "real" update was - # provided - if v not in update_d: - v_update = v.type.filter_variable( - v.default_update, allow_convert=False - ) - if not v.type.is_super(v_update.type): - raise TypeError( - "An update must have a type compatible with " - "the original shared variable" - ) - update_d[v] = v_update - update_expr.append((v, v_update)) - except IndexError: - pass # stack is empty - return clone_d[v] + if not v.type.is_super(v_update.type): + raise TypeError( + "An update must have a type compatible with " + "the original shared variable" + ) + update_d[v] = v_update + update_expr.append((v, v_update)) + if not copy_inputs_over: + return clone_d.setdefault(v, v.clone()) + else: + return clone_d.setdefault(v, v) # initialize the clone_d mapping with the replace dictionary if replace is None: From 9ffd97f8ec905718d6da558c740f7ab50dc6b4ec Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 23 Sep 2025 15:25:56 +0200 Subject: [PATCH 28/33] Speedup gradient --- pytensor/gradient.py | 563 +++++++++++++++++---------------------- pytensor/scalar/basic.py | 1 + 2 files changed, 248 insertions(+), 316 deletions(-) diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 5924fd7fcb..0944e848e0 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -4,6 +4,8 @@ import warnings from collections.abc import Callable, Mapping, MutableSequence, Sequence from functools import partial, reduce +from itertools import chain +from operator import add as operator_add from typing import TYPE_CHECKING, Literal, TypeVar, Union, overload import numpy as np @@ -12,9 +14,8 @@ from pytensor.compile.ops import ViewOp from pytensor.configdefaults import config from pytensor.graph import utils, vectorize_graph -from pytensor.graph.basic import Apply, NominalVariable, Variable +from pytensor.graph.basic import Apply, Variable from pytensor.graph.null_type import NullType, null_type -from pytensor.graph.op import get_test_values from pytensor.graph.type import Type @@ -1162,381 +1163,311 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): # its inputs' gradients term_dict = {} - def access_term_cache(node): - """Populates term_dict[node] and returns it""" + from pytensor.scalar import discrete_dtypes, float_dtypes - if node not in term_dict: - inputs = node.inputs + discrete_dtypes_set = frozenset(discrete_dtypes) + float_dtypes_set = frozenset(float_dtypes) - output_grads = [access_grad_cache(var) for var in node.outputs] + def access_term_cache(node): + """Populates term_dict[node] and returns it""" + try: + return term_dict[node] + except KeyError: + pass - # list of bools indicating if each output is connected to the cost - outputs_connected = [ - not isinstance(g.type, DisconnectedType) for g in output_grads - ] + inputs = node.inputs + output_grads = [access_grad_cache(var) for var in node.outputs] + output_grads_connected = [ + not isinstance(g.type, DisconnectedType) for g in output_grads + ] - connection_pattern = _node_to_pattern(node) + connection_pattern = ( + None + if not hasattr(node.op, "connection_pattern") + else _node_to_pattern(node) + ) - # list of bools indicating if each input is connected to the cost + # list of bools indicating if each input is connected to the cost + if connection_pattern is None: + inputs_connected = [any(output_grads_connected)] * len(inputs) + else: inputs_connected = [ ( any( - input_to_output and output_to_cost - for input_to_output, output_to_cost in zip( - input_to_outputs, outputs_connected, strict=True + in_to_out and out_to_cost + for in_to_out, out_to_cost in zip( + input_to_outputs, output_grads_connected ) ) ) for input_to_outputs in connection_pattern ] - # List of bools indicating if each output is an integer dtype - output_is_int = [ - hasattr(output.type, "dtype") - and output.type.dtype in pytensor.tensor.type.discrete_dtypes - for output in node.outputs - ] - - # List of bools indicating if each output is NullType - ograd_is_nan = [ - isinstance(output.type, NullType) for output in output_grads - ] + if not any(inputs_connected): + # All outputs of this op are disconnected so we can skip + # Calling the op's grad method and report that the inputs are disconnected + # (The op's grad method could do this too, but this saves the + # implementer the trouble of worrying about this case) + term_dict[node] = input_grads = [disconnected_type() for _ in inputs] + return input_grads + + # List of bools indicating if each output is NullType + output_grads_null = [ + isinstance(output.type, NullType) for output in output_grads + ] - # List of bools indicating if each input only has NullType outputs + # List of bools indicating if each input only has NullType outputs + if connection_pattern is None: + only_connected_to_nan = [all(output_grads_null)] * len(inputs) + else: only_connected_to_nan = [ ( not any( in_to_out and out_to_cost and not out_nan for in_to_out, out_to_cost, out_nan in zip( - in_to_outs, outputs_connected, ograd_is_nan, strict=True + in_to_outs, output_grads_connected, output_grads_null ) ) ) for in_to_outs in connection_pattern ] - if not any(inputs_connected): - # All outputs of this op are disconnected so we can skip - # Calling the op's grad method and report that the inputs - # are disconnected - # (The op's grad method could do this too, but this saves the - # implementer the trouble of worrying about this case) - input_grads = [disconnected_type() for ipt in inputs] - elif all(only_connected_to_nan): - # All inputs are only connected to nan gradients, so we don't - # need to bother calling the grad method. We know the gradient - # with respect to all connected inputs is nan. - input_grads = [] - for connected in inputs_connected: - if connected: - input_grads.append(null_type()) - else: - input_grads.append(disconnected_type()) - else: - # At least one input of this op is connected to the cost so and - # not all output gradients are undefined so we must - # call the op's grad method - - # Each Op's grad function requires inputs and output_grads - # If the Op destroys any input, but the grad expression uses - # it, then chances are the resulting graph will have a - # dependency cycle. We avoid this cycle by passing (symbolic) - # copies of each destroyed input. + if all(only_connected_to_nan): + # All inputs are only connected to nan gradients, so we don't + # need to bother calling the grad method. We know the gradient + # with respect to all connected inputs is nan. + term_dict[node] = input_grads = [ + null_type() if connected else disconnected_type() + for connected in inputs_connected + ] + return input_grads + + # At least one input of this op is connected to the cost so and + # not all output gradients are undefined so we must + # call the op's grad method + + # Each Op's grad function requires inputs and output_grads + # If the Op destroys any input, but the grad expression uses + # it, then chances are the resulting graph will have a + # dependency cycle. We avoid this cycle by passing (symbolic) + # copies of each destroyed input. + if (destroy_map := getattr(node.op, "destroy_map", None)) is not None: + # Is this just an index? + dinputs = frozenset(chain.from_iterable(destroy_map.values())) + + def try_to_copy(var): try: - dinputs = [node.inputs[x[0]] for x in node.op.destroy_map.values()] + return var.copy() except AttributeError: - dinputs = [] - - def try_to_copy_if_needed(var): - if var in dinputs and hasattr(var, "copy"): - return var.copy() return var - inputs = [try_to_copy_if_needed(ipt) for ipt in inputs] - - # Build a list of output gradients with the same dtype as - # the corresponding output variable. - # If an output is of a float dtype, we want to cast the - # output gradient into the same dtype, to avoid having a - # gradient graph with double precision (taking more memory, - # and more computation). - # If an output is of an integer dtype, then we just leave it - # alone. - # DO NOT force integer variables to have zero grad. This causes - # bugs where we fail to detect disconnected or undefined - # gradients. - # DO NOT force integer variables to have integer dtype. - # This is a violation of the op contract. - new_output_grads = [] - for o, og in zip(node.outputs, output_grads, strict=True): - o_dt = getattr(o.type, "dtype", None) - og_dt = getattr(og.type, "dtype", None) - if ( - o_dt not in pytensor.tensor.type.discrete_dtypes - and og_dt - and o_dt != og_dt - ): - new_output_grads.append(og.astype(o_dt)) - else: - new_output_grads.append(og) - - # Make sure that, if new_output_grads[i] has a floating point - # dtype, it is the same dtype as outputs[i] - for o, ng in zip(node.outputs, new_output_grads, strict=True): - o_dt = getattr(o.type, "dtype", None) - ng_dt = getattr(ng.type, "dtype", None) - if ( - ng_dt is not None - and o_dt not in pytensor.tensor.type.discrete_dtypes - ): - assert ng_dt == o_dt + inputs = [ + try_to_copy(ipt) if ipt_idx in dinputs else ipt + for ipt_idx, ipt in enumerate(inputs) + ] - assert all( - getattr(ng.type, "dtype", None) - not in pytensor.tensor.type.discrete_dtypes - for ng in new_output_grads + # Build a list of output gradients with the same dtype as + # the corresponding output variable. + # If an output is of a float dtype, we want to cast the + # output gradient into the same dtype, to avoid having a + # gradient graph with double precision (taking more memory, + # and more computation). + # If an output is of an integer dtype, then we just leave it alone. + # DO NOT force integer variables to have zero grad. This causes + # bugs where we fail to detect disconnected or undefined gradients. + # DO NOT force integer variables to have integer dtype. + # This is a violation of the op contract. + + for o_idx, (o, og) in enumerate(zip(node.outputs, output_grads)): + try: + og_dt = og.type.dtype + o_dt = o.type.dtype + except AttributeError: + continue + if o_dt not in discrete_dtypes and o_dt != og_dt: + output_grads[o_idx] = og.astype(o_dt) + + input_grads = node.op.L_op(inputs, node.outputs, output_grads) + + if input_grads is None: + raise TypeError(f"{node.op}.grad returned None, expected iterable.") + + if len(input_grads) != len(inputs): + raise ValueError(f"{node.op} returned the wrong number of gradient terms.") + + # Need to propagate the NullType gradients; if an input grad is + # not disconnected and the corresponding input is connected + # to at least one output whose gradient is NullType then the input + # grad should be NullType. + # TODO: Why are we the ones enforcing this? + if any(output_grads_null): + if connection_pattern is None: + nan_ograd = next( + ograd for ograd in output_grads if isinstance(ograd.type, NullType) ) - - # If config.compute_test_value is turned on, check that the - # gradients on the outputs of this node have the right shape. - # We also check the gradient on the inputs later--both checks - # are needed, because some gradients are only ever specified - # by the user, not computed by Op.grad, and some gradients are - # only computed and returned, but never passed as another - # node's output grads. - for idx, packed in enumerate( - zip(node.outputs, new_output_grads, strict=True) + input_grads = [ + inp_grad + if isinstance(inp_grad.type, DisconnectedType) + else nan_ograd + for inp_grad in input_grads + ] + else: + # convert to list so we can modify it inplace + input_grads = list(input_grads) + for inp_idx, (inp_grad, in_to_outs) in enumerate( + zip(input_grads, connection_pattern) ): - orig_output, new_output_grad = packed - if not hasattr(orig_output, "shape"): - continue - if isinstance(new_output_grad.type, DisconnectedType): + if isinstance(inp_grad.type, DisconnectedType): continue - for orig_output_v, new_output_grad_v in get_test_values(*packed): - o_shape = orig_output_v.shape - g_shape = new_output_grad_v.shape - if o_shape != g_shape: - raise ValueError( - "Got a gradient of shape " - + str(o_shape) - + " on an output of shape " - + str(g_shape) - ) - - input_grads = node.op.L_op(inputs, node.outputs, new_output_grads) - - if input_grads is None: - raise TypeError( - f"{node.op}.grad returned NoneType, expected iterable." - ) - - if len(input_grads) != len(inputs): - raise ValueError( - f"{node.op} returned the wrong number of gradient terms." - ) - # We can not enforce this, as AdvancedSubtensor1 has an option to - # return the sparse grad for optimization reason. - - # for ig, i in zip(input_grads, inputs): - # if (not isinstance(ig.type, (DisconnectedType, NullType)) and - # type(ig.type) != type(i.type)): - # raise ValueError( - # "%s returned the wrong type for gradient terms." - # " Sparse inputs must have sparse grads and dense" - # " inputs must have dense grad. Got %s, expected %s" %( - # str(node.op), ig.type, i.type)) - - # must convert to list in case the op returns a tuple - # we won't be able to post-process out the Nones if it does that - input_grads = list(input_grads) - - # Need to propagate the NullType gradients; if an input grad is - # not disconnected and the corresponding input is connected - # to at least one output whose gradient is NullType then the input - # grad should be NullType. - for inp_idx in range(len(input_grads)): - for out_idx in range(len(ograd_is_nan)): - if ( - ograd_is_nan[out_idx] - and connection_pattern[inp_idx][out_idx] - and not isinstance(input_grads[inp_idx].type, DisconnectedType) + for in_to_out, ograd_null, ograd in zip( + in_to_outs, output_grads_null, output_grads ): - input_grads[inp_idx] = output_grads[out_idx] - - # Do type checking on the result + if in_to_out and ograd_null: + input_grads[inp_idx] = ograd + break + + # List of bools indicating if each output is an integer dtype + output_is_int = [ + getattr(output.type, "dtype", "float64") in discrete_dtypes_set + for output in node.outputs + ] - # List of bools indicating if each input only has integer outputs + # List of bools indicating if each input only has integer outputs + if connection_pattern is None: + only_connected_to_int = [all(output_is_int)] * len(inputs) + else: only_connected_to_int = [ ( - True - not in [ + not any( in_to_out and out_to_cost and not out_int for in_to_out, out_to_cost, out_int in zip( - in_to_outs, outputs_connected, output_is_int, strict=True + in_to_outs, output_grads_connected, output_is_int ) - ] + ) ) for in_to_outs in connection_pattern ] - for i, term in enumerate(input_grads): - # Disallow Nones - if term is None: - # We don't know what None means. in the past it has been - # used to mean undefined, zero, or disconnected. - # We therefore don't allow it because its usage has become - # so muddied. + # Do type checking on the result + for i, ( + inp, + term, + term_connected, + term_only_connected_to_nan, + term_only_connected_to_int, + ) in enumerate( + zip( + inputs, + input_grads, + inputs_connected, + only_connected_to_nan, + only_connected_to_int, + ) + ): + if term is None: + # We don't know what None means. in the past it has been used to mean undefined, zero, or disconnected. + raise TypeError( + f"{node.op} grad returned None for a gradient term.\n" + "Instead, it should return zeros_like(input), disconnected_type(), or a NullType variable " + "such as those created by the grad_undefined or grad_unimplemented helpers." + ) + + term_type = term.type + + if term_only_connected_to_nan: + assert isinstance(term_type, NullType | DisconnectedType) + + elif not isinstance(term_type, NullType | DisconnectedType): + if term_type.dtype not in float_dtypes_set: raise TypeError( - f"{node.op}.grad returned None for a gradient term, " - "this is prohibited. Instead of None," - "return zeros_like(input), disconnected_type()," - " or a NullType variable such as those made with " - "the grad_undefined or grad_unimplemented helper " - "functions." + f"{node.op} grad illegally returned an integer-valued variable for input index {i} " + f"with dtype {term_type.dtype}." ) - # Check that the gradient term for this input - # has the right shape - if hasattr(term, "shape"): - orig_ipt = inputs[i] - if not isinstance(orig_ipt, NominalVariable): - for orig_ipt_v, term_v in get_test_values(orig_ipt, term): - i_shape = orig_ipt_v.shape - t_shape = term_v.shape - if i_shape != t_shape: - raise ValueError( - f"{node.op}.grad returned object of " - f"shape {t_shape} as gradient term on input {int(i)} " - f"of shape {i_shape}" - ) - - if not isinstance(term.type, NullType | DisconnectedType): - if term.type.dtype not in pytensor.tensor.type.float_dtypes: - raise TypeError( - str(node.op) + ".grad illegally " - " returned an integer-valued variable." - f" (Input index {int(i)}, dtype {term.type.dtype})" - ) + if ( + inp_ndim := getattr(inp.type, "ndim", None) + ) is not None and inp_ndim != term_type.ndim: + raise ValueError( + f"{node.op}.grad returned a term with {term_type.ndim} " + f"dimensions for input {i}, but {inp.type.ndim} are required." + ) - if only_connected_to_nan[i]: - assert isinstance(term.type, NullType) - - if only_connected_to_int[i]: - # This term has only integer outputs and we know - # it's not undefined or disconnected - # The only other valid thing it can be is 0 - - is_zero = _is_zero(term) - assert is_zero in ("yes", "no", "maybe") - if is_zero == "maybe": - msg = ( - f"{node.op}.grad returned {term} of type {type(term)} for input" - f" {i}. This input's only connections to " - "the cost through this op are via " - "integer-valued outputs so it should be " - "NullType, DisconnectedType, or some form " - "of zeros. It is not NullType or " - "DisconnectedType and pytensor can't " - "simplify it to a constant, so it's not " - "verifiably zeros." - ) - elif is_zero == "no": - msg = ( - f"{node.op}.grad returned {term} of type {type(term)} for input" - f" {i}. Since this input is only connected " - "to integer-valued outputs, it should " - "evaluate to zeros, but it evaluates to" - f"{pytensor.get_underlying_scalar_constant_value(term)}." - ) - raise ValueError(msg) - - # Check that op.connection_pattern matches the connectivity - # logic driving the op.grad method - for i, (ipt, ig, connected) in enumerate( - zip(inputs, input_grads, inputs_connected, strict=True) - ): - actually_connected = not isinstance(ig.type, DisconnectedType) - - if actually_connected and not connected: - msg = ( - f"{node.op}.grad returned {ig} of type {ig.type} for input {i}." - " Expected DisconnectedType instance based on " - " the output of the op's connection_pattern " - "method." + if term_only_connected_to_int and _is_zero(term) == "no": + raise ValueError( + f"{node.op}.grad returned {term} of type {type(term)} for input {i}" + f"Since this input is only connected to integer-valued outputs, it should evaluate to zeros, " + f" but it evaluates to {pytensor.get_underlying_scalar_constant_value(term)}." ) - raise TypeError(msg) - elif connected and not actually_connected: - msg = f"{node.op}.grad returned DisconnectedType for input {i}." + # Check that op.connection_pattern matches the connectivity logic driving the op.grad method + if term_connected != ( + actually_connected := not isinstance(term_type, DisconnectedType) + ): + if actually_connected: + raise TypeError( + f"{node.op}.grad returned an input gradient of type {term_type} for input {i}.\n" + "Expected DisconnectedType instance based on the output of the op's connection_pattern method." + ) + else: + msg = f"{node.op}.grad returned DisconnectedType for input {i}. " if hasattr(node.op, "connection_pattern"): - msg += " Its connection_pattern method does not allow this." - raise TypeError(msg) + raise TypeError( + msg + "Its connection_pattern method does not allow this." + ) else: - msg += ( - " You may want to implement a " - "connection_pattern method for it." + warnings.warn( + msg + + "You may want to implement a connection_pattern method for it." ) - warnings.warn(msg) - - # cache the result - term_dict[node] = input_grads - return term_dict[node] + term_dict[node] = input_grads + return input_grads # populate grad_dict[var] and return it def access_grad_cache(var): - if var not in grad_dict: - # If var is not in grad_dict already, we must compute it - if var in var_to_app_to_idx: - null_terms = [] - terms = [] - node_to_idx = var_to_app_to_idx[var] - for node in node_to_idx: - for idx in node_to_idx[node]: - term = access_term_cache(node)[idx] - - if not isinstance(term, Variable): - raise TypeError( - f"{node.op}.grad returned {type(term)}, expected" - " Variable instance." - ) - - if isinstance(term.type, NullType): - null_terms.append(term) - continue - - # Don't try to sum up DisconnectedType placeholders - if isinstance(term.type, DisconnectedType): - continue - - if hasattr(var, "ndim") and term.ndim != var.ndim: - raise ValueError( - f"{node.op}.grad returned a term with" - f" {int(term.ndim)} dimensions, but {int(var.ndim)} are required." - ) - - terms.append(term) + try: + return grad_dict[var] + except KeyError: + pass - # Add up the terms to get the total gradient on this variable - if len(null_terms) > 0: - # At least one term is a NullType : the total gradient - # will also be a NullType - grad_dict[var] = null_terms[0] - elif len(terms) > 0: - # the next line is like sum(terms) but doesn't add an - # extraneous TensorConstant(0) - grad_dict[var] = reduce(lambda x, y: x + y, terms) - else: - grad_dict[var] = disconnected_type() + # If var is not in grad_dict already, we must compute it + if (node_to_idx := var_to_app_to_idx.get(var, None)) is None: + # this variable isn't connected to the cost in the computational graph + grad_var = disconnected_type() + + else: + terms = [] + for node, indices in node_to_idx.items(): + node_terms = access_term_cache(node) + terms.extend( + term + for idx in indices + # Don't include disconnected terms in the sum + if not isinstance((term := node_terms[idx]).type, DisconnectedType) + ) - if cost_name is not None and var.name is not None: - grad_dict[var].name = f"(d{cost_name}/d{var.name})" + if terms: + # Add up the terms to get the total gradient on this variable + try: + grad_var = reduce(operator_add, terms) + except TypeError: + # This should only happen when there's a NullType term + try: + grad_var = next( + t for t in terms if isinstance(t.type, NullType) + ) + except StopIteration: + raise TypeError( + f"The gradient terms of variable {var} could not be added together: {terms}." + ) else: - # this variable isn't connected to the cost in the - # computational graph - grad_dict[var] = disconnected_type() - # end if cache miss - return grad_dict[var] + grad_var = disconnected_type() + + if cost_name is not None and var.name is not None: + grad_var.name = f"(d{cost_name}/d{var.name})" + + grad_dict[var] = grad_var + return grad_var rval = [access_grad_cache(elem) for elem in wrt] diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index d4dd29f492..52f8bb05dd 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -833,6 +833,7 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType: continuous_types: _ScalarTypes = float_types + complex_types all_types: _ScalarTypes = discrete_types + continuous_types +float_dtypes = tuple(t.dtype for t in float_types) discrete_dtypes = tuple(t.dtype for t in discrete_types) From 2349641409d09b95aa1b1ceaf85f2d1f0bf11c5b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Oct 2025 17:15:57 +0200 Subject: [PATCH 29/33] .speedup stuff --- pytensor/graph/rewriting/basic.py | 18 ++-- pytensor/tensor/rewriting/math.py | 113 +++++++++------------ pytensor/tensor/variable.py | 41 ++------ tests/compile/function/test_types.py | 140 ++++++++++++++++++++++++++- 4 files changed, 206 insertions(+), 106 deletions(-) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index c373bf42a9..0f7a1ed53e 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -1000,7 +1000,11 @@ def tracks(self): return self._tracks def __str__(self): - return getattr(self, "__name__", repr(self)) + try: + return self.__name__ + except AttributeError: + self.__name__ = name = self.__repr__() + return name def __repr__(self): return f"FromFunctionNodeRewriter({self.fn!r}, {self._tracks!r}, {self.requirements!r})" @@ -1215,11 +1219,13 @@ def __init__( self.tracker.add_tracker(o) def __str__(self): - return getattr( - self, - "__name__", - f"{type(self).__name__}({','.join(str(o) for o in self.rewrites)})", - ) + try: + return self.__name__ + except AttributeError: + self.__name__ = name = ( + f"{type(self).__name__}({','.join(str(o) for o in self.rewrites)})" + ) + return name def tracks(self): t = [] diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index e313163b6e..a63b58261c 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1,9 +1,9 @@ r"""Rewrites for the `Op`\s in :mod:`pytensor.tensor.math`.""" import itertools -import operator from collections import defaultdict from functools import partial, reduce +from operator import itemgetter import numpy as np @@ -1014,6 +1014,7 @@ def __init__(self, main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=Tr self.main = main self.inverse = inverse_fn self.reciprocal = reciprocal_fn + self.ops = (self.main, self.inverse, self.reciprocal) self.calculate = calculate self.use_reciprocal = use_reciprocal @@ -1053,68 +1054,38 @@ def get_num_denum(self, inp): # internal data nodes all have the dtype of the 'input' # argument. The leaf-Variables of the graph covered by the # recursion may be of any Variable type. - - if inp.owner is None or inp.owner.op not in [ - self.main, - self.inverse, - self.reciprocal, - ]: - if inp.owner and isinstance(inp.owner.op, DimShuffle): - # If input is a DimShuffle of some input which does - # something like this: - - # * change a vector of length N into a 1xN row matrix - # * change a scalar into a 1x1x1 tensor - # * in general, complete the shape of a tensor - # with broadcastable 1s to the *left* - # Then we will simply discard the DimShuffle and return - # the num/denum of its input - dsn = inp.owner # dimshuffle node - dsop = dsn.op # dimshuffle op - - # the first input of the dimshuffle i.e. the ndarray to redim - dsi0 = dsn.inputs[0] - - # The compatible order is a DimShuffle "new_order" of the form: - # ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim) - - # That kind of DimShuffle only adds broadcastable - # dimensions on the left, without discarding any - # existing broadcastable dimension and is inserted - # automatically by Elemwise when the inputs have - # different numbers of dimensions (hence why we can - # discard its information - we know we can retrieve it - # later on). - compatible_order = ("x",) * (inp.type.ndim - dsi0.type.ndim) + tuple( - range(dsi0.type.ndim) - ) - if dsop.new_order == compatible_order: - # If the "new_order" is the one we recognize, - # we return the num_denum of the dimshuffled input. - return self.get_num_denum(inp.owner.inputs[0]) - else: - # This is when the input isn't produced by main, - # inverse or reciprocal. - return [inp], [] + parent = inp.owner + if parent is None or parent.op not in self.ops: + if ( + parent is not None + and isinstance(ds_op := parent.op, DimShuffle) + and ds_op.is_left_expand_dims + ): + # If input is a left_expand_dims DimShuffle, + # the kind of which is inserted automatically by Elemwise + # we return the num_denum of the dimshuffled input. + return self.get_num_denum(parent.inputs[0]) else: return [inp], [] + num = [] denum = [] - parent = inp.owner # We get the (num, denum) pairs for each input # pairs = [self.get_num_denum(input2) if input2.type.dtype == # input.type.dtype else ([input2], []) for input2 in # parent.inputs] - pairs = [self.get_num_denum(input2) for input2 in parent.inputs] + get_num_denum = self.get_num_denum + pairs = [get_num_denum(input2) for input2 in parent.inputs] if parent.op == self.main: # If we have main(x, y, ...), numx, denumx, numy, denumy, ... # then num is concat(numx, numy, num...) and denum is # concat(denumx, denumy, denum...) note that main() can have any # number of arguments >= 0 concat is list concatenation - num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs)) - denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs)) + list_concat = list.__iadd__ + num = reduce(list_concat, map(itemgetter(0), pairs)) + denum = reduce(list_concat, map(itemgetter(1), pairs)) elif parent.op == self.inverse: # If we have inverse(x, y), numx, denumx, numy and denumy # then num is concat(numx, denumy) and denum is @@ -1125,8 +1096,7 @@ def get_num_denum(self, inp): # If we have reciprocal(x), numx, denumx # then num is denumx and denum is numx # note that reciprocal() is unary - num = pairs[0][1] - denum = pairs[0][0] + denum, num = pairs[0] return num, denum def merge_num_denum(self, num, denum): @@ -1207,6 +1177,8 @@ def simplify_factors(self, num, denum): """ ln = len(num) ld = len(denum) + if ln == 0 or ld == 0: + return num, denum if ld > 2 and ln > 2: # Faster version for "big" inputs. while True: @@ -1252,15 +1224,21 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): numct, denumct = [], [] for v in orig_num: - if isinstance(v, TensorConstant) and v.unique_value is not None: + if ( + isinstance(v, TensorConstant) + and (unique_val := v.unique_value) is not None + ): # We found a constant in the numerator! # We add it to numct - numct.append(v.unique_value) + numct.append(unique_val) else: num.append(v) for v in orig_denum: - if isinstance(v, TensorConstant) and v.unique_value is not None: - denumct.append(v.unique_value) + if ( + isinstance(v, TensorConstant) + and (unique_val := v.unique_value) is not None + ): + denumct.append(unique_val) else: denum.append(v) @@ -1315,13 +1293,16 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): def transform(self, fgraph, node, enforce_tracks=True): op = node.op - if enforce_tracks and (op not in {self.main, self.inverse, self.reciprocal}): + if enforce_tracks and (op not in self.ops): return False - assert len(node.outputs) == 1 - out = node.outputs[0] + [out] = node.outputs + clients = fgraph.clients - out_clients = fgraph.clients.get(out) + try: + out_clients = clients[out] + except Exception: + return False if not out_clients: return False @@ -1330,22 +1311,18 @@ def transform(self, fgraph, node, enforce_tracks=True): # this canonized graph... if so, we do nothing and wait for # them to be transformed. for c, c_idx in out_clients: - while ( - isinstance(c.op, DimShuffle) and len(fgraph.clients[c.outputs[0]]) <= 1 - ): - c = fgraph.clients[c.outputs[0]][0][0] - if c.op in [self.main, self.inverse, self.reciprocal]: + while isinstance(c.op, DimShuffle) and len(clients[c.outputs[0]]) <= 1: + [(c, _)] = clients[c.outputs[0]] + if c.op in self.ops: return False # Here we make the canonical version of the graph around this node # See the documentation of get_num_denum and simplify - orig_num, orig_denum = self.get_num_denum(node.outputs[0]) + orig_num, orig_denum = self.get_num_denum(out) num, denum = self.simplify(list(orig_num), list(orig_denum), out.type) def same(x, y): - return len(x) == len(y) and all( - np.all(xe == ye) for xe, ye in zip(x, y, strict=True) - ) + return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in zip(x, y)) if ( same(orig_num, num) @@ -2645,7 +2622,7 @@ def add_calculate(num, denum, aslist=False, out_type=None): else: v = reduce(np.add, num, zero) - reduce(np.add, denum, zero) if aslist: - if np.all(v == 0): + if (v == 0).all(): return [] else: return [v] diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 474d08c49d..0518aeda40 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -985,7 +985,7 @@ def __eq__(self, other): # (note that if there are NaN values in d1, this will return # False, which is why we do not bother with testing `other.has_nan` # here). - return (self.sum == other.sum) and np.all(d0 == d1) + return (self.sum == other.sum) and (d0 == d1).all() def __ne__(self, other): return not self == other @@ -1007,36 +1007,15 @@ def sum(self): # Prevent warnings when there are `inf`s and `-inf`s present with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) - self._sum = self.no_nan.sum() - - # The following 2 lines are needed as in Python 3.3 with NumPy - # 1.7.1, numpy.ndarray and numpy.memmap aren't hashable. - if isinstance(self._sum, np.memmap): - self._sum = np.asarray(self._sum).item() - - if self.has_nan and self.no_nan.mask.all(): - # In this case the sum is not properly computed by numpy. - self._sum = 0 - - if np.isinf(self._sum) or np.isnan(self._sum): - # NaN may happen when there are both -inf and +inf values. - if self.has_nan: - # Filter both NaN and Inf values. - mask = self.no_nan.mask + np.isinf(self[1]) - else: - # Filter only Inf values. - mask = np.isinf(self[1]) - if mask.all(): - self._sum = 0 - else: - self._sum = np.ma.masked_array(self[1], mask).sum() - # At this point there should be no more NaN. - assert not np.isnan(self._sum) + self._sum = _sum = self.no_nan.sum() - if isinstance(self._sum, np.ma.core.MaskedConstant): - self._sum = 0 + if not np.isfinite(_sum): + self._sum = _sum = np.nan_to_num( + self[1], nan=0, posinf=0, neginf=0 + ).sum() + assert not np.isnan(_sum) - return self._sum + return _sum @property def no_nan(self): @@ -1044,8 +1023,8 @@ def no_nan(self): return self._no_nan except AttributeError: nans = np.isnan(self[1]) - self._no_nan = np.ma.masked_array(self[1], nans) - self.has_nan = np.any(nans) + self.has_nan = has_nans = nans.any() + self._no_nan = np.ma.masked_array(self[1], nans) if has_nans else self[1] return self._no_nan diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 6f122767bb..766b9a7803 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -12,7 +12,9 @@ from pytensor.compile.io import In, Out from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config -from pytensor.graph.basic import Constant +from pytensor.graph.basic import Constant, explicit_graph_inputs +from pytensor.graph.replace import graph_replace +from pytensor.graph.rewriting import rewrite_graph from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter from pytensor.graph.utils import MissingInputError from pytensor.link.vm import VMLinker @@ -1359,6 +1361,142 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark): benchmark(f, rng_val) +@pytest.fixture(scope="module") +def radon_model(): + def halfnormal(name, *, sigma=1.0, model_logp): + log_value = pt.scalar(f"{name}_log") + value = pt.exp(log_value) + + logp = ( + -0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma) + ) + logp = pt.switch(value >= 0, logp, -np.inf) + model_logp.append(logp + value) + return value + + def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None): + value = pt.scalar(name) if observed is None else pt.as_tensor(observed) + + logp = ( + -0.5 * (((value - mu) / sigma) ** 2) + - pt.log(pt.sqrt(2.0 * np.pi)) + - pt.log(sigma) + ) + model_logp.append(logp) + return value + + def zerosumnormal(name, *, sigma=1.0, size, model_logp): + raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,)) + n = raw_value.shape[0] + 1 + sum_vals = raw_value.sum(0, keepdims=True) + norm = sum_vals / (pt.sqrt(n) + n) + fill_value = norm - sum_vals / pt.sqrt(n) + value = pt.concatenate([raw_value, fill_value]) - norm + + shape = value.shape + _full_size = pt.prod(shape) + _degrees_of_freedom = pt.prod(shape[-1:].inc(-1)) + logp = pt.sum( + -0.5 * ((value / sigma) ** 2) + - (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma)) + * (_degrees_of_freedom / _full_size) + ) + model_logp.append(logp) + return value + + rng = np.random.default_rng(1) + n_counties = 85 + county_idx = rng.integers(n_counties, size=919) + county_idx.sort() + floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64) + log_radon = rng.normal(size=919) + + model_logp = [] + intercept = normal("intercept", sigma=10, model_logp=model_logp) + + # County effects + county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp) + county_sd = halfnormal("county_sd", model_logp=model_logp) + county_effect = county_raw * county_sd + + # Global floor effect + floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp) + + county_floor_raw = zerosumnormal( + "county_floor_raw", size=n_counties, model_logp=model_logp + ) + county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp) + county_floor_effect = county_floor_raw * county_floor_sd + + mu = ( + intercept + + county_effect[county_idx] + + floor_effect * floor + + county_floor_effect[county_idx] * floor + ) + + sigma = halfnormal("sigma", model_logp=model_logp) + _ = normal( + "log_radon", + mu=mu, + sigma=sigma, + observed=log_radon, + model_logp=model_logp, + ) + + model_logp = pt.sum([logp.sum() for logp in model_logp]) + model_logp = rewrite_graph( + model_logp, include=("canonicalize", "stabilize"), clone=False + ) + params = list(explicit_graph_inputs(model_logp)) + model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)]) + + size = sum(int(np.prod(p.type.shape)) for p in params) + joined_inputs = pt.vector("joined_inputs", shape=(size,)) + idx = 0 + replacement = {} + for param in params: + param_shape = param.type.shape + param_size = int(np.prod(param_shape)) + replacement[param] = joined_inputs[idx : idx + param_size].reshape(param_shape) + idx += param_size + assert idx == joined_inputs.type.shape[0] + + model_logp, model_dlogp = graph_replace([model_logp, model_dlogp], replacement) + return joined_inputs, [model_logp, model_dlogp] + + +@pytest.mark.parametrize("mode", ["FAST_COMPILE", "FAST_RUN", "NUMBA"]) +def test_radon_model_compile_benchmark(mode, radon_model, benchmark): + joined_inputs, [model_logp, model_dlogp] = radon_model + rng = np.random.default_rng(1) + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) + + def compile_and_call_once(): + fn = function( + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True + ) + fn(x) + + benchmark(compile_and_call_once) + + +@pytest.mark.parametrize("mode", ["C", "C_VM", "C_VM_NOGC", "NUMBA", "NUMBA_VM"]) +def test_radon_model_call_benchmark(mode, radon_model, benchmark): + joined_inputs, [model_logp, model_dlogp] = radon_model + + real_mode = "C_VM" if mode == "C_VM_NOGC" else mode + fn = function( + [joined_inputs], [model_logp, model_dlogp], mode=real_mode, trust_input=True + ) + if mode == "C_VM_NOGC": + fn.vm.allow_gc = False + + rng = np.random.default_rng(1) + x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) + benchmark(fn, x) + + @pytest.mark.parametrize("mode", ["FAST_COMPILE", "FAST_RUN"]) @pytest.mark.parametrize("depth", [2, 20]) def test_function_compilation_benchmark(mode, depth, benchmark): From 2c6bff5959c7ce73225bb05bf6fd939c37867d25 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Oct 2025 19:02:11 +0200 Subject: [PATCH 30/33] .speedup stuff --- pytensor/graph/utils.py | 16 ++++------- pytensor/tensor/basic.py | 4 +-- pytensor/tensor/rewriting/elemwise.py | 12 ++++----- pytensor/tensor/rewriting/math.py | 38 +++++++++++++-------------- pytensor/tensor/type.py | 9 ++++--- pytensor/tensor/variable.py | 5 ++-- 6 files changed, 40 insertions(+), 44 deletions(-) diff --git a/pytensor/graph/utils.py b/pytensor/graph/utils.py index 42ebbcd216..9f18fbd220 100644 --- a/pytensor/graph/utils.py +++ b/pytensor/graph/utils.py @@ -4,7 +4,7 @@ from abc import ABCMeta from collections.abc import Sequence from io import StringIO -from typing import TYPE_CHECKING, Any, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar, Union if TYPE_CHECKING: @@ -227,9 +227,10 @@ def __hash__(self): if "__eq__" not in dct: def __eq__(self, other): - return type(self) is type(other) and tuple( - getattr(self, a) for a in props - ) == tuple(getattr(other, a) for a in props) + return self is other or ( + type(self) is type(other) + and all(getattr(self, a) == getattr(other, a) for a in props) + ) dct["__eq__"] = __eq__ @@ -278,13 +279,6 @@ def info(self): for k, v in self.__dict__.items(): print(f" {k}: {v}") # noqa: T201 - # These two methods have been added to help Mypy - def __getattribute__(self, name): - return super().__getattribute__(name) - - def __setattr__(self, name: str, value: Any) -> None: - self.__dict__[name] = value - class ValidatingScratchpad(Scratchpad): """This `Scratchpad` validates attribute values.""" diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index bf9638c473..e669c15108 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -35,7 +35,7 @@ from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.raise_op import CheckAndRaise from pytensor.scalar import int32 -from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable +from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable, convert from pytensor.tensor import ( _as_tensor_variable, _get_vector_length, @@ -219,7 +219,7 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant: else: x = x.data - x_ = ps.convert(x, dtype=dtype) + x_ = convert(x, dtype=dtype) if ndim is not None: if x_.ndim < ndim: diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index cee496c384..871eccc2c7 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -312,11 +312,11 @@ def apply_local_dimshuffle_lift(fgraph, var): """ lift recursively """ - if var.owner is None: - return var - new = local_dimshuffle_lift.transform(fgraph, var.owner) - if new: - return new[0] + if (node := var.owner) is not None and isinstance(node.op, DimShuffle): + # Sidestep indirection in local_dimshuffle_lift.apply + new = local_dimshuffle_lift.fn(fgraph, node) + if new: + return new[0] return var @@ -1041,7 +1041,7 @@ def local_inline_composite_constants(fgraph, node): new_inner_outs = clone_replace( composite_op.fgraph.outputs, replace=inner_replacements ) - new_composite_op = ps.Composite(new_inner_inputs, new_inner_outs) + new_composite_op = ps.Composite(new_inner_inputs, new_inner_outs, clone_graph=False) new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs # Some of the inlined constants were broadcasting the output shape diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a63b58261c..e47a43c9e7 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1068,15 +1068,11 @@ def get_num_denum(self, inp): else: return [inp], [] - num = [] - denum = [] - # We get the (num, denum) pairs for each input # pairs = [self.get_num_denum(input2) if input2.type.dtype == # input.type.dtype else ([input2], []) for input2 in # parent.inputs] - get_num_denum = self.get_num_denum - pairs = [get_num_denum(input2) for input2 in parent.inputs] + pairs = [self.get_num_denum(input2) for input2 in parent.inputs] if parent.op == self.main: # If we have main(x, y, ...), numx, denumx, numy, denumy, ... @@ -1092,7 +1088,7 @@ def get_num_denum(self, inp): # concat(denumx, numy) note that inverse() is binary num = pairs[0][0] + pairs[1][1] denum = pairs[0][1] + pairs[1][0] - elif parent.op == self.reciprocal: + else: # parent.op == self.reciprocal: # If we have reciprocal(x), numx, denumx # then num is denumx and denum is numx # note that reciprocal() is unary @@ -2408,20 +2404,22 @@ def check_for_x_over_absX(numerators, denominators): """Convert x/abs(x) into sign(x).""" # TODO: this function should dig/search through dimshuffles # This won't catch a dimshuffled absolute value - for den in list(denominators): - if den.owner and den.owner.op == pt_abs and den.owner.inputs[0] in numerators: - if den.owner.inputs[0].type.dtype.startswith("complex"): - # TODO: Make an Op that projects a complex number to - # have unit length but projects 0 to 0. That - # would be a weird Op, but consistent with the - # special case below. I heard there's some - # convention in Matlab that is similar to - # this... but not sure. - pass - else: - denominators.remove(den) - numerators.remove(den.owner.inputs[0]) - numerators.append(sign(den.owner.inputs[0])) + if not numerators or not denominators: + return numerators, denominators + + original_denominators = denominators + for den in original_denominators: + if ( + (den_node := den.owner) is not None + and den_node.op == pt_abs + and (num_index := numerators.index(num := den_node.inputs[0])) >= 0 + and not num.type.dtype.startswith("complex") + ): + if denominators is original_denominators: + denominators = denominators.copy() + denominators.remove(den) + numerators.pop(num_index) + numerators.append(sign(num)) return numerators, denominators diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 6f29269ac7..d918e9417a 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -5,6 +5,7 @@ import numpy as np import numpy.typing as npt +from numpy import dtype as np_dtype import pytensor from pytensor import scalar as ps @@ -36,7 +37,7 @@ uint_dtypes = list(map(str, ps.uint_types)) _all_dtypes_str: dict[str, str] = {d: d for d in all_dtypes} -_str_to_numpy_dtype: dict[str, np.dtype] = {} +_str_to_numpy_dtype: dict[str, np_dtype] = {} # TODO: add more type correspondences for e.g. int32, int64, float32, # complex64, etc. @@ -108,16 +109,18 @@ def __init__( elif dtype not in _all_dtypes_str: # Check if dtype is a valid numpy dtype try: - dtype = str(np.dtype(dtype)) + dtype = np_dtype(dtype).name except TypeError as exc: raise TypeError( f"Unsupported dtype for TensorType: {dtype}" ) from exc else: _all_dtypes_str[dtype] = dtype + elif isinstance(dtype, np_dtype): + dtype = dtype.name else: try: - dtype = str(np.dtype(dtype)) + dtype = np.dtype(dtype).name except TypeError as exc: raise TypeError(f"Unsupported dtype for TensorType: {dtype}") from exc diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 0518aeda40..2ca02ef7be 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -348,10 +348,11 @@ def dimshuffle(self, *pattern): """ if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)): pattern = pattern[0] - ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern) - if ds_op.new_order == tuple(range(self.type.ndim)): + pattern = tuple(pattern) + if pattern == tuple(range(self.type.ndim)): # No-op return self + ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern) return ds_op(self) def flatten(self, ndim=1): From ed5cec271973216037a61504f5b79bb436342bf9 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 4 Oct 2025 11:12:49 +0200 Subject: [PATCH 31/33] Reapply "Use non recursive algorithm in `rebuild_collect_shared`" This reverts commit 330a1e03273ad19cae6c22baf9b0a97b23397b55. --- pytensor/compile/function/pfunc.py | 87 ++++++++++++++++-------------- 1 file changed, 47 insertions(+), 40 deletions(-) diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index 91d6e1a588..2056d077be 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -179,47 +179,54 @@ def clone_v_get_shared_updates(v, copy_inputs_over): """ # this co-recurses with clone_a - assert v is not None - if v in clone_d: - return clone_d[v] - if v.owner: - owner = v.owner - if owner not in clone_d: - for i in owner.inputs: - clone_v_get_shared_updates(i, copy_inputs_over) - clone_node_and_cache( - owner, - clone_d, - strict=rebuild_strict, - clone_inner_graphs=clone_inner_graphs, - ) - return clone_d.setdefault(v, v) - elif isinstance(v, SharedVariable): - if v not in shared_inputs: - shared_inputs.append(v) - if v.default_update is not None: - # Check that v should not be excluded from the default - # updates list - if no_default_updates is False or ( - isinstance(no_default_updates, list) and v not in no_default_updates - ): - # Do not use default_update if a "real" update was - # provided - if v not in update_d: - v_update = v.type.filter_variable( - v.default_update, allow_convert=False + stack = [v] + try: + while True: + v = stack.pop() + if v in clone_d: + continue + if (apply := v.owner) is not None: + if all(i in clone_d for i in apply.inputs): + # all inputs have been cloned, we can clone this node + clone_node_and_cache( + apply, + clone_d, + strict=rebuild_strict, + clone_inner_graphs=clone_inner_graphs, ) - if not v.type.is_super(v_update.type): - raise TypeError( - "An update must have a type compatible with " - "the original shared variable" - ) - update_d[v] = v_update - update_expr.append((v, v_update)) - if not copy_inputs_over: - return clone_d.setdefault(v, v.clone()) - else: - return clone_d.setdefault(v, v) + else: + # expand on the inputs + stack.extend(apply.inputs) + else: + clone_d[v] = v if copy_inputs_over else v.clone() + + # Special handling of SharedVariables + if isinstance(v, SharedVariable): + if v not in shared_inputs: + shared_inputs.append(v) + if v.default_update is not None: + # Check that v should not be excluded from the default + # updates list + if no_default_updates is False or ( + isinstance(no_default_updates, list) + and v not in no_default_updates + ): + # Do not use default_update if a "real" update was + # provided + if v not in update_d: + v_update = v.type.filter_variable( + v.default_update, allow_convert=False + ) + if not v.type.is_super(v_update.type): + raise TypeError( + "An update must have a type compatible with " + "the original shared variable" + ) + update_d[v] = v_update + update_expr.append((v, v_update)) + except IndexError: + pass # stack is empty + return clone_d[v] # initialize the clone_d mapping with the replace dictionary if replace is None: From 2ea91fdc094aeb86fe6327a7acc39d2f721b4953 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 4 Oct 2025 13:49:30 +0200 Subject: [PATCH 32/33] .speedup stuff --- pytensor/compile/function/pfunc.py | 5 +- pytensor/graph/basic.py | 4 +- pytensor/scalar/basic.py | 59 +++++-- pytensor/tensor/rewriting/math.py | 228 +++++++++++++-------------- pytensor/tensor/rewriting/shape.py | 112 +++++++------ pytensor/tensor/rewriting/special.py | 8 +- tests/compile/function/test_types.py | 2 +- 7 files changed, 238 insertions(+), 180 deletions(-) diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index 2056d077be..574a86b1b4 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -182,12 +182,14 @@ def clone_v_get_shared_updates(v, copy_inputs_over): stack = [v] try: while True: - v = stack.pop() + v = stack[-1] if v in clone_d: + stack.pop() continue if (apply := v.owner) is not None: if all(i in clone_d for i in apply.inputs): # all inputs have been cloned, we can clone this node + stack.pop() clone_node_and_cache( apply, clone_d, @@ -198,6 +200,7 @@ def clone_v_get_shared_updates(v, copy_inputs_over): # expand on the inputs stack.extend(apply.inputs) else: + stack.pop() clone_d[v] = v if copy_inputs_over else v.clone() # Special handling of SharedVariables diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index bb6334aed8..30ee1f97ce 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -701,7 +701,9 @@ def equals(self, other): This does what `__eq__` would normally do, but `Variable` and `Apply` should always be hashable by `id`. """ - return isinstance(other, type(self)) and self.signature() == other.signature() + return self is other or ( + isinstance(other, type(self)) and self.signature() == other.signature() + ) @property def owner(self): diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 52f8bb05dd..a29eb73f9e 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -282,7 +282,7 @@ def convert(x, dtype=None): x_ = np.asarray(x) if x_.size == 0 and not hasattr(x, "dtype"): x_ = np.asarray(x, dtype=config.floatX) - assert issubclass(type(x_), np.ndarray | np.memmap) + # assert issubclass(type(x_), np.ndarray | np.memmap) return x_ @@ -834,6 +834,7 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType: all_types: _ScalarTypes = discrete_types + continuous_types float_dtypes = tuple(t.dtype for t in float_types) +complex_dtypes = tuple(t.dtype for t in complex_types) discrete_dtypes = tuple(t.dtype for t in discrete_types) @@ -1729,6 +1730,9 @@ def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs return f"{z} = {cond} ? {ift} : {iff};" + def supports_c_code(self, inputs, outputs): + return True + def L_op(self, inputs, outputs, gout): (cond, ift, iff) = inputs (gz,) = gout @@ -1969,13 +1973,14 @@ def impl(self, *inputs): def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs - op = " + " - if node.outputs[0].type == bool: - op = " || " if not inputs: - return z + " = 0;" + return f"{z} = 0;" else: - return z + " = " + op.join(inputs) + ";" + op = " || " if (node.outputs[0].type == bool) else " + " + return f"{z} = {op.join(inputs)};" + + def supports_c_code(self, inputs, outputs): + return True # Always supports c code def L_op(self, inputs, outputs, gout): (gz,) = gout @@ -2011,13 +2016,14 @@ def impl(self, *inputs): def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs - op = " * " - if node.outputs[0].type == bool: - op = " && " if not inputs: - return z + " = 1;" + return f"{z} = 1;" else: - return z + " = " + op.join(inputs) + ";" + op = " && " if (node.outputs[0].type == bool) else " * " + return f"{z} = {op.join(inputs)};" + + def supports_c_code(self, inputs, outputs): + return True # Always supports c code def grad(self, inputs, gout): (gz,) = gout @@ -2072,6 +2078,9 @@ def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs return f"{z} = {x} - {y};" + def supports_c_code(self, inputs, outputs): + return True # Always supports c code + def L_op(self, inputs, outputs, gout): (x, y) = inputs (gz,) = gout @@ -2122,6 +2131,10 @@ def c_code(self, node, name, inputs, outputs, sub): return f"{z} = ((double){x}) / {y};" return f"{z} = {x} / {y};" + def supports_c_code(self, inputs, outputs): + [x, y] = inputs + return x.type.dtype not in complex_dtypes and y.type.dtype not in complex_dtypes + def grad(self, inputs, gout): (x, y) = inputs (gz,) = gout @@ -2387,6 +2400,10 @@ def c_code(self, node, name, inputs, outputs, sub): raise NotImplementedError("type not supported", type) return f"{z} = pow({x}, {y});" + def supports_c_code(self, inputs, outputs): + [x, y] = inputs + return x.type.dtype not in complex_dtypes and y.type.dtype not in complex_dtypes + def L_op(self, inputs, outputs, gout): (x, y) = inputs (gz,) = gout @@ -2507,6 +2524,9 @@ def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs return f"{z} = {y};" + def supports_c_code(self, inputs, outputs): + return True # Always supports c code + def connection_pattern(self, node): # x is never connected because its elements are never used # y is connected because its elements are copied over @@ -2975,6 +2995,9 @@ def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs return f"{z} = -{x};" + def supports_c_code(self, inputs, outputs): + return True # Always supports c code + neg = Neg(same_out_nobool, name="neg") @@ -3062,6 +3085,9 @@ def c_code(self, node, name, inputs, outputs, sub): cast = node.outputs[0].type.dtype_specs()[1] return f"{z} = log(({cast}){x});" + def supports_c_code(self, inputs, outputs): + return inputs[0].type.dtype not in complex_dtypes + log = Log(upgrade_to_float, name="log") @@ -3227,6 +3253,9 @@ def c_code(self, node, name, inputs, outputs, sub): cast = node.outputs[0].type.dtype_specs()[1] return f"{z} = exp(({cast}){x});" + def supports_c_code(self, inputs, outputs): + return inputs[0].type.dtype not in complex_dtypes + exp = Exp(upgrade_to_float, name="exp") @@ -3330,6 +3359,9 @@ def c_code(self, node, name, inputs, outputs, sub): (z,) = outputs return f"{z} = {x} * {x};" + def supports_c_code(self, inputs, outputs): + return True # Always supports c code + sqr = Sqr(same_out, name="sqr") @@ -3366,6 +3398,9 @@ def c_code(self, node, name, inputs, outputs, sub): cast = node.outputs[0].type.dtype_specs()[1] return f"{z} = sqrt(({cast}){x});" + def supports_c_code(self, inputs, outputs): + return inputs[0].type.dtype not in complex_types + sqrt = Sqrt(upgrade_to_float, name="sqrt") @@ -4415,7 +4450,7 @@ def fgraph(self): return self._fgraph # fgraph cannot be a property of the base class because it messes up with C caching. # We also need a `FunctionGraph(clone=True)` (default) according to an old comment - fgraph = FunctionGraph(self.inputs, self.outputs) + fgraph = FunctionGraph(self.inputs, self.outputs, clone=False) self._fgraph = fgraph return self._fgraph diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index e47a43c9e7..bb5c9d0c03 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -19,6 +19,7 @@ node_rewriter, ) from pytensor.graph.rewriting.utils import get_clients_at_depth +from pytensor.scalar import Abs from pytensor.tensor.basic import ( Alloc, Join, @@ -1014,10 +1015,13 @@ def __init__(self, main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=Tr self.main = main self.inverse = inverse_fn self.reciprocal = reciprocal_fn - self.ops = (self.main, self.inverse, self.reciprocal) + self.scalar_ops = ( + main.scalar_op, + inverse_fn.scalar_op, + reciprocal_fn.scalar_op, + ) self.calculate = calculate self.use_reciprocal = use_reciprocal - self.external_simplifiers = [] def add_simplifier(self, simplifier, reason): @@ -1054,46 +1058,47 @@ def get_num_denum(self, inp): # internal data nodes all have the dtype of the 'input' # argument. The leaf-Variables of the graph covered by the # recursion may be of any Variable type. - parent = inp.owner - if parent is None or parent.op not in self.ops: - if ( - parent is not None - and isinstance(ds_op := parent.op, DimShuffle) - and ds_op.is_left_expand_dims - ): + if (parent := inp.owner) is not None: + if isinstance(ds_op := parent.op, DimShuffle) and ds_op.is_left_expand_dims: # If input is a left_expand_dims DimShuffle, # the kind of which is inserted automatically by Elemwise # we return the num_denum of the dimshuffled input. return self.get_num_denum(parent.inputs[0]) - else: - return [inp], [] - - # We get the (num, denum) pairs for each input - # pairs = [self.get_num_denum(input2) if input2.type.dtype == - # input.type.dtype else ([input2], []) for input2 in - # parent.inputs] - pairs = [self.get_num_denum(input2) for input2 in parent.inputs] - - if parent.op == self.main: - # If we have main(x, y, ...), numx, denumx, numy, denumy, ... - # then num is concat(numx, numy, num...) and denum is - # concat(denumx, denumy, denum...) note that main() can have any - # number of arguments >= 0 concat is list concatenation - list_concat = list.__iadd__ - num = reduce(list_concat, map(itemgetter(0), pairs)) - denum = reduce(list_concat, map(itemgetter(1), pairs)) - elif parent.op == self.inverse: - # If we have inverse(x, y), numx, denumx, numy and denumy - # then num is concat(numx, denumy) and denum is - # concat(denumx, numy) note that inverse() is binary - num = pairs[0][0] + pairs[1][1] - denum = pairs[0][1] + pairs[1][0] - else: # parent.op == self.reciprocal: - # If we have reciprocal(x), numx, denumx - # then num is denumx and denum is numx - # note that reciprocal() is unary - denum, num = pairs[0] - return num, denum + + if isinstance(parent.op, Elemwise): + try: + kind = self.scalar_ops.index(parent.op.scalar_op) + except ValueError: + pass + else: + # We get the (num, denum) pairs for each input + # pairs = [self.get_num_denum(input2) if input2.type.dtype == + # input.type.dtype else ([input2], []) for input2 in + # parent.inputs] + pairs = [self.get_num_denum(input2) for input2 in parent.inputs] + + if kind == 0: + # If we have main(x, y, ...), numx, denumx, numy, denumy, ... + # then num is concat(numx, numy, num...) and denum is + # concat(denumx, denumy, denum...) note that main() can have any + # number of arguments >= 0 concat is list concatenation + list_concat = list.__iadd__ + num = reduce(list_concat, map(itemgetter(0), pairs)) + denum = reduce(list_concat, map(itemgetter(1), pairs)) + elif kind == 1: + # If we have inverse(x, y), numx, denumx, numy and denumy + # then num is concat(numx, denumy) and denum is + # concat(denumx, numy) note that inverse() is binary + num = pairs[0][0] + pairs[1][1] + denum = pairs[0][1] + pairs[1][0] + else: # parent.op == self.reciprocal: + # If we have reciprocal(x), numx, denumx + # then num is denumx and denum is numx + # note that reciprocal() is unary + denum, num = pairs[0] + return num, denum + + return [inp], [] # fall back case def merge_num_denum(self, num, denum): r""" @@ -1253,43 +1258,45 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): # does it for us? ct = [self.calculate(numct, denumct, aslist=False, out_type=out_type)] - # Wrapping ct in a Constant with the right dtype - ct = [constant(c, dtype=out_type.dtype) for c in ct] + if not ct: + # ct is empty if the constant is the neutral element + # (e.g. 1 for multiplication, 0 for addition) + # In that case we just return num and denum + return num, denum - if orig_num and len(numct) == 1 and len(denumct) == 0 and ct: - # In that case we should only have one constant in `ct`. - [var_ct] = ct + [c] = ct + if orig_num and (not denumct) and len(numct) == 1: + # Check for useless simplification + # If it so happens that: + # * there's exactly one constant on the numerator and none on the denominator + # * it's not the neutral element (ct is an empty list in that case) + # * the constant is the same as the first argument in the + # numerator (we only check the first argument because the + # canonizer puts the computed constants first) + # -> then we return the original num/denum. + # If we don't do that the rewrite will just loop + # infinitely replacing something by the same thing... + # Note that it is important to use `values_eq` instead of + # the == operator, to handle NaN values correctly. first_num_var = orig_num[0] first_num_ct = ( first_num_var.unique_value if isinstance(first_num_var, TensorConstant) else None ) - if first_num_ct is not None and var_ct.type.values_eq( - var_ct.data, first_num_ct + if first_num_ct is not None and first_num_var.type.values_eq( + c, first_num_ct ): - # This is an important trick :( if it so happens that: - # * there's exactly one constant on the numerator and none on - # the denominator - # * it's not the neutral element (ct is an empty list in that - # case) - # * the constant is the same as the first argument in the - # numerator (we only check the first argument because the - # canonizer puts the computed constants first) - # -> then we return very exactly the original num/denum. - # If we don't do that the rewrite will just loop - # infinitely because it will not catch on that there are - # no changes to be made and every time it will want to - # replace something by the same thing... - # Note that it is important to use `values_eq` instead of - # the == operator, to handle NaN values correctly. return orig_num, orig_denum - return ct + num, denum + # Convert c back to a Constant with the right dtype and append remaining numerator + return [constant(c, dtype=out_type.dtype), *num], denum def transform(self, fgraph, node, enforce_tracks=True): op = node.op - if enforce_tracks and (op not in self.ops): + if enforce_tracks and not ( + isinstance(op, Elemwise) and op.scalar_op in self.scalar_ops + ): return False [out] = node.outputs @@ -1309,7 +1316,7 @@ def transform(self, fgraph, node, enforce_tracks=True): for c, c_idx in out_clients: while isinstance(c.op, DimShuffle) and len(clients[c.outputs[0]]) <= 1: [(c, _)] = clients[c.outputs[0]] - if c.op in self.ops: + if isinstance(c.op, Elemwise) and c.op.scalar_op in self.scalar_ops: return False # Here we make the canonical version of the graph around this node @@ -1317,25 +1324,25 @@ def transform(self, fgraph, node, enforce_tracks=True): orig_num, orig_denum = self.get_num_denum(out) num, denum = self.simplify(list(orig_num), list(orig_denum), out.type) - def same(x, y): - return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in zip(x, y)) - if ( - same(orig_num, num) - and same(orig_denum, denum) - and + ( + len(orig_num) == len(num) + and all(xe is ye for xe, ye in zip(orig_num, num)) + ) + and ( + len(orig_denum) == len(denum) + and all(xe is ye for xe, ye in zip(orig_denum, denum)) + ) # Check to see if we've collapsed some nested ops. - not ( + and not ( len(orig_denum) == 0 - and # Make sure this change would increase the number of vector # arguments--decreasing the number of unnecessary `self.main` # nodes. - len(node.inputs) < len(orig_num) + and len(node.inputs) < len(orig_num) ) - and # Do a similar check for the reciprocal op. - not ( + and not ( self.use_reciprocal and node.op == self.reciprocal and len(orig_num) == 0 @@ -1372,10 +1379,7 @@ def __str__(self): def mul_calculate(num, denum, aslist=False, out_type=None): if not num and not denum: # Smallest 1 possible. - if aslist: - return [] - else: - return np.int8(1) + return [] if aslist else np.int8(1) # Make sure we do not accidentally upcast data types. if out_type is None: @@ -1384,12 +1388,11 @@ def mul_calculate(num, denum, aslist=False, out_type=None): out_dtype = out_type.dtype one = np.asarray(1, dtype=out_dtype) - v = reduce(np.multiply, num, one) / reduce(np.multiply, denum, one) + v = reduce(np.multiply, num, one) if num else one + if denum: + v /= reduce(np.multiply, denum, one) if aslist: - if np.all(v == 1): - return [] - else: - return [v] + return [] if (v == 1).all() else [v] return v @@ -2060,8 +2063,6 @@ def local_mul_zero(fgraph, node): with zero. """ - otype = node.outputs[0].type - for i in node.inputs: try: value = get_underlying_scalar_constant_value(i) @@ -2070,6 +2071,7 @@ def local_mul_zero(fgraph, node): # print 'MUL by value', value, node.inputs if value == 0: # print '... returning zeros' + otype = node.outputs[0].type return [broadcast_arrays(np.asarray(0, dtype=otype.dtype), *node.inputs)[0]] @@ -2358,38 +2360,26 @@ def local_mul_specialize(fgraph, node): @register_specialize @node_rewriter([add]) def local_add_remove_zeros(fgraph, node): - new_inputs = [] - for inp in node.inputs: - try: - y = get_underlying_scalar_constant_value(inp) - except NotScalarConstantError: - y = inp - if y == 0.0: - continue - new_inputs.append(inp) + is_zeros = [ + get_underlying_scalar_constant_value(inp, raise_not_constant=False) == 0.0 + for inp in node.inputs + ] - if len(new_inputs) == len(node.inputs): - return False + if not any(is_zeros): + return None + new_inputs = [inp for inp, is_zero in zip(node.inputs, is_zeros) if not is_zero] node_output = node.outputs[0] dtype = node_output.type.dtype - if len(new_inputs) == 0: - # we got rid of the entire expression! - ndim = node_output.type.ndim - # Reuse call to constant for cache() - cst = constant(np.zeros((1,) * ndim, dtype=dtype)) - assert cst.type.broadcastable == (True,) * ndim - return [alloc_like(cst, node_output, fgraph)] - - ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)] + ret = alloc_like(variadic_add(*new_inputs), node_output, fgraph) # The dtype should not be changed. It can happen if the input # that was forcing upcasting was equal to 0. - if ret[0].dtype != dtype: - ret = [cast(ret[0], dtype)] + if ret.type.dtype != dtype: + ret = cast(ret, dtype) - return ret + return [ret] mul_canonizer = in2out( @@ -2411,7 +2401,8 @@ def check_for_x_over_absX(numerators, denominators): for den in original_denominators: if ( (den_node := den.owner) is not None - and den_node.op == pt_abs + and isinstance(den_node.op, Elemwise) + and isinstance(den_node.op.scalar_op, Abs) and (num_index := numerators.index(num := den_node.inputs[0])) >= 0 and not num.type.dtype.startswith("complex") ): @@ -2602,13 +2593,21 @@ def local_log_sum_exp(fgraph, node): def add_calculate(num, denum, aslist=False, out_type=None): # TODO: make sure that this function and mul_calculate are similar + if not num and not denum: + return ( + [] + if aslist + else (0.0 if out_type is None else np.asarray(0, dtype=out_type.dtype)) + ) + if out_type is None: zero = 0.0 else: zero = np.asarray(0, dtype=out_type.dtype) + # zero = 0.0 if out_type is None else np.asarray(0, # dtype=out_type.dtype) - if out_type and out_type.dtype == "bool": + if out_type is not None and out_type.dtype == "bool": if len(denum) == 0: # NumPy 1.14 do not accept to do "bool - bool" v = reduce(np.add, num, zero) @@ -2618,12 +2617,11 @@ def add_calculate(num, denum, aslist=False, out_type=None): " an earlier error should have been raised" ) else: - v = reduce(np.add, num, zero) - reduce(np.add, denum, zero) + v = reduce(np.add, num, zero) if num else zero + if denum: + v -= reduce(np.add, denum, zero) if aslist: - if (v == 0).all(): - return [] - else: - return [v] + return [] if (v == 0).all() else [v] return v diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index f784954dc9..dd65eb8b2f 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -51,6 +51,17 @@ from pytensor.tensor.variable import TensorVariable +int_constant_cache = {} + + +def int_constant(i): + try: + return int_constant_cache[i] + except KeyError: + int_constant_cache[i] = c = constant(i, dtype="int64") + return c + + class ShapeFeature(Feature): r"""A `Feature` that tracks shape information in a graph. @@ -209,7 +220,7 @@ def get_shape(self, var, idx): def shape_ir(self, i, r): """Return symbolic r.shape[i] for tensor variable r, int i.""" if hasattr(r.type, "shape") and r.type.shape[i] is not None: - return constant(r.type.shape[i], dtype="int64") + return int_constant(r.type.shape[i]) else: # Do not call make_node for test_value s = Shape_i(i)(r) @@ -271,7 +282,7 @@ def unpack(self, s_i, var): # choose that options as it would give better error # message. raise AssertionError(msg) - return constant(s_i, dtype="int64") + return int_constant(s_i) if isinstance(s_i, tuple | list): # this dimension is the same as many of the inputs # which tells us that if one of the inputs is known, @@ -336,7 +347,10 @@ def set_shape(self, r, s, override=False): if not isinstance(s, tuple | list): raise TypeError("shapes must be tuple/list", (r, s)) - if r.type.ndim != len(s): + r_type = r.type + r_ndim = r_type.ndim + + if r_ndim != len(s): sio = StringIO() pytensor.printing.debugprint(r, file=sio, print_type=True) raise AssertionError( @@ -345,24 +359,22 @@ def set_shape(self, r, s, override=False): f" for the variable:\n{sio.getvalue()}" ) + shape_of_reverse_index_setdefault = self.shape_of_reverse_index.setdefault shape_vars = [] - for i in range(r.type.ndim): - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - shape_vars.append(constant(r.type.shape[i], dtype="int64")) - else: - shape_vars.append(self.unpack(s[i], r)) - assert all( - not hasattr(r.type, "shape") - or r.type.shape[i] != 1 - or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals( - get_scalar_constant_value(shape_vars[i], raise_not_constant=False) - ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(shape_vars) - for sv in shape_vars: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) + r_static_shape = getattr(r_type, "shape", None) + if r_static_shape is None: + self.shape_of[r] = svs = tuple(self.unpack(s_i, r) for s_i in s) + for sv in svs: + shape_of_reverse_index_setdefault(sv, set()).add(r) + else: + for i, (r_st, s_i) in enumerate(zip(r_static_shape, s)): + if r_st is not None: + shape_var = int_constant(r_st) + else: + shape_var = self.unpack(s_i, r) + shape_of_reverse_index_setdefault(shape_var, set()).add(r) + shape_vars.append(shape_var) + self.shape_of[r] = tuple(shape_vars) def update_shape(self, r, other_r): """Replace shape of r by shape of other_r. @@ -372,19 +384,20 @@ def update_shape(self, r, other_r): """ # other_r should already have a shape - assert other_r in self.shape_of, ("other_r not in shape_of", other_r) - other_shape = self.shape_of[other_r] + shape_of = self.shape_of + other_shape = shape_of[other_r] # If other_shape has no information, call is pointless. if other_shape is None: return - if r in self.shape_of: - r_shape = self.shape_of[r] - else: + try: + r_shape = shape_of[r] + except KeyError: # If no info is known on r's shape, use other_shape self.set_shape(r, other_shape) return + if ( other_r.owner and r.owner @@ -426,11 +439,11 @@ def update_shape(self, r, other_r): merged_shape.append(r_shape[i]) elif any( ( - r_shape[i] == anc + r_shape[i] is anc or ( anc.owner and isinstance(anc.owner.op, Shape) - and anc.owner.inputs[0] == r + and anc.owner.inputs[0] is r ) ) for anc in ancestors([other_shape[i]]) @@ -445,25 +458,26 @@ def update_shape(self, r, other_r): merged_shape.append(r_shape[i]) else: merged_shape.append(other_shape[i]) - assert all( - ( - not hasattr(r.type, "shape") - or r.type.shape[i] != 1 - and other_r.type.shape[i] != 1 - ) - or self.lscalar_one.equals(merged_shape[i]) - or self.lscalar_one.equals( - get_scalar_constant_value( - merged_shape[i], - only_process_constants=True, - raise_not_constant=False, - ) - ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(merged_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) + # assert all( + # ( + # not hasattr(r.type, "shape") + # or r.type.shape[i] != 1 + # and other_r.type.shape[i] != 1 + # ) + # or self.lscalar_one.equals(merged_shape[i]) + # or self.lscalar_one.equals( + # get_scalar_constant_value( + # merged_shape[i], + # only_process_constants=True, + # raise_not_constant=False, + # ) + # ) + # for i in range(r.type.ndim) + # ) + shape_of[r] = tuple(merged_shape) + shape_of_reverse_index_setdefault = self.shape_of_reverse_index.setdefault + for sv in shape_of[r]: + shape_of_reverse_index_setdefault(sv, set()).add(r) def set_shape_i(self, r, i, s_i): """Replace element i of shape_of[r] by s_i""" @@ -510,7 +524,7 @@ def on_attach(self, fgraph): fgraph.shape_feature = self # Must be local to the object as otherwise we reuse the same # variable for multiple fgraph! - self.lscalar_one = constant(1, dtype="int64") + self.lscalar_one = int_constant(1) assert self.lscalar_one.type.dtype == "int64" self.fgraph = fgraph @@ -576,7 +590,7 @@ def on_import(self, fgraph, node, reason): assert str(d.dtype) != "uint64", node new_shape += sh[len(new_shape) : i + 1] if isinstance(d, Constant): - casted_d = constant(d.data, dtype="int64") + casted_d = int_constant(d.data) else: casted_d = cast(d, "int64") new_shape[i] = casted_d @@ -1159,7 +1173,7 @@ def local_shape_ground(fgraph, node): if len(static_shape) == 0: return [_empty_shape] if not any(dim is None for dim in static_shape): - return [stack([constant(dim, dtype="int64") for dim in static_shape])] + return [stack([int_constant(dim) for dim in static_shape])] @register_infer_shape diff --git a/pytensor/tensor/rewriting/special.py b/pytensor/tensor/rewriting/special.py index 59569ea886..c20523797c 100644 --- a/pytensor/tensor/rewriting/special.py +++ b/pytensor/tensor/rewriting/special.py @@ -108,7 +108,11 @@ def local_logsoftmax_grad(fgraph, node): def softmax_simplifier(numerators, denominators): - for numerator in list(numerators): + if not numerators or not denominators: + return numerators, denominators + + orig_numerators = numerators + for numerator in orig_numerators: if not numerator.type.dtype.startswith("float"): continue @@ -165,6 +169,8 @@ def softmax_simplifier(numerators, denominators): if matching_denom is not None: softmax = Softmax(axis=sum_axis)(numerator.owner.inputs[0]) copy_stack_trace(numerator, softmax) + if numerators is orig_numerators: + numerators = numerators.copy() numerators.remove(numerator) denominators.remove(matching_denom) numerators.append(softmax) diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 766b9a7803..d23897a8de 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -1466,7 +1466,7 @@ def zerosumnormal(name, *, sigma=1.0, size, model_logp): return joined_inputs, [model_logp, model_dlogp] -@pytest.mark.parametrize("mode", ["FAST_COMPILE", "FAST_RUN", "NUMBA"]) +@pytest.mark.parametrize("mode", ["FAST_COMPILE", "FAST_RUN"]) def test_radon_model_compile_benchmark(mode, radon_model, benchmark): joined_inputs, [model_logp, model_dlogp] = radon_model rng = np.random.default_rng(1) From 9ba09238b8707b9a8aaf6828727b2b516b08f58d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 4 Oct 2025 15:21:07 +0200 Subject: [PATCH 33/33] .speedup stuff --- pytensor/tensor/rewriting/basic.py | 153 +++++++++++------------------ pytensor/tensor/rewriting/math.py | 69 ++++++------- 2 files changed, 88 insertions(+), 134 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 514296e76b..a66050b9fa 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -45,18 +45,14 @@ from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.scalar import ( AND, - EQ, LE, - NEQ, - OR, - XOR, Add, BinaryScalarOp, Cast, - Identity, Mul, Second, Switch, + or_, ) from pytensor.tensor.basic import ( Alloc, @@ -73,6 +69,7 @@ fill, get_scalar_constant_value, join, + ones, ones_like, register_infer_shape, switch, @@ -83,7 +80,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays -from pytensor.tensor.math import Sum, add, eq, variadic_add +from pytensor.tensor.math import Sum, add, and_, eq, mul, neq, variadic_add, xor from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -548,104 +545,70 @@ def local_alloc_empty_to_zeros(fgraph, node): @register_infer_shape @register_useless -@register_canonicalize("fast_compile") +@register_canonicalize @register_specialize -@node_rewriter([Elemwise]) -def local_useless_elemwise(fgraph, node): - """ - eq(x, x) -> 1 - neq(x, x) -> 0 - mul(x) -> x - add(x) -> x - identity(x) -> x - and(x, 1) -> x (if x.dtype == 'bool') - and(x, 0) -> zeros_like(x) - or(x, 0) -> x - or(x, 1) -> ones_like(x) (if x.dtype == 'bool') - xor(x, x) -> zeros_like(x) - - TODO: This implementation is painfully redundant. - TODO: Allow rewrite when useless input broadcasts output - - """ - out_bcast = node.outputs[0].type.broadcastable - dtype = node.outputs[0].type.dtype - scalar_op = node.op.scalar_op +@node_rewriter([eq]) +def local_useless_eq(fgraph, node): + """rewrite eq(x, x) -> 1""" + [x, y] = node.inputs + if x is y: + ret = ones_like(x, dtype=node.outputs[0].dtype, opt=True) + copy_stack_trace(node.outputs[0], ret) + return [ret] - if isinstance(scalar_op, EQ) and len(node.inputs) == 2: - if node.inputs[0] is node.inputs[1]: - # it is the same var in the graph. That will always be true - ret = ones_like(node.inputs[0], dtype=dtype, opt=True) - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - elif isinstance(scalar_op, NEQ | XOR) and len(node.inputs) == 2: - if node.inputs[0] is node.inputs[1]: - # it is the same var in the graph. That will always be false - ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) +@register_infer_shape +@register_useless +@register_canonicalize +@register_specialize +@node_rewriter([neq, xor]) +def local_useless_neq(fgraph, node): + """rewrite neq(x, x) -> 0""" + [x, y] = node.inputs + if x is y: + ret = zeros_like(x, dtype=node.outputs[0].dtype, opt=True) + copy_stack_trace(node.outputs[0], ret) + return [ret] - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - elif isinstance(node.op.scalar_op, Mul | Add | Identity) and len(node.inputs) == 1: +@register_infer_shape +@register_useless +@register_canonicalize +@register_specialize +@node_rewriter([add, mul]) +def local_useless_add_mul_identity(fgraph, node): + if len(node.inputs) == 1: # No need to copy over any stack trace return [node.inputs[0]] - elif isinstance(node.op.scalar_op, AND) and len(node.inputs) == 2: - if ( - isinstance(node.inputs[0], TensorConstant) - and node.inputs[1].type.broadcastable == out_bcast - ): - const_val = node.inputs[0].unique_value - if const_val is not None: - if const_val == 0: - return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this rewrite would be wrong - return [node.inputs[1].astype(node.outputs[0].dtype)] - - if ( - isinstance(node.inputs[1], TensorConstant) - and node.inputs[0].type.broadcastable == out_bcast - ): - const_val = node.inputs[1].unique_value - if const_val is not None: - if const_val == 0: - return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this rewrite would be wrong - return [node.inputs[0].astype(node.outputs[0].dtype)] - - elif isinstance(node.op.scalar_op, OR) and len(node.inputs) == 2: - if ( - isinstance(node.inputs[0], TensorConstant) - and node.inputs[1].type.broadcastable == out_bcast - ): - const_val = node.inputs[0].unique_value - if const_val is not None: - if const_val == 0: - return [node.inputs[1].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this rewrite would be wrong - return [ones_like(node.inputs[1], dtype=dtype, opt=True)] - if ( - isinstance(node.inputs[1], TensorConstant) - and node.inputs[0].type.broadcastable == out_bcast - ): - const_val = node.inputs[1].unique_value - if const_val is not None: - if const_val == 0: - return [node.inputs[0].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this rewrite would be wrong - return [ones_like(node.inputs[0], dtype=dtype, opt=True)] +@register_infer_shape +@register_useless +@register_canonicalize +@register_specialize +@node_rewriter([and_, or_]) +def local_useless_and_or(fgraph, node): + inputs = node.inputs + for x, y in [inputs, reversed(inputs)]: + if isinstance(x, TensorConstant) and (val := x.unique_value) is not None: + out = node.outputs[0] + out_type = out.type + if isinstance(node.op.scalar_op, AND): + if val: + res = y.astype(out_type.dtype) + else: + res = zeros((), dtype=out_type.dtype) + else: # OR + if val: + res = ones((), dtype=out_type.dtype) + else: + res = y.astype(out_type.dtype) + + if res.type.broadcastable != out_type.broadcastable: + res = fill(x, res) + + copy_stack_trace(node.outputs[0], res) + return [res] @register_specialize diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index bb5c9d0c03..cab07f6179 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -19,7 +19,7 @@ node_rewriter, ) from pytensor.graph.rewriting.utils import get_clients_at_depth -from pytensor.scalar import Abs +from pytensor.scalar import Abs, gt from pytensor.tensor.basic import ( Alloc, Join, @@ -29,8 +29,10 @@ cast, constant, expand_dims, + fill, get_underlying_scalar_constant_value, moveaxis, + ones, ones_like, register_infer_shape, split, @@ -56,6 +58,7 @@ deg2rad, digamma, dot, + eq, erf, erfc, exp, @@ -69,6 +72,7 @@ log1mexp, log1p, log1pexp, + lt, makeKeepDims, maximum, mul, @@ -121,6 +125,7 @@ TensorConstant, TensorVariable, ) +from pytensor.xtensor.math import minimum def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): @@ -1529,15 +1534,32 @@ def local_elemwise_sub_zeros(fgraph, node): @register_specialize @register_stabilize @register_canonicalize -@node_rewriter([Elemwise]) -def local_useless_elemwise_comparison(fgraph, node): - """... +@node_rewriter([lt, le, gt, ge, eq, minimum, maximum]) +def local_useless_comparison_same_input(fgraph, node): + [x, y] = node.inputs + if x is y: + out = node.outputs[0] + dtype = node.outputs[0].type.dtype + if isinstance(node.op.scalar_op, ps.LT | ps.GT): + res = zeros((), dtype=dtype) + elif isinstance(node.op.scalar_op, ps.LE | ps.GE | ps.EQ): + res = ones((), dtype=dtype) + else: # isinstance(node.op, (minimum, maximum)) + res = x + if res.type.broadcastable != out.type.broadcastable: + res = fill(y, res) + # Copy over stacktrace from previous output. + copy_stack_trace(out, res) + return [res] - # Comparing to itself is constant - Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) - Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) - Elemwise[{minimum,maximum}](X, X) -> X +@register_useless +@register_specialize +@register_stabilize +@register_canonicalize +@node_rewriter([Elemwise]) +def local_useless_elemwise_shape_comparison(fgraph, node): + """ # Comparing shape to 0 can be constant Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X) Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) @@ -1568,37 +1590,6 @@ def local_useless_elemwise_comparison(fgraph, node): dtype = node.outputs[0].type.dtype out_bcast = node.outputs[0].type.broadcastable - # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) - if ( - isinstance(node.op.scalar_op, ps.LT | ps.GT) - and node.inputs[0] is node.inputs[1] - ): - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - - # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) - if ( - isinstance(node.op.scalar_op, ps.LE | ps.GE) - and node.inputs[0] is node.inputs[1] - ): - res = ones_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - - # Elemwise[{minimum,maximum}](X, X) -> X - if ( - isinstance(node.op.scalar_op, ps.ScalarMinimum | ps.ScalarMaximum) - and node.inputs[0] is node.inputs[1] - ): - res = node.inputs[0] - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - # Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X) if ( isinstance(node.op.scalar_op, ps.LT)