diff --git a/doc/extending/graph_rewriting.rst b/doc/extending/graph_rewriting.rst index 2112d5f276..0bb4c9fa7f 100644 --- a/doc/extending/graph_rewriting.rst +++ b/doc/extending/graph_rewriting.rst @@ -134,7 +134,7 @@ computation graph. In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`, and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with another while respecting certain validation constraints. As an -exercise, try to rewrite :class:`Simplify` using :class:`NodeFinder`. (Hint: you +exercise, try to rewrite :class:`Simplify` using :class:`WalkingGraphRewriter`. (Hint: you want to use the method it publishes instead of the call to toposort) Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by diff --git a/doc/library/graph/features.rst b/doc/library/graph/features.rst index fa343c2813..e6fef424e4 100644 --- a/doc/library/graph/features.rst +++ b/doc/library/graph/features.rst @@ -26,7 +26,3 @@ Guide .. class:: ReplaceValidate(History, Validator) .. method:: replace_validate(fgraph, var, new_var, reason=None) - -.. class:: NodeFinder(Bookkeeper) - -.. class:: PrintListener(object) diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 7611a380bd..74744d6732 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -827,100 +827,6 @@ def validate(self, fgraph): raise InconsistencyError("Trying to reintroduce a removed node") -class NodeFinder(Bookkeeper): - def __init__(self): - self.fgraph = None - self.d = {} - - def on_attach(self, fgraph): - if hasattr(fgraph, "get_nodes"): - raise AlreadyThere("NodeFinder is already present") - - if self.fgraph is not None and self.fgraph != fgraph: - raise Exception("A NodeFinder instance can only serve one FunctionGraph.") - - self.fgraph = fgraph - fgraph.get_nodes = partial(self.query, fgraph) - Bookkeeper.on_attach(self, fgraph) - - def clone(self): - return type(self)() - - def on_detach(self, fgraph): - """ - Should remove any dynamically added functionality - that it installed into the function_graph - """ - if self.fgraph is not fgraph: - raise Exception( - "This NodeFinder instance was not attached to the provided fgraph." - ) - self.fgraph = None - del fgraph.get_nodes - Bookkeeper.on_detach(self, fgraph) - - def on_import(self, fgraph, node, reason): - try: - self.d.setdefault(node.op, []).append(node) - except TypeError: # node.op is unhashable - return - except Exception as e: - print("OFFENDING node", type(node), type(node.op), file=sys.stderr) # noqa: T201 - try: - print("OFFENDING node hash", hash(node.op), file=sys.stderr) # noqa: T201 - except Exception: - print("OFFENDING node not hashable", file=sys.stderr) # noqa: T201 - raise e - - def on_prune(self, fgraph, node, reason): - try: - nodes = self.d[node.op] - except TypeError: # node.op is unhashable - return - nodes.remove(node) - if not nodes: - del self.d[node.op] - - def query(self, fgraph, op): - try: - all = self.d.get(op, []) - except TypeError: - raise TypeError( - f"{op} in unhashable and cannot be queried by the optimizer" - ) - all = list(all) - return all - - -class PrintListener(Feature): - def __init__(self, active=True): - self.active = active - - def on_attach(self, fgraph): - if self.active: - print("-- attaching to: ", fgraph) # noqa: T201 - - def on_detach(self, fgraph): - """ - Should remove any dynamically added functionality - that it installed into the function_graph - """ - if self.active: - print("-- detaching from: ", fgraph) # noqa: T201 - - def on_import(self, fgraph, node, reason): - if self.active: - print(f"-- importing: {node}, reason: {reason}") # noqa: T201 - - def on_prune(self, fgraph, node, reason): - if self.active: - print(f"-- pruning: {node}, reason: {reason}") # noqa: T201 - - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): - if self.active: - print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") # noqa: T201 - - class PreserveVariableAttributes(Feature): """ This preserve some variables attributes and tag during optimization. diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 5b45fa40f4..66d5f844b1 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -11,12 +11,10 @@ import warnings from collections import Counter, UserList, defaultdict, deque from collections.abc import Callable, Iterable, Sequence -from collections.abc import Iterable as IterableType -from functools import _compose_mro, partial, reduce # type: ignore +from functools import _compose_mro, partial # type: ignore from itertools import chain -from typing import TYPE_CHECKING, Literal +from typing import Literal -import pytensor from pytensor.configdefaults import config from pytensor.graph import destroyhandler as dh from pytensor.graph.basic import ( @@ -28,18 +26,15 @@ io_toposort, vars_between, ) -from pytensor.graph.features import AlreadyThere, Feature, NodeFinder +from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op +from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.misc.ordered_set import OrderedSet from pytensor.utils import flatten -if TYPE_CHECKING: - from pytensor.graph.rewriting.unify import Var - - _logger = logging.getLogger("pytensor.graph.rewriting.basic") RemoveKeyType = Literal["remove"] @@ -60,14 +55,6 @@ ] -class MetaNodeRewriterSkip(AssertionError): - """This is an `AssertionError`, but instead of having the - `MetaNodeRewriter` print the error, it just skip that - compilation. - - """ - - class Rewriter(abc.ABC): """Abstract base class for graph/term rewriters.""" @@ -942,129 +929,6 @@ def recursive_merge(var): return [recursive_merge(v) for v in variables] -class MetaNodeRewriter(NodeRewriter): - r""" - Base class for meta-rewriters that try a set of `NodeRewriter`\s - to replace a node and choose the one that executes the fastest. - - If the error `MetaNodeRewriterSkip` is raised during - compilation, we will skip that function compilation and not print - the error. - - """ - - def __init__(self): - self.verbose = config.metaopt__verbose - self.track_dict = defaultdict(list) - self.tag_dict = defaultdict(list) - self._tracks = [] - self.rewriters = [] - - def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]): - self.rewriters.append(rewriter) - - tracks = rewriter.tracks() - if tracks: - self._tracks.extend(tracks) - for c in tracks: - self.track_dict[c].append(rewriter) - - for tag in tag_list: - self.tag_dict[tag].append(rewriter) - - def tracks(self): - return self._tracks - - def transform(self, fgraph, node, *args, **kwargs): - # safety check: depending on registration, tracks may have been ignored - if self._tracks is not None: - if not isinstance(node.op, tuple(self._tracks)): - return - # first, we need to provide dummy values for all inputs - # to the node that are not shared variables anyway - givens = {} - missing = set() - for input in node.inputs: - if isinstance(input, pytensor.compile.SharedVariable): - pass - elif hasattr(input.tag, "test_value"): - givens[input] = pytensor.shared( - input.type.filter(input.tag.test_value), - input.name, - shape=input.broadcastable, - borrow=True, - ) - else: - missing.add(input) - if missing: - givens.update(self.provide_inputs(node, missing)) - missing.difference_update(givens.keys()) - # ensure we have data for all input variables that need it - if missing: - if self.verbose > 0: - print( # noqa: T201 - f"{self.__class__.__name__} cannot meta-rewrite {node}, " - f"{len(missing)} of {int(node.nin)} input shapes unknown" - ) - return - # now we can apply the different rewrites in turn, - # compile the resulting subgraphs and time their execution - if self.verbose > 1: - print( # noqa: T201 - f"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):" - ) - timings = [] - for node_rewriter in self.get_rewrites(node): - outputs = node_rewriter.transform(fgraph, node, *args, **kwargs) - if outputs: - try: - fn = pytensor.function( - [], outputs, givens=givens, on_unused_input="ignore" - ) - fn.trust_input = True - timing = min(self.time_call(fn) for _ in range(2)) - except MetaNodeRewriterSkip: - continue - except Exception as e: - if self.verbose > 0: - print(f"* {node_rewriter}: exception", e) # noqa: T201 - continue - else: - if self.verbose > 1: - print(f"* {node_rewriter}: {timing:.5g} sec") # noqa: T201 - timings.append((timing, outputs, node_rewriter)) - else: - if self.verbose > 0: - print(f"* {node_rewriter}: not applicable") # noqa: T201 - # finally, we choose the fastest one - if timings: - timings.sort() - if self.verbose > 1: - print(f"= {timings[0][2]}") # noqa: T201 - return timings[0][1] - return - - def provide_inputs(self, node, inputs): - """Return a dictionary mapping some `inputs` to `SharedVariable` instances of with dummy values. - - The `node` argument can be inspected to infer required input shapes. - - """ - raise NotImplementedError() - - def get_rewrites(self, node): - """Return the rewrites that apply to `node`. - - This uses ``self.track_dict[type(node.op)]`` by default. - """ - return self.track_dict[type(node.op)] - - def time_call(self, fn): - start = time.perf_counter() - fn() - return time.perf_counter() - start - - class FromFunctionNodeRewriter(NodeRewriter): """A `NodeRewriter` constructed from a function.""" @@ -1214,9 +1078,6 @@ class SequentialNodeRewriter(NodeRewriter): reentrant : bool Some global rewriters, like `NodeProcessingGraphRewriter`, use this value to determine if they should ignore new nodes. - retains_inputs : bool - States whether or not the inputs of a transformed node are transferred - to the outputs. """ def __init__( @@ -1247,9 +1108,6 @@ def __init__( self.reentrant = any( getattr(rewrite, "reentrant", True) for rewrite in rewriters ) - self.retains_inputs = all( - getattr(rewrite, "retains_inputs", False) for rewrite in rewriters - ) self.apply_all_rewrites = apply_all_rewrites @@ -1425,17 +1283,12 @@ class SubstitutionNodeRewriter(NodeRewriter): # an SubstitutionNodeRewriter does not apply to the nodes it produces reentrant = False - # all the inputs of the original node are transferred to the outputs - retains_inputs = True def __init__(self, op1, op2, transfer_tags=True): self.op1 = op1 self.op2 = op2 self.transfer_tags = transfer_tags - def op_key(self): - return self.op1 - def tracks(self): return [self.op1] @@ -1453,39 +1306,6 @@ def __str__(self): return f"{self.op1} -> {self.op2}" -class RemovalNodeRewriter(NodeRewriter): - """ - Removes all applications of an `Op` by transferring each of its - outputs to the corresponding input. - - """ - - reentrant = False # no nodes are added at all - - def __init__(self, op): - self.op = op - - def op_key(self): - return self.op - - def tracks(self): - return [self.op] - - def transform(self, fgraph, node): - if node.op != self.op: - return False - return node.inputs - - def __str__(self): - return f"{self.op}(x) -> x" - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - print( - f"{' ' * level}{self.__class__.__name__}(self.op) id={id(self)}", - file=stream, - ) - - class PatternNodeRewriter(NodeRewriter): """Replace all occurrences of an input pattern with an output pattern. @@ -1545,7 +1365,6 @@ def __init__( in_pattern, out_pattern, allow_multiple_clients: bool = False, - skip_identities_fn=None, name: str | None = None, tracks=(), get_nodes=None, @@ -1563,8 +1382,6 @@ def __init__( allow_multiple_clients If ``False``, the pattern matching will fail if one of the subpatterns has more than one client. - skip_identities_fn - TODO name Set the name of this rewriter. tracks @@ -1574,19 +1391,17 @@ def __init__( function that takes the tracked node and returns a list of nodes on which we will try this rewrite. values_eq_approx - TODO + If specified, this value will be assigned to the ``values_eq_approx`` + tag of the output variable. This is used by DebugMode to determine if rewrites are correct. allow_cast Automatically cast the output of the rewrite whenever new and old types differ Notes ----- `tracks` and `get_nodes` can be used to make this rewrite track a less - frequent `Op`, which will prevent the rewrite from being tried as - often. + frequent `Op`, which will prevent the rewrite from being tried as often. """ - from pytensor.graph.rewriting.unify import convert_strs_to_vars - var_map: dict[str, Var] = {} self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map) self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) @@ -1600,9 +1415,7 @@ def __init__( raise TypeError( "The pattern to search for must start with a specific Op instance." ) - self.__doc__ = f"{self.__class__.__doc__}\n\nThis instance does: {self}\n" self.allow_multiple_clients = allow_multiple_clients - self.skip_identities_fn = skip_identities_fn if name: self.__name__ = name self._tracks = tracks @@ -1610,9 +1423,6 @@ def __init__( if tracks != (): assert get_nodes - def op_key(self): - return self.op - def tracks(self): if self._tracks != (): return self._tracks @@ -1633,9 +1443,6 @@ def transform(self, fgraph, node, get_nodes=True): if ret is not False and ret is not None: return dict(zip(real_node.outputs, ret, strict=True)) - if node.op != self.op: - return False - if len(node.outputs) != 1: # PatternNodeRewriter doesn't support replacing multi-output nodes return False @@ -1664,11 +1471,13 @@ def transform(self, fgraph, node, get_nodes=True): [old_out] = node.outputs if not old_out.type.is_super(ret.type): + from pytensor.tensor.type import TensorType + # Type doesn't match if not ( self.allow_cast - and isinstance(old_out.type, pytensor.tensor.TensorType) - and isinstance(ret.type, pytensor.tensor.TensorType) + and isinstance(old_out.type, TensorType) + and isinstance(ret.type, TensorType) ): return False @@ -2136,7 +1945,7 @@ def walking_rewriter( else: (node_rewriters,) = node_rewriters if not name: - name = node_rewriters.__name__ + name = getattr(node_rewriters, "__name__", None) ret = WalkingGraphRewriter( node_rewriters, order=order, @@ -2152,52 +1961,6 @@ def walking_rewriter( out2in = partial(walking_rewriter, "out_to_in") -class OpKeyGraphRewriter(NodeProcessingGraphRewriter): - r"""A rewriter that applies a `NodeRewriter` to specific `Op`\s. - - The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either - as a list of `Op`\s or a single `Op`), and discovered within a - `FunctionGraph` using the `NodeFinder` `Feature`. - - This is similar to the `Op`-based tracking feature used by other rewriters. - - """ - - def __init__(self, node_rewriter, ignore_newtrees=False, failure_callback=None): - if not hasattr(node_rewriter, "op_key"): - raise TypeError(f"{node_rewriter} must have an `op_key` method.") - super().__init__(node_rewriter, ignore_newtrees, failure_callback) - - def apply(self, fgraph): - op = self.node_rewriter.op_key() - if isinstance(op, list | tuple): - q = reduce(list.__iadd__, map(fgraph.get_nodes, op)) - else: - q = list(fgraph.get_nodes(op)) - - def importer(node): - if node is not current_node: - if node.op == op: - q.append(node) - - u = self.attach_updater( - fgraph, importer, None, name=getattr(self, "name", None) - ) - try: - while q: - node = q.pop() - if node not in fgraph.apply_nodes: - continue - current_node = node - self.process_node(fgraph, node) - finally: - self.detach_updater(fgraph, u) - - def add_requirements(self, fgraph): - super().add_requirements(fgraph) - fgraph.attach_feature(NodeFinder()) - - class ChangeTracker(Feature): def __init__(self): self.changed = False @@ -2785,38 +2548,6 @@ def merge(rewriters, attr, idx): ) -def _check_chain(r, chain): - """ - WRITEME - - """ - chain = list(reversed(chain)) - while chain: - elem = chain.pop() - if elem is None: - if r.owner is not None: - return False - elif r.owner is None: - return False - elif isinstance(elem, Op): - if r.owner.op != elem: - return False - else: - try: - if issubclass(elem, Op) and not isinstance(r.owner.op, elem): - return False - except TypeError: - return False - if chain: - r = r.owner.inputs[chain.pop()] - # print 'check_chain', _check_chain.n_calls - # _check_chain.n_calls += 1 - - # The return value will be used as a Boolean, but some Variables cannot - # be used as Booleans (the results of comparisons, for instance) - return r is not None - - def pre_greedy_node_rewriter( fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable ) -> Variable: @@ -2998,10 +2729,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): otherwise. """ - if isinstance(f_or_fgraph, pytensor.compile.function.types.Function): - fgraph = f_or_fgraph.maker.fgraph - elif isinstance(f_or_fgraph, pytensor.graph.fg.FunctionGraph): + from pytensor.compile.function.types import Function + + if isinstance(f_or_fgraph, FunctionGraph): fgraph = f_or_fgraph + elif isinstance(f_or_fgraph, Function): + fgraph = f_or_fgraph.maker.fgraph else: raise ValueError("The type of f_or_fgraph is not supported") diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 83ee8c2c3b..e9c2c8e47e 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -34,7 +34,6 @@ from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, NodeRewriter, - RemovalNodeRewriter, Rewriter, copy_stack_trace, in2out, @@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node): return [alloc(inputs_inner[0], *dims_outer)] -register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy") +@register_canonicalize +@node_rewriter(tracks=[tensor_copy]) +def remove_tensor_copy(fgraph, node): + return node.inputs @register_specialize diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index f8156067f9..df2355ca12 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10): return np.allclose(x, ref, rtol=rtol, atol=atol) -def _skip_mul_1(r): - if r.owner and r.owner.op == mul: - not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] - if len(not_is_1) == 1: - return not_is_1[0] - - def _is_1(expr): """ @@ -3190,7 +3183,6 @@ def _is_1(expr): (neg, (softplus, (neg, "x"))), allow_multiple_clients=True, values_eq_approx=values_eq_approx_remove_inf, - skip_identities_fn=_skip_mul_1, tracks=[sigmoid], get_nodes=get_clients_at_depth1, ) @@ -3199,7 +3191,6 @@ def _is_1(expr): (neg, (softplus, "x")), allow_multiple_clients=True, values_eq_approx=values_eq_approx_remove_inf, - skip_identities_fn=_skip_mul_1, tracks=[sigmoid], get_nodes=get_clients_at_depth2, ) diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 0990dbeca0..90589db337 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -13,7 +13,7 @@ from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config from pytensor.graph.basic import Constant -from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter +from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter from pytensor.graph.utils import MissingInputError from pytensor.link.vm import VMLinker from pytensor.printing import debugprint @@ -39,7 +39,7 @@ def PatternOptimizer(p1, p2, ign=True): - return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) + return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) class TestFunction: diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index d0cb94f9fb..228c93a8c8 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -8,11 +8,9 @@ from pytensor.graph.rewriting.basic import ( EquilibriumGraphRewriter, MergeOptimizer, - OpKeyGraphRewriter, OpToRewriterTracker, PatternNodeRewriter, SequentialNodeRewriter, - SubstitutionNodeRewriter, WalkingGraphRewriter, in2out, logging, @@ -51,33 +49,29 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): raise AssertionError() -def OpKeyPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False): - return OpKeyGraphRewriter( +def WalkingPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False): + return WalkingGraphRewriter( PatternNodeRewriter(p1, p2, allow_multiple_clients=allow_multiple_clients), ignore_newtrees=ign, ) -def WalkingPatternNodeRewriter(p1, p2, ign=True): - return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) - - class TestPatternNodeRewriter: def test_replace_output(self): # replacing the whole graph x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).rewrite( - g - ) + WalkingPatternNodeRewriter( + (op1, (op2, "1", "2"), "3"), (op4, "3", "2") + ).rewrite(g) assert str(g) == "FunctionGraph(Op4(z, y))" def test_nested_out_pattern(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(x, y) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2")) ).rewrite(g) assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))" @@ -86,7 +80,7 @@ def test_unification_1(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, x), z) # the arguments to op2 are the same g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op4, "2", "1"), ).rewrite(g) @@ -97,7 +91,7 @@ def test_unification_2(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) # the arguments to op2 are different g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op4, "2", "1"), ).rewrite(g) @@ -109,7 +103,7 @@ def test_replace_subgraph(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter((op2, "1", "2"), (op1, "2", "1")).rewrite(g) + WalkingPatternNodeRewriter((op2, "1", "2"), (op1, "2", "1")).rewrite(g) assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))" def test_no_recurse(self): @@ -119,7 +113,9 @@ def test_no_recurse(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite(g) + WalkingPatternNodeRewriter((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite( + g + ) assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))" def test_multiple(self): @@ -127,7 +123,7 @@ def test_multiple(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), op2(x, y), op2(y, z)) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter((op2, "1", "2"), (op4, "1")).rewrite(g) + WalkingPatternNodeRewriter((op2, "1", "2"), (op4, "1")).rewrite(g) assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))" def test_nested_even(self): @@ -136,21 +132,21 @@ def test_nested_even(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(x)))) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) + WalkingPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) assert str(g) == "FunctionGraph(x)" def test_nested_odd(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(op1(x))))) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) + WalkingPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) assert str(g) == "FunctionGraph(Op1(x))" def test_expand(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(x))) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g) + WalkingPatternNodeRewriter((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g) assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))" def test_ambiguous(self): @@ -169,7 +165,7 @@ def test_constant(self): z = Constant(MyType(), 2, name="z") e = op1(op1(x, y), y) g = FunctionGraph([y], [e]) - OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g) + WalkingPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g) assert str(g) == "FunctionGraph(Op1(Op2(y, z{2}), y))" def test_constraints(self): @@ -181,7 +177,7 @@ def constraint(r): # Only replacing if the input is an instance of Op2 return r.owner.op == op2 - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op1, {"pattern": "1", "constraint": constraint}), (op3, "1") ).rewrite(g) assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))" @@ -190,7 +186,7 @@ def test_match_same(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(x, x) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter((op1, "x", "y"), (op3, "x", "y")).rewrite(g) + WalkingPatternNodeRewriter((op1, "x", "y"), (op3, "x", "y")).rewrite(g) assert str(g) == "FunctionGraph(Op3(x, x))" @pytest.mark.xfail( @@ -202,10 +198,10 @@ def test_match_same_illegal(self): g = FunctionGraph([x, y, z], [e]) def constraint(r): - # Only replacing if the input is an instance of Op2 + # Only replacing if the inputs are not identical return r.owner.inputs[0] is not r.owner.inputs[1] - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( {"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y") ).rewrite(g) assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))" @@ -220,7 +216,7 @@ def test_allow_multiple_clients(self): # So the replacement should fail outputs = [e] g = FunctionGraph(inputs, outputs, copy_inputs=False) - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op4, (op1, "x", "y")), (op3, "x", "y"), ).rewrite(g) @@ -228,7 +224,7 @@ def test_allow_multiple_clients(self): # Now it should be fine g = FunctionGraph(inputs, outputs, copy_inputs=False) - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op4, (op1, "x", "y")), (op3, "x", "y"), allow_multiple_clients=True, @@ -237,7 +233,7 @@ def test_allow_multiple_clients(self): # The fact that the inputs of the pattern have multiple clients should not matter g = FunctionGraph(inputs, outputs, copy_inputs=False) - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op3, (op4, "w"), "w"), (op3, "w", "w"), allow_multiple_clients=False, @@ -252,7 +248,7 @@ def test_allow_multiple_clients(self): outputs = [e1, e2] g = FunctionGraph(inputs, outputs, copy_inputs=False) - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op4, (op4, "e")), "e", allow_multiple_clients=False, @@ -261,7 +257,7 @@ def test_allow_multiple_clients(self): outputs = [e1, e3] g = FunctionGraph([x, y, z], outputs, copy_inputs=False) - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op4, (op4, "e")), "e", allow_multiple_clients=False, @@ -269,7 +265,7 @@ def test_allow_multiple_clients(self): assert equal_computations(g.outputs, outputs) g = FunctionGraph(inputs, outputs, copy_inputs=False) - OpKeyPatternNodeRewriter( + WalkingPatternNodeRewriter( (op4, (op4, "e")), "e", allow_multiple_clients=True, @@ -281,33 +277,13 @@ def test_eq(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op_y(x, y), z) g = FunctionGraph([x, y, z], [e]) - OpKeyPatternNodeRewriter((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).rewrite( - g - ) + WalkingPatternNodeRewriter( + (op1, (op_z, "1", "2"), "3"), (op4, "3", "2") + ).rewrite(g) str_g = str(g) assert str_g == "FunctionGraph(Op4(z, y))" -def KeyedSubstitutionNodeRewriter(op1, op2): - return OpKeyGraphRewriter(SubstitutionNodeRewriter(op1, op2)) - - -class TestSubstitutionNodeRewriter: - def test_straightforward(self): - x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") - e = op1(op1(op1(op1(op1(x))))) - g = FunctionGraph([x, y, z], [e]) - KeyedSubstitutionNodeRewriter(op1, op2).rewrite(g) - assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))" - - def test_straightforward_2(self): - x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") - e = op1(op2(x), op3(y), op4(z)) - g = FunctionGraph([x, y, z], [e]) - KeyedSubstitutionNodeRewriter(op3, op4).rewrite(g) - assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))" - - class NoInputOp(Op): __props__ = ("param",) diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index 16a654da26..70333f369b 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -10,7 +10,6 @@ from pytensor.graph.op import Op from pytensor.graph.rewriting.basic import ( NodeProcessingGraphRewriter, - OpKeyGraphRewriter, PatternNodeRewriter, SubstitutionNodeRewriter, WalkingGraphRewriter, @@ -21,7 +20,7 @@ def OpKeyPatternNodeRewriter(p1, p2, ign=True): - return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) + return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) def TopoSubstitutionNodeRewriter( diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index d23caf52ee..474104269d 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -2,92 +2,12 @@ import pytensor.tensor as pt from pytensor.graph import rewrite_graph -from pytensor.graph.basic import Apply, Variable, equal_computations -from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate +from pytensor.graph.basic import equal_computations +from pytensor.graph.features import Feature, FullHistory, ReplaceValidate from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import Op -from pytensor.graph.type import Type from tests.graph.utils import MyVariable, op1 -class TestNodeFinder: - def test_straightforward(self): - class MyType(Type): - def __init__(self, name): - self.name = name - - def filter(self, *args, **kwargs): - raise NotImplementedError() - - def __str__(self): - return self.name - - def __repr__(self): - return self.name - - def __eq__(self, other): - return isinstance(other, MyType) - - class MyOp(Op): - __props__ = ("nin", "name") - - def __init__(self, nin, name): - self.nin = nin - self.name = name - - def make_node(self, *inputs): - def as_variable(x): - assert isinstance(x, Variable) - return x - - assert len(inputs) == self.nin - inputs = list(map(as_variable, inputs)) - for input in inputs: - if not isinstance(input.type, MyType): - raise Exception("Error 1") - outputs = [MyType(self.name + "_R")()] - return Apply(self, inputs, outputs) - - def __str__(self): - return self.name - - def perform(self, *args, **kwargs): - raise NotImplementedError() - - sigmoid = MyOp(1, "Sigmoid") - add = MyOp(2, "Add") - dot = MyOp(2, "Dot") - - def MyVariable(name): - return Variable(MyType(name), None, None) - - def inputs(): - x = MyVariable("x") - y = MyVariable("y") - z = MyVariable("z") - return x, y, z - - x, y, z = inputs() - e0 = dot(y, z) - e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) - g = FunctionGraph([x, y, z], [e], clone=False) - g.attach_feature(NodeFinder()) - - assert hasattr(g, "get_nodes") - for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): - if len(list(g.get_nodes(type))) != num: - raise Exception(f"Expected: {num} times {type}") - new_e0 = add(y, z) - assert e0.owner in g.get_nodes(dot) - assert new_e0.owner not in g.get_nodes(add) - g.replace(e0, new_e0) - assert e0.owner not in g.get_nodes(dot) - assert new_e0.owner in g.get_nodes(add) - for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): - if len(list(g.get_nodes(type))) != num: - raise Exception(f"Expected: {num} times {type}") - - class TestReplaceValidate: def test_verbose(self, capsys): var1 = MyVariable("var1")