Skip to content

Commit 5f8cee6

Browse files
committed
Test more FusionOptimizer graphs
1 parent 96122d1 commit 5f8cee6

File tree

2 files changed

+46
-11
lines changed

2 files changed

+46
-11
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,6 @@ def elemwise_to_scalar(inputs, outputs):
569569
return scalar_inputs, scalar_outputs
570570

571571
def apply(self, fgraph):
572-
nb_replacement = 0
573-
574572
if fgraph.profile:
575573
validate_before = fgraph.profile.validate_time
576574
callbacks_before = fgraph.execute_callbacks_times.copy()
@@ -925,6 +923,8 @@ def update_fuseable_mappings_after_fg_replace(
925923
starting_nodes=starting_nodes,
926924
)
927925

926+
nb_fused = 0
927+
nb_replacement = 0
928928
for inputs, outputs in find_next_fuseable_subgraph(fgraph):
929929
if (len(inputs) + len(outputs)) > max_operands:
930930
warn(
@@ -943,11 +943,13 @@ def update_fuseable_mappings_after_fg_replace(
943943
if old_out.name:
944944
composite_out.name = old_out.name
945945

946+
starting_nodes = len(fgraph.apply_nodes)
946947
fgraph.replace_all_validate(
947948
list(zip(outputs, composite_outputs, strict=True)),
948949
reason=self.__class__.__name__,
949950
)
950-
nb_replacement += 1
951+
nb_fused += 1
952+
nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1
951953

952954
if fgraph.profile:
953955
validate_time = fgraph.profile.validate_time - validate_before
@@ -965,7 +967,7 @@ def update_fuseable_mappings_after_fg_replace(
965967

966968
return (
967969
self,
968-
1, # nb_iter
970+
nb_fused,
969971
nb_replacement,
970972
0, # nb_inconsintency_replace
971973
validate_time,
@@ -978,7 +980,7 @@ def update_fuseable_mappings_after_fg_replace(
978980
def print_profile(stream, prof, level=0):
979981
blanc = " " * level
980982
print(blanc, "FusionOptimizer", file=stream)
981-
print(blanc, " nb_iter", prof[1], file=stream)
983+
print(blanc, " nb_fused", prof[1], file=stream)
982984
print(blanc, " nb_replacement", prof[2], file=stream)
983985
print(blanc, " nb_inconsistency_replace", prof[3], file=stream)
984986
print(blanc, " validate_time", prof[4], file=stream)

tests/tensor/rewriting/test_elemwise.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ def my_init(dtype="float64", num=0):
273273
fwx = fw + fx
274274
ftanx = tan(fx)
275275

276-
def large_fuseable_graph(self, n):
276+
@staticmethod
277+
def large_fuseable_graph(n):
277278
factors = []
278279
sd = dscalar()
279280
means = dvector()
@@ -296,6 +297,28 @@ def large_fuseable_graph(self, n):
296297
dlogp = [pytensor.grad(logp, v) for v in vars]
297298
return vars, dlogp
298299

300+
@staticmethod
301+
def deep_small_kernels(n):
302+
x = pt.matrix("x")
303+
out = x
304+
for _ in range(n):
305+
out = pt.sin(out.T) + pt.cos(out)
306+
307+
return [x], [out]
308+
309+
@staticmethod
310+
def test_diamond_graph():
311+
a = pt.matrix("a")
312+
b = pt.exp(a)
313+
c = pt.log(b)
314+
d = pt.sin(c)
315+
e = c + d
316+
317+
fg = FunctionGraph([a], [e], clone=False)
318+
_, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg)
319+
assert nb_fused == 1
320+
assert nb_replacement == 4
321+
299322
@pytest.mark.parametrize(
300323
"case",
301324
[
@@ -1347,16 +1370,26 @@ def test_eval_benchmark(self, benchmark):
13471370
benchmark(func)
13481371

13491372
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
1350-
def test_rewrite_benchmark(self, benchmark):
1351-
inps, outs = self.large_fuseable_graph(n=25)
1373+
@pytest.mark.parametrize(
1374+
"graph_fn, n, expected_n_repl",
1375+
[
1376+
("deep_small_kernels", 20, (20, 60)),
1377+
("large_fuseable_graph", 25, (103, 876)),
1378+
],
1379+
)
1380+
def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark):
1381+
inps, outs = getattr(self, graph_fn)(n)
13521382
fg = FunctionGraph(inps, outs)
13531383
opt = FusionOptimizer()
13541384

13551385
def rewrite_func():
1356-
nb_replacement = opt.apply(fg.clone())[2]
1357-
return nb_replacement
1386+
fg_clone = fg.clone()
1387+
_, nb_fused, nb_replacement, *_ = opt.apply(fg_clone)
1388+
# fg_clone.dprint()
1389+
return nb_fused, nb_replacement
13581390

1359-
assert benchmark(rewrite_func) == 103
1391+
assert rewrite_func() == expected_n_repl
1392+
benchmark.pedantic(rewrite_func, rounds=7, iterations=5)
13601393

13611394
def test_no_warning_from_old_client(self):
13621395
# There used to be a warning issued when creating fuseable mapping

0 commit comments

Comments
 (0)