File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change 1212from pytensor .compile .io import In , Out
1313from pytensor .compile .mode import Mode , get_default_mode
1414from pytensor .configdefaults import config
15+ from pytensor .graph import FunctionGraph
1516from pytensor .graph .basic import Constant
1617from pytensor .graph .rewriting .basic import PatternNodeRewriter , WalkingGraphRewriter
1718from pytensor .graph .utils import MissingInputError
2122from pytensor .tensor .math import sum as pt_sum
2223from pytensor .tensor .random import normal
2324from pytensor .tensor .random .type import random_generator_type
25+ from pytensor .tensor .rewriting .elemwise import FusionOptimizer
2426from pytensor .tensor .type import (
2527 dmatrix ,
2628 dscalar ,
@@ -1371,3 +1373,18 @@ def compile_function(mode=mode, depth=depth):
13711373 return fn
13721374
13731375 benchmark .pedantic (compile_function , iterations = 20 , rounds = 5 )
1376+
1377+
1378+ def test_benchmark_fusion_optimizer (benchmark ):
1379+ x = pt .matrix ("x" )
1380+ out = x
1381+ for _ in range (20 ):
1382+ out = pt .sin (out .T ) + pt .cos (out )
1383+
1384+ optimizer = FusionOptimizer ()
1385+
1386+ def foo ():
1387+ fg = FunctionGraph (outputs = [out ])
1388+ optimizer .apply (fg )
1389+
1390+ benchmark (foo )
You can’t perform that action at this time.
0 commit comments