Skip to content

Commit 2f499c4

Browse files
committed
Benchmark another FusionOptimizer graph
1 parent 8c3b113 commit 2f499c4

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,6 @@ def elemwise_to_scalar(inputs, outputs):
567567
return scalar_inputs, scalar_outputs
568568

569569
def apply(self, fgraph):
570-
nb_replacement = 0
571-
572570
if fgraph.profile:
573571
validate_before = fgraph.profile.validate_time
574572
callbacks_before = fgraph.execute_callbacks_times.copy()
@@ -923,6 +921,8 @@ def update_fuseable_mappings_after_fg_replace(
923921
starting_nodes=starting_nodes,
924922
)
925923

924+
nb_fused = 0
925+
nb_replacement = 0
926926
for inputs, outputs in find_next_fuseable_subgraph(fgraph):
927927
if (len(inputs) + len(outputs)) > max_operands:
928928
warn(
@@ -941,11 +941,13 @@ def update_fuseable_mappings_after_fg_replace(
941941
if old_out.name:
942942
composite_out.name = old_out.name
943943

944+
starting_nodes = len(fgraph.apply_nodes)
944945
fgraph.replace_all_validate(
945946
list(zip(outputs, composite_outputs, strict=True)),
946947
reason=self.__class__.__name__,
947948
)
948-
nb_replacement += 1
949+
nb_fused += 1
950+
nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1
949951

950952
if fgraph.profile:
951953
validate_time = fgraph.profile.validate_time - validate_before
@@ -963,7 +965,7 @@ def update_fuseable_mappings_after_fg_replace(
963965

964966
return (
965967
self,
966-
1, # nb_iter
968+
nb_fused,
967969
nb_replacement,
968970
0, # nb_inconsintency_replace
969971
validate_time,
@@ -976,7 +978,7 @@ def update_fuseable_mappings_after_fg_replace(
976978
def print_profile(stream, prof, level=0):
977979
blanc = " " * level
978980
print(blanc, "FusionOptimizer", file=stream)
979-
print(blanc, " nb_iter", prof[1], file=stream)
981+
print(blanc, " nb_fused", prof[1], file=stream)
980982
print(blanc, " nb_replacement", prof[2], file=stream)
981983
print(blanc, " nb_inconsistency_replace", prof[3], file=stream)
982984
print(blanc, " validate_time", prof[4], file=stream)

tests/tensor/rewriting/test_elemwise.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ def my_init(dtype="float64", num=0):
273273
fwx = fw + fx
274274
ftanx = tan(fx)
275275

276-
def large_fuseable_graph(self, n):
276+
@staticmethod
277+
def large_fuseable_graph(n):
277278
factors = []
278279
sd = dscalar()
279280
means = dvector()
@@ -296,6 +297,24 @@ def large_fuseable_graph(self, n):
296297
dlogp = [pytensor.grad(logp, v) for v in vars]
297298
return vars, dlogp
298299

300+
@staticmethod
301+
def deep_small_kernels(n):
302+
x = pt.matrix("x")
303+
out = x
304+
for _ in range(n):
305+
out = pt.sin(out.T) + pt.cos(out)
306+
307+
return [x], [out]
308+
309+
@staticmethod
310+
def diamond_graph(n):
311+
a = pt.matrix("a")
312+
b = pt.exp(a)
313+
c = pt.log(b)
314+
d = pt.sin(c)
315+
e = c + d
316+
return [a], [e]
317+
299318
@pytest.mark.parametrize(
300319
"case",
301320
[
@@ -1347,16 +1366,27 @@ def test_eval_benchmark(self, benchmark):
13471366
benchmark(func)
13481367

13491368
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
1350-
def test_rewrite_benchmark(self, benchmark):
1351-
inps, outs = self.large_fuseable_graph(n=25)
1369+
@pytest.mark.parametrize(
1370+
"graph_fn, n, expected_n_repl",
1371+
[
1372+
# ("diamond_graph", None, (1, 4)),
1373+
("deep_small_kernels", 20, (20, 60)),
1374+
("large_fuseable_graph", 25, (103, 876)),
1375+
],
1376+
)
1377+
def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark):
1378+
inps, outs = getattr(self, graph_fn)(n)
13521379
fg = FunctionGraph(inps, outs)
13531380
opt = FusionOptimizer()
13541381

13551382
def rewrite_func():
1356-
nb_replacement = opt.apply(fg.clone())[2]
1357-
return nb_replacement
1383+
fg_clone = fg.clone()
1384+
_, nb_fused, nb_replacement, *_ = opt.apply(fg_clone)
1385+
# fg_clone.dprint()
1386+
return nb_fused, nb_replacement
13581387

1359-
assert benchmark(rewrite_func) == 103
1388+
assert rewrite_func() == expected_n_repl
1389+
benchmark.pedantic(rewrite_func, rounds=7, iterations=5)
13601390

13611391
def test_no_warning_from_old_client(self):
13621392
# There used to be a warning issued when creating fuseable mapping

0 commit comments

Comments
 (0)