Skip to content

Commit b4c7afa

Browse files
committed
.don't validate in every iteration of fusion rewriter
1 parent 95064eb commit b4c7afa

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,8 @@ def update_fuseable_mappings_after_fg_replace(
948948
starting_nodes=starting_nodes,
949949
)
950950

951+
checkpoint = fgraph.checkpoint()
952+
nb_inconsintency_replace = 0
951953
for inputs, outputs in find_next_fuseable_subgraph(fgraph):
952954
if (len(inputs) + len(outputs)) > max_operands:
953955
warn(
@@ -966,12 +968,22 @@ def update_fuseable_mappings_after_fg_replace(
966968
if old_out.name:
967969
composite_out.name = old_out.name
968970

969-
fgraph.replace_all_validate(
971+
fgraph.replace_all(
970972
list(zip(outputs, composite_outputs, strict=True)),
971973
reason=self.__class__.__name__,
972974
)
973975
nb_replacement += 1
974976

977+
try:
978+
fgraph.validate()
979+
except InconsistencyError:
980+
warn(
981+
f"{self.__class__.__name__} produced an inconsistent graph. Reverting to the previous state."
982+
)
983+
fgraph.revert(checkpoint)
984+
nb_inconsintency_replace = nb_replacement
985+
nb_replacement = 0
986+
975987
if fgraph.profile:
976988
validate_time = fgraph.profile.validate_time - validate_before
977989
callback_time = fgraph.execute_callbacks_time - callback_before
@@ -990,7 +1002,7 @@ def update_fuseable_mappings_after_fg_replace(
9901002
self,
9911003
1, # nb_iter
9921004
nb_replacement,
993-
0, # nb_inconsintency_replace
1005+
nb_inconsintency_replace,
9941006
validate_time,
9951007
callback_time,
9961008
callbacks_time,

0 commit comments

Comments
 (0)