Skip to content

Commit 7a7d26f

Browse files
committed
Benchmark InplaceOptimizer
1 parent 2e1758b commit 7a7d26f

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/compile/function/test_types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.compile.io import In, Out
1313
from pytensor.compile.mode import Mode, get_default_mode
1414
from pytensor.configdefaults import config
15+
from pytensor.graph import FunctionGraph
1516
from pytensor.graph.basic import Constant
1617
from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter
1718
from pytensor.graph.utils import MissingInputError
@@ -21,6 +22,7 @@
2122
from pytensor.tensor.math import sum as pt_sum
2223
from pytensor.tensor.random import normal
2324
from pytensor.tensor.random.type import random_generator_type
25+
from pytensor.tensor.rewriting.elemwise import FusionOptimizer
2426
from pytensor.tensor.type import (
2527
dmatrix,
2628
dscalar,
@@ -1371,3 +1373,18 @@ def compile_function(mode=mode, depth=depth):
13711373
return fn
13721374

13731375
benchmark.pedantic(compile_function, iterations=20, rounds=5)
1376+
1377+
1378+
def test_benchmark_fusion_optimizer(benchmark):
1379+
x = pt.matrix("x")
1380+
out = x
1381+
for _ in range(20):
1382+
out = pt.sin(out.T) + pt.cos(out)
1383+
1384+
optimizer = FusionOptimizer()
1385+
1386+
def foo():
1387+
fg = FunctionGraph(outputs=[out])
1388+
optimizer.apply(fg)
1389+
1390+
benchmark(foo)

0 commit comments

Comments
 (0)