@@ -948,6 +948,8 @@ def update_fuseable_mappings_after_fg_replace(
948948 starting_nodes = starting_nodes ,
949949 )
950950
951+ checkpoint = fgraph .checkpoint ()
952+ nb_inconsintency_replace = 0
951953 for inputs , outputs in find_next_fuseable_subgraph (fgraph ):
952954 if (len (inputs ) + len (outputs )) > max_operands :
953955 warn (
@@ -966,12 +968,22 @@ def update_fuseable_mappings_after_fg_replace(
966968 if old_out .name :
967969 composite_out .name = old_out .name
968970
969- fgraph .replace_all_validate (
971+ fgraph .replace_all (
970972 list (zip (outputs , composite_outputs , strict = True )),
971973 reason = self .__class__ .__name__ ,
972974 )
973975 nb_replacement += 1
974976
977+ try :
978+ fgraph .validate ()
979+ except InconsistencyError :
980+ warn (
981+ f"{ self .__class__ .__name__ } produced an inconsistent graph. Reverting to the previous state."
982+ )
983+ fgraph .revert (checkpoint )
984+ nb_inconsintency_replace = nb_replacement
985+ nb_replacement = 0
986+
975987 if fgraph .profile :
976988 validate_time = fgraph .profile .validate_time - validate_before
977989 callback_time = fgraph .execute_callbacks_time - callback_before
@@ -990,7 +1002,7 @@ def update_fuseable_mappings_after_fg_replace(
9901002 self ,
9911003 1 , # nb_iter
9921004 nb_replacement ,
993- 0 , # nb_inconsintency_replace
1005+ nb_inconsintency_replace ,
9941006 validate_time ,
9951007 callback_time ,
9961008 callbacks_time ,
0 commit comments