diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 339da84cd1..769a5dfeeb 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -13,7 +13,6 @@ import builtins import math from collections.abc import Callable -from copy import copy from itertools import chain from textwrap import dedent from typing import Any, TypeAlias @@ -779,9 +778,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType: This caches objects to save allocation and run time. """ - if dtype not in cache: - cache[dtype] = ScalarType(dtype=dtype) - return cache[dtype] + try: + return cache[dtype] + except KeyError: + cache[dtype] = res = ScalarType(dtype=dtype) + return res # Register C code for ViewOp on Scalars. @@ -987,25 +988,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant: def as_scalar(x: Any, name: str | None = None) -> ScalarVariable: - from pytensor.tensor.basic import scalar_from_tensor - from pytensor.tensor.type import TensorType + if isinstance(x, ScalarVariable): + return x + + if isinstance(x, Variable): + from pytensor.tensor.basic import scalar_from_tensor + from pytensor.tensor.type import TensorType + + if isinstance(x.type, TensorType) and x.type.ndim == 0: + return scalar_from_tensor(x) + else: + raise TypeError(f"Cannot convert {x} to a scalar type") if isinstance(x, Apply): + # FIXME: Why do we support calling this with Apply? + # Also, if we do, why can't we support multiple outputs? if len(x.outputs) != 1: raise ValueError( "It is ambiguous which output of a multi-output" " Op has to be fetched.", x, ) - else: - x = x.outputs[0] - if isinstance(x, Variable): - if isinstance(x, ScalarVariable): - return x - elif isinstance(x.type, TensorType) and x.type.ndim == 0: - return scalar_from_tensor(x) - else: - raise TypeError(f"Cannot convert {x} to a scalar type") + return as_scalar(x.outputs[0]) return constant(x) @@ -1329,32 +1333,26 @@ def supports_c_code(self, inputs, outputs): the given Elemwise inputs, outputs. """ - try: - tmp_s_input = [] - # To keep the same aliasing between inputs - mapping = dict() - for ii in inputs: - if ii in mapping: - tmp_s_input.append(mapping[ii]) - else: - tmp = get_scalar_type(ii.dtype).make_variable() - tmp_s_input.append(tmp) - mapping[ii] = tmp_s_input[-1] - - with config.change_flags(compute_test_value="ignore"): - s_op = self(*tmp_s_input, return_list=True) + tmp_s_input = [] + # To keep the same aliasing between inputs + mapping = {} + for ii in inputs: + if ii in mapping: + tmp_s_input.append(mapping[ii]) + else: + tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable() + tmp_s_input.append(tmp) - # if the scalar_op don't have a c implementation, - # we skip its fusion to allow the fusion of the - # other ops. + try: self.c_code( - s_op[0].owner, + self.make_node(*tmp_s_input), "test_presence_of_c_code", + # FIXME: Shouldn't this be a unique name per unique variable? ["x" for x in inputs], ["z" for z in outputs], {"fail": "%(fail)s"}, ) - except (MethodNotDefined, NotImplementedError): + except (NotImplementedError, MethodNotDefined): return False return True @@ -4094,12 +4092,12 @@ def __init__(self, *args, **kwargs): self.prepare_node_called = set() super().__init__(*args, **kwargs) - def _cleanup_graph(self, inputs, outputs): + def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True): # TODO: We could convert to TensorVariable, optimize graph, # and then convert back to ScalarVariable. # This would introduce rewrites like `log(1 + x) -> log1p`. - fgraph = FunctionGraph(copy(inputs), copy(outputs)) + fgraph = FunctionGraph(inputs, outputs, clone=clone) # Validate node types for node in fgraph.apply_nodes: @@ -4282,7 +4280,9 @@ class Composite(ScalarInnerGraphOp): init_param: tuple[str, ...] = ("inputs", "outputs") - def __init__(self, inputs, outputs, name="Composite"): + def __init__( + self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True + ): self.name = name self._name = None # We need to clone the graph as sometimes its nodes already @@ -4300,10 +4300,13 @@ def __init__(self, inputs, outputs, name="Composite"): if len(outputs) > 1 or not any( isinstance(var.owner.op, Composite) for var in outputs ): - # No inner Composite - inputs, outputs = clone(inputs, outputs) + if clone_graph: + inputs, outputs = clone(inputs, outputs) + else: # Inner Composite that we need to flatten + # FIXME: There could be a composite in the middle of the graph, why is this here? + # If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway. assert len(outputs) == 1 # 1. Create a new graph from inputs up to the # Composite @@ -4322,7 +4325,8 @@ def __init__(self, inputs, outputs, name="Composite"): assert res[0] != inputs inputs, outputs = res[0], res2[1] - self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) + # We already cloned the graph, or the user told us there was no need for it + self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False) self.inputs_type = tuple(input.type for input in self.inputs) self.outputs_type = tuple(output.type for output in self.outputs) self.nin = len(inputs) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e2d420f361..192d289af5 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,10 +2,10 @@ import itertools import operator import sys -from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce -from typing import TypeVar +from heapq import heapify, heappop, heappush +from operator import or_ from warnings import warn import pytensor.scalar.basic as ps @@ -28,7 +28,7 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern -from pytensor.graph.traversal import ancestors +from pytensor.graph.traversal import toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( @@ -530,47 +530,26 @@ def add_requirements(self, fgraph): @staticmethod def elemwise_to_scalar(inputs, outputs): - replace_inputs = [(inp, inp.clone()) for inp in inputs] - outputs = clone_replace(outputs, replace=replace_inputs) - - inputs = [inp for _, inp in replace_inputs] - fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) - middle_inputs = [] - - scalar_inputs = [ - ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs - ] - middle_scalar_inputs = [] - - for node in fg.toposort(): - node_scalar_inputs = [] - for inp in node.inputs: - if inp in inputs: - node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) - elif inp in middle_inputs: - node_scalar_inputs.append( - middle_scalar_inputs[middle_inputs.index(inp)] + replacement = { + inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + } + for node in toposort(outputs, blockers=inputs): + scalar_inputs = [replacement[inp] for inp in node.inputs] + replacement.update( + dict( + zip( + node.outputs, + node.op.scalar_op.make_node(*scalar_inputs).outputs, ) - else: - new_scalar_input = ps.get_scalar_type( - inp.type.dtype - ).make_variable() - node_scalar_inputs.append(new_scalar_input) - middle_scalar_inputs.append(new_scalar_input) - middle_inputs.append(inp) - - new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) - middle_scalar_inputs.append(new_scalar_node.outputs[0]) - middle_inputs.append(node.outputs[0]) - - scalar_outputs = [ - middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs - ] - return scalar_inputs, scalar_outputs + ) + ) - def apply(self, fgraph): - nb_replacement = 0 + return ( + [replacement[inp] for inp in inputs], + [replacement[out] for out in outputs], + ) + def apply(self, fgraph): if fgraph.profile: validate_before = fgraph.profile.validate_time callbacks_before = fgraph.execute_callbacks_times.copy() @@ -578,376 +557,380 @@ def apply(self, fgraph): max_operands = elemwise_max_operands_fct(None) - def find_next_fuseable_subgraph( + def find_fuseable_subgraphs( fg: FunctionGraph, - ) -> Generator[tuple[list[Variable], list[Variable]], None, None]: - """Find all subgraphs in a FunctionGraph that can be fused together - - Yields - ------- - List of inputs and outputs that determine subgraphs which can be fused. - This generator assumes that such subgraph is replaced by a single - Elemwise Composite before being accessed again in the next iteration. + ) -> Generator[tuple[tuple[Variable], tuple[Variable]], None, None]: + """Find subgraphs of Elemwise nodes that can be fused together. + + In general there is no single solution, we try to find large subgraphs eagerly + + Any two consecutive Elemwise nodes that have the same broadcasting pattern, + and a C-implementation (historical accident that should be revisited), are potentially fuseable. + + However, we need to be careful about keeping the fused subgraph "convex", meaning that no two + nodes in the same subgraph are connected via a path that goes outside the subgraph, either because they + are connected via unfuseable nodes, or nodes that have been claimed by another subgraph. + + For example the graph add(sin(exp(x)), sum(exp(x)) cannot be fused into a single Elemwise, because the sum + node breaks the convexity of the subgraph {exp, sin, add}. However, we can fuse {exp, sin}, + and perhaps fuse add with somethnig else. """ - FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]] - UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] - - def initialize_fuseable_mappings( - *, fg: FunctionGraph - ) -> tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]: - @cache - def elemwise_scalar_op_has_c_code(node: Apply) -> bool: - # TODO: This should not play a role in non-c backends! - if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): - return True - else: - if config.optimizer_verbose: + class FuseableClients: + __slots__ = ("fuseable_clients", "candidate_nodes") + + def __init__(self, fgraph): + @cache + def elemwise_scalar_op_has_c_code( + node: Apply, optimizer_verbose=config.optimizer_verbose + ) -> bool: + # TODO: This should not play a role in non-c backends! + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + elif optimizer_verbose: warn( f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." ) return False - # Fuseable nodes have to be accessed in a deterministic manner - # to ensure the rewrite remains deterministic. - # This is not a problem from unfuseable ones, as they can never - # become part of the graph. - fuseable_clients: FUSEABLE_MAPPING = defaultdict(list) - unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) - for out, clients in fg.clients.items(): - # Old FunctionGraph nodes remain in the clients dictionary - # even after they are removed by rewrites - if not clients: - continue - - out_maybe_fuseable = ( - out.owner - and isinstance(out.owner.op, Elemwise) - # and not isinstance(out.owner.op.scalar_op, ps.Composite) - and len(out.owner.outputs) == 1 - and elemwise_scalar_op_has_c_code(out.owner) - ) - for client, _ in clients: - if ( - out_maybe_fuseable - and isinstance(client.op, Elemwise) - # and not isinstance(client.op.scalar_op, ps.Composite) - and len(client.outputs) == 1 - and out.type.broadcastable - == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) + self.fuseable_clients = fuseable_clients = {} + self.candidate_nodes = candidate_nodes = set() + fg_clients = fg.clients + for out, clients_and_indices in fg_clients.items(): + out_node = out.owner + + if not ( + out_node is not None + and len(out_node.outputs) == 1 + and isinstance(out_node.op, Elemwise) + and elemwise_scalar_op_has_c_code(out_node) ): - if client not in fuseable_clients[out]: - fuseable_clients[out].append(client) - else: - unfuseable_clients[out].add(client) - - return fuseable_clients, unfuseable_clients - - def find_fuseable_subgraph( - *, - fg: FunctionGraph, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - ) -> tuple[list[Variable], list[Variable]]: - KT = TypeVar("KT") - VT = TypeVar("VT", list, set) - - def shallow_clone_defaultdict( - d: defaultdict[KT, VT], - ) -> defaultdict[KT, VT]: - new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory) - new_dict.update({k: v.copy() for k, v in d.items()}) - return new_dict - - def variables_depend_on( - variables, depend_on, stop_search_at=None - ) -> bool: - return any( - a in depend_on - for a in ancestors(variables, blockers=stop_search_at) - ) + continue - toposort = fg.toposort() - for starting_node in toposort: - if starting_node in visited_nodes: - continue + candidate_nodes.add(out_node) + out_bcast = out.type.broadcastable + out_fuseable_clients = { + client + for client, _ in clients_and_indices + if ( + len(client.outputs) == 1 + and isinstance(client.op, Elemwise) + and out_bcast == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ) + } + if out_fuseable_clients: + fuseable_clients[out_node] = out_fuseable_clients - starting_out = starting_node.outputs[0] - if not fuseable_clients.get(starting_out): - visited_nodes.add(starting_node) - continue + def __bool__(self): + return bool(self.fuseable_clients) - subgraph_inputs: list[Variable] = [] - subgraph_outputs: list[Variable] = [] - unfuseable_clients_subgraph: set[Variable] = set() + def __getitem__(self, node: Apply): + return self.fuseable_clients.get(node, ()) + + def is_sink_node(self, node: Apply) -> bool: + # A sink node is a candidate node that has no fuseable clients + return ( + node in self.candidate_nodes + and node not in self.fuseable_clients + ) - # Shallow cloning of maps so that they can be manipulated in place - fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients) - unfuseable_clients_clone = shallow_clone_defaultdict( - unfuseable_clients + def remove_subgraph_connections(self, subgraph: "ConvexSubgraph"): + # Update fuseable clients, inputs can no longer be fused with graph variables + # and outputs can't be fused with anything else + subgraph_inputs, subgraph_outputs = ( + subgraph.get_inputs_and_outputs() ) + fuseable_clients = self.fuseable_clients + for ancestor in subgraph_inputs: + if (ancestor_node := ancestor.owner) is not None: + if ancestor_fuseable_clients := fuseable_clients.get( + ancestor_node + ): + ancestor_fuseable_clients.difference_update( + subgraph.nodes + ) + if not ancestor_fuseable_clients: + del fuseable_clients[ancestor_node] + + for out in subgraph_outputs: + fuseable_clients.pop(out.owner, None) + + class SortedFuseableNodesQueue: + __slots__ = ("queue",) + + def __init__(self): + # We keep an ordered queue for expanding the subgraph + # We always want to visit ancestors before clients + # For ancestors, we want to visit the later nodes first (those that have more dependencies) + # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) + # We negate the bitflag for ancestors to achieve this ordering + self.queue = queue = [] + heapify(queue) + + def push(self, node: Apply, node_bitflag: int, is_ancestor: bool): + if is_ancestor: + node_bitflag = -node_bitflag + heappush(self.queue, (node_bitflag, node)) + + def pop(self) -> tuple[Apply, int, bool]: + node_bitflag, node = heappop(self.queue) + return node, node_bitflag < 0 + + def __bool__(self): + return bool(self.queue) + + class NonConvexError(Exception): + pass + + class ConvexSubgraph: + __slots__ = ( + "nodes_bitflags", + "ancestors_bitset", + "nodes", + "nodes_bitset", + "unfuseable_ancestors_bitset", + "unfuseable_clients_bitset", + "inputs_and_outputs", + ) - fuseable_nodes_to_visit = deque([starting_node]) - - # We now try to expand as much as possible towards the potentially - # fuseable clients and ancestors to detect the largest possible - # subgraph that can be Composed together into a single `Op`. The - # largest issue to watch out is for cyclical dependencies, where - # some inputs or clients may depend on other nodes of the same - # subgraph via a path that cannot be included in the Composite - # (unfuseable) - while fuseable_nodes_to_visit: - next_node = fuseable_nodes_to_visit.popleft() - visited_nodes.add(next_node) - next_out = next_node.outputs[0] - - # If the output variable of next_node has no fuseable clients - # or has unfuseable clients, then next_node must become an output - # if it is to be fused. - must_become_output = ( - next_out not in fuseable_clients_temp - or next_out in unfuseable_clients_clone + def __init__(self, nodes_bitflags, ancestors_bitset): + self.nodes_bitflags = nodes_bitflags + self.ancestors_bitset = ancestors_bitset + self.nodes = {} + self.nodes_bitset = 0 + self.unfuseable_ancestors_bitset = 0 + self.unfuseable_clients_bitset = 0 + self.inputs_and_outputs = None + + def __len__(self): + return len(self.nodes) + + def __contains__(self, node: int): + return bool(self.nodes_bitset & self.nodes_bitflags[node]) + + def add_node(self, node: Apply, is_ancestor: bool): + node_bitflag = self.nodes_bitflags[node] + if is_ancestor: + if node_bitflag & self.unfuseable_ancestors_bitset: + raise NonConvexError + elif self.ancestors_bitset[node] & self.unfuseable_clients_bitset: + raise NonConvexError + self.nodes_bitset |= node_bitflag + self.nodes[node] = None + self.inputs_and_outputs = None # clear cache + + def add_unfuseable_ancestor(self, ancestor: Apply): + # If an ancestor is unfuseable, so are all its ancestors + self.unfuseable_ancestors_bitset |= self.ancestors_bitset[ancestor] + + def add_unfuseable_client(self, client: Apply): + # If a client is unfuseable, so are all its clients, but we don't need to keep track of those + # Any downstream client will also depend on this unfuseable client and will be rejected when visited + self.unfuseable_clients_bitset |= self.nodes_bitflags[client] + + def get_inputs_and_outputs(self): + if self.inputs_and_outputs is not None: + return self.inputs_and_outputs + + nodes = self.nodes + # Use a dict to deduplicate while preserving order + subgraph_inputs = tuple( + dict.fromkeys( + inp + for node in nodes + for inp in node.inputs + if (ancestor_node := inp.owner) is None + or ancestor_node not in nodes ) + ) - # We have backtracked to this node, and it may no longer be a viable output, - # so we remove it and check again as if we had never seen this node - if must_become_output and next_out in subgraph_outputs: - subgraph_outputs.remove(next_out) + subgraph_outputs = tuple( + node.outputs[0] + for node in nodes + if any( + client not in nodes + for client, _ in fg_clients[node.outputs[0]] + ) + ) + self.inputs_and_outputs = subgraph_inputs, subgraph_outputs + return subgraph_inputs, subgraph_outputs - required_unfuseable_inputs = [ - inp - for inp in next_node.inputs - if next_node in unfuseable_clients_clone.get(inp, ()) - ] - new_required_unfuseable_inputs = [ - inp - for inp in required_unfuseable_inputs - if inp not in subgraph_inputs - ] - - must_backtrack = False - if new_required_unfuseable_inputs and subgraph_outputs: - # We need to check that any new inputs required by this node - # do not depend on other outputs of the current subgraph, - # via an unfuseable path. - if variables_depend_on( - [next_out], - depend_on=unfuseable_clients_subgraph, - stop_search_at=subgraph_outputs, - ): - must_backtrack = True - - if not must_backtrack: - implied_unfuseable_clients = { - c - for client in unfuseable_clients_clone.get(next_out, ()) - if not isinstance(client.op, Output) - for c in client.outputs - } - - new_implied_unfuseable_clients = ( - implied_unfuseable_clients - unfuseable_clients_subgraph - ) + class SortedSubgraphCollection: + __slots__ = ("subgraphs", "nodes_bitset") - if new_implied_unfuseable_clients and subgraph_inputs: - # We need to check that any inputs of the current subgraph - # do not depend on other clients of this node, - # via an unfuseable path. - if variables_depend_on( - subgraph_inputs, - depend_on=new_implied_unfuseable_clients, - ): - must_backtrack = True - - if must_backtrack: - for inp in next_node.inputs: - if ( - inp.owner in visited_nodes - # next_node could have the same input repeated - and next_node in fuseable_clients_temp[inp] - ): - fuseable_clients_temp[inp].remove(next_node) - unfuseable_clients_clone[inp].add(next_node) - # This input must become an output of the subgraph, - # because it can't be merged with next_node. - # We will revisit it to make sure this is safe. - fuseable_nodes_to_visit.appendleft(inp.owner) - - for client in fuseable_clients_temp[next_out]: - if client in visited_nodes: - fuseable_clients_temp[next_out].remove(client) - unfuseable_clients_clone[next_out].add(client) - # next_out must become an input of the subgraph. - # We will revisit any of its clients currently - # in the subgraph to make sure this is safe. - fuseable_nodes_to_visit.appendleft(client) - - # Revisit node at a later time - visited_nodes.remove(next_node) - continue + def __init__(self): + self.subgraphs: list[ + tuple[int, tuple[tuple[Variable], tuple[Variable]]] + ] = [] + self.nodes_bitset = 0 - # Adding next_node to subgraph does not result in any - # immediate dependency problems. Update subgraph - # mappings as if it next_node was part of it. - # Useless inputs will be removed by the useless Composite rewrite - for inp in new_required_unfuseable_inputs: - if inp not in subgraph_inputs: - subgraph_inputs.append(inp) - - if must_become_output: - subgraph_outputs.append(next_out) - unfuseable_clients_subgraph.update( - new_implied_unfuseable_clients - ) + def __contains__(self, node_bitflag: int): + return bool(node_bitflag & self.nodes_bitset) - # Expand through unvisited fuseable ancestors - for inp in sorted( - ( - inp - for inp in next_node.inputs - if ( - inp not in required_unfuseable_inputs - and inp.owner not in visited_nodes - ) - ), - key=lambda inp: toposort.index(inp.owner), - reverse=True, - ): - fuseable_nodes_to_visit.appendleft(inp.owner) - - # Expand through unvisited fuseable clients - for next_node in sorted( - ( - node - for node in fuseable_clients_temp.get(next_out, ()) - if node not in visited_nodes - ), - key=lambda node: toposort.index(node), + def insert_subgraph(self, subgraph: ConvexSubgraph): + # Usually new subgraphs don't depend on previous subgraphs, so we can just append them at the end + # But in some cases they can, so we need to insert at the right position. + subgraph_unfuseable_ancestors_bitset = ( + subgraph.unfuseable_ancestors_bitset + ) + if not (subgraph_unfuseable_ancestors_bitset & self.nodes_bitset): + self.subgraphs.append(subgraph) + else: + # Iterate from the end, removing the bitsets of each previous subgraphs until our current subgraph + # no longer depends on what's left. This tells us where to insert the current subgraph. + remaining_nodes_bitset = self.nodes_bitset + for index, other_subgraph in enumerate( + reversed(self.subgraphs) ): - fuseable_nodes_to_visit.append(next_node) - - # Don't return if final subgraph is just the original Elemwise - if len(subgraph_outputs) == 1 and set( - subgraph_outputs[0].owner.inputs - ) == set(subgraph_inputs): - # Update global fuseable mappings - # No input was actually fuseable - for inp in starting_node.inputs: - if starting_node in fuseable_clients.get(inp, ()): - fuseable_clients[inp].remove(starting_node) - unfuseable_clients[inp].add(starting_node) - # No client was actually fuseable - unfuseable_clients[starting_out].update( - fuseable_clients.pop(starting_out, ()) - ) + remaining_nodes_bitset &= ~other_subgraph.nodes_bitset + if not ( + subgraph_unfuseable_ancestors_bitset + & remaining_nodes_bitset + ): + break + self.subgraphs.insert(-(index + 1), subgraph) + self.nodes_bitset |= subgraph.nodes_bitset + + def __iter__(self): + yield from self.subgraphs + + fuseable_clients = FuseableClients(fgraph) + + if not fuseable_clients: + return None + + # Create a bitset of ancestors for each node. + # Each node is represented by a bit flag of it's position in the toposort + # With two variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c, + # the ancestors bit flags would be {A: 0b001, B: 0b010, C: 0b100} + # and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111} + # This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND + # For example, to ask if B is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[B] != 0` + # We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do + # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` + fg_clients = fgraph.clients + nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} + # Root variables have `None` as owner, which we can handle with a bitset of 0 for `None` + ancestors_bitset = {None: 0} + for node, node_bitflag in nodes_bitflags.items(): + # The bitset of each node is the union of the bitsets of its inputs, plus its own bit + ancestors_bitset[node] = reduce( + or_, + (ancestors_bitset[inp.owner] for inp in node.inputs), + node_bitflag, + ) + # handle root and leaf nodes gracefully + # Root variables have `None` as owner, which we can handle with a bitflag of 0 for `None` + nodes_bitflags[None] = 0 + # Nothing ever depends on output nodes, so just use a new bit for all + out_bitflag = 1 << len(nodes_bitflags) + for out in fg.outputs: + for client, _ in fg_clients[out]: + if isinstance(client.op, Output): + nodes_bitflags[client] = out_bitflag + + sorted_subgraphs = SortedSubgraphCollection() + + # Start exploring from candidate sink nodes (backwards) + # These are Elemwise nodes with a C-implementation, that are not part of another subgraph + # And have no other fuseable clients (i.e., are sinks) + for starting_node, starting_bitflag in reversed(nodes_bitflags.items()): + if ( + starting_bitflag in sorted_subgraphs + or not fuseable_clients.is_sink_node(starting_node) + ): + continue + + subgraph = ConvexSubgraph(nodes_bitflags, ancestors_bitset) + + fuseable_nodes_queue = SortedFuseableNodesQueue() + fuseable_nodes_queue.push( + starting_node, starting_bitflag, is_ancestor=True + ) + while fuseable_nodes_queue: + node, is_ancestor = fuseable_nodes_queue.pop() + + if node in subgraph: continue - return subgraph_inputs, subgraph_outputs - raise ValueError - - def update_fuseable_mappings_after_fg_replace( - *, - fg: FunctionGraph, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - starting_nodes: set[Apply], - ) -> None: - # Find new composite node and dropped intermediate nodes - # by comparing the current fg.apply nodes with the cached - # original nodes - next_nodes = fg.apply_nodes - (new_composite_node,) = next_nodes - starting_nodes - dropped_nodes = starting_nodes - next_nodes - - # Remove intermediate Composite nodes from mappings - for dropped_node in dropped_nodes: - (dropped_out,) = dropped_node.outputs - fuseable_clients.pop(dropped_out, None) - unfuseable_clients.pop(dropped_out, None) - visited_nodes.remove(dropped_node) - - # Update fuseable information for subgraph inputs - for inp in subgraph_inputs: - if inp in fuseable_clients: - new_fuseable_clients = [ - client - for client in fuseable_clients[inp] - if client not in dropped_nodes - ] - if new_fuseable_clients: - fuseable_clients[inp] = new_fuseable_clients + try: + subgraph.add_node(node, is_ancestor=is_ancestor) + except NonConvexError: + continue + + # Expand through ancestors and client nodes + # A node can either be: + # - already part of the subgraph (skip) + # - fuseable (add to queue) + # - unfuseable (add to respective unfuseable bitset) + for ancestor in node.inputs: + ancestor_node = ancestor.owner + if ancestor_node in subgraph: + continue + if node in fuseable_clients[ancestor_node]: + fuseable_nodes_queue.push( + ancestor_node, + nodes_bitflags[ancestor_node], + is_ancestor=True, + ) else: - fuseable_clients.pop(inp) - unfuseable_clients[inp] = ( - unfuseable_clients[inp] - dropped_nodes - ) | {new_composite_node} - - # Update fuseable information for subgraph outputs - for out in new_composite_node.outputs: - unfuseable_clients[out] = {client for client, _ in fg.clients[out]} - - visited_nodes.add(new_composite_node) - return - - # We start by creating two maps, 1) from each node to each potentially - # fuseable client (both nodes must be single output Elemwise with same - # broadcast type) and 2) from each node to each certainly unfuseable - # client (those that don't fit into 1)) - fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) - visited_nodes: set[Apply] = set() - while True: - starting_nodes = fg.apply_nodes.copy() - try: - subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( - fg=fg, - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, - ) - except ValueError: - return - else: - # The caller is now expected to update fg in place, - # by replacing the subgraph with a Composite Op - yield subgraph_inputs, subgraph_outputs - - # This is where we avoid repeated work by using a stateful - # generator. For large models (as in `TestFusion.test_big_fusion`) - # this can provide huge speedups - update_fuseable_mappings_after_fg_replace( - fg=fg, - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, - starting_nodes=starting_nodes, - ) + subgraph.add_unfuseable_ancestor(ancestor_node) + + next_fuseable_clients = fuseable_clients[node] + for client_node, _ in fg_clients[node.outputs[0]]: + if client_node in subgraph: + continue + if client_node in next_fuseable_clients: + fuseable_nodes_queue.push( + client_node, + nodes_bitflags[client_node], + is_ancestor=False, + ) + else: + subgraph.add_unfuseable_client(client_node) - for inputs, outputs in find_next_fuseable_subgraph(fgraph): + # Finished exploring this subgraph + if len(subgraph) == 1: + # No fusion possible, single node subgraph + continue + + sorted_subgraphs.insert_subgraph(subgraph) + # Mark the nodes of this subgraph as no longer fuseable + fuseable_clients.remove_subgraph_connections(subgraph) + + yield from ( + subgraph.get_inputs_and_outputs() for subgraph in sorted_subgraphs + ) + + nb_fused = 0 + nb_replacement = 0 + for inputs, outputs in find_fuseable_subgraphs(fgraph): if (len(inputs) + len(outputs)) > max_operands: warn( "Loop fusion failed because the resulting node would exceed " "the kernel argument limit." ) - break + continue scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) - composite_outputs = Elemwise(ps.Composite(scalar_inputs, scalar_outputs))( - *inputs - ) - if not isinstance(composite_outputs, list): - composite_outputs = [composite_outputs] - for old_out, composite_out in zip(outputs, composite_outputs, strict=True): + composite_outputs = Elemwise( + # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables + ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False) + )(*inputs, return_list=True) + assert len(outputs) == len(composite_outputs) + for old_out, composite_out in zip(outputs, composite_outputs): + # Preserve any names on the original outputs if old_out.name: composite_out.name = old_out.name + starting_nodes = len(fgraph.apply_nodes) fgraph.replace_all_validate( - list(zip(outputs, composite_outputs, strict=True)), + tuple(zip(outputs, composite_outputs)), reason=self.__class__.__name__, ) - nb_replacement += 1 + nb_fused += 1 + nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1 if fgraph.profile: validate_time = fgraph.profile.validate_time - validate_before @@ -965,7 +948,7 @@ def update_fuseable_mappings_after_fg_replace( return ( self, - 1, # nb_iter + nb_fused, nb_replacement, 0, # nb_inconsintency_replace validate_time, @@ -978,7 +961,7 @@ def update_fuseable_mappings_after_fg_replace( def print_profile(stream, prof, level=0): blanc = " " * level print(blanc, "FusionOptimizer", file=stream) - print(blanc, " nb_iter", prof[1], file=stream) + print(blanc, " nb_fused", prof[1], file=stream) print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream) print(blanc, " validate_time", prof[4], file=stream) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index c23d0ac23a..7e625043ec 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -273,7 +273,8 @@ def my_init(dtype="float64", num=0): fwx = fw + fx ftanx = tan(fx) - def large_fuseable_graph(self, n): + @staticmethod + def large_fuseable_graph(n): factors = [] sd = dscalar() means = dvector() @@ -296,6 +297,24 @@ def large_fuseable_graph(self, n): dlogp = [pytensor.grad(logp, v) for v in vars] return vars, dlogp + @staticmethod + def deep_small_kernels(n): + x = pt.matrix("x") + out = x + for _ in range(n): + out = pt.sin(out.T) + pt.cos(out) + + return [x], [out] + + @staticmethod + def diamond_graph(n): + a = pt.matrix("a") + b = pt.exp(a) + c = pt.log(b) + d = pt.sin(c) + e = c + d + return [a], [e] + @pytest.mark.parametrize( "case", [ @@ -1347,16 +1366,27 @@ def test_eval_benchmark(self, benchmark): benchmark(func) @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") - def test_rewrite_benchmark(self, benchmark): - inps, outs = self.large_fuseable_graph(n=25) + @pytest.mark.parametrize( + "graph_fn, n, expected_n_repl", + [ + # ("diamond_graph", None, (1, 4)), + ("deep_small_kernels", 20, (20, 60)), + ("large_fuseable_graph", 25, (128, 876)), + ], + ) + def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): + inps, outs = getattr(self, graph_fn)(n) fg = FunctionGraph(inps, outs) opt = FusionOptimizer() def rewrite_func(): - nb_replacement = opt.apply(fg.clone())[2] - return nb_replacement + fg_clone = fg.clone() + _, nb_fused, nb_replacement, *_ = opt.apply(fg_clone) + # fg_clone.dprint() + return nb_fused, nb_replacement - assert benchmark(rewrite_func) == 103 + assert rewrite_func() == expected_n_repl + benchmark.pedantic(rewrite_func, rounds=7, iterations=5) def test_no_warning_from_old_client(self): # There used to be a warning issued when creating fuseable mapping diff --git a/tests/test_printing.py b/tests/test_printing.py index 95c3c938cf..dbad8c063b 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -301,7 +301,8 @@ def test_debugprint(): Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv" exp_res = dedent( r""" - Composite{(i2 + (i0 - i1))} 4 + Composite{(i0 + (i1 - i2))} 4 + ├─ A ├─ ExpandDims{axis=0} v={0: [0]} 3 """ f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2" @@ -313,17 +314,16 @@ def test_debugprint(): │ ├─ B │ ├─ │ └─ 0.0 - ├─ D - └─ A + └─ D Inner graphs: - Composite{(i2 + (i0 - i1))} + Composite{(i0 + (i1 - i2))} ← add 'o0' - ├─ i2 - └─ sub ├─ i0 - └─ i1 + └─ sub + ├─ i1 + └─ i2 """ ).lstrip()