@@ -90,8 +90,6 @@ def process_initial_eval(
9090 use_weights : bool = False ,
9191) -> Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]:
9292
93- initial_eval = _parse_forward_out (initial_eval )
94-
9593 # number of elements in the output of forward_func
9694 n_outputs = initial_eval .numel ()
9795
@@ -153,6 +151,74 @@ def format_result(
153151 return _format_output (is_inputs_tuple , attrib )
154152
155153
154+ def check_output_shape_valid (
155+ inputs : TensorOrTupleOfTensorsGeneric ,
156+ num_examples : int ,
157+ initial_eval : Tensor ,
158+ modified_eval : Tensor ,
159+ perturbations_per_eval : int ,
160+ ) -> None :
161+ """
162+ Validates that the forward function's output shape scales correctly with
163+ input batch size when perturbations_per_eval > 1.
164+
165+ When multiple perturbations are evaluated simultaneously
166+ (perturbations_per_eval > 1),
167+ the forward function must return outputs whose first dimension grows proportionally
168+ with the input batch size. This ensures the forward function is not aggregating
169+ results across the batch, which would prevent correct attribution calculation.
170+
171+ Args:
172+ inputs (Tensor or tuple[Tensor, ...]): Input tensors used for evaluation.
173+ The first dimension of inputs[0] is used to determine current
174+ batch size.
175+ num_examples (int): The original number of examples (batch size) before
176+ expansion for perturbations.
177+ initial_eval (Tensor): Output from forward function with original batch size
178+ (perturbations_per_eval = 1). Used as baseline for shape comparison.
179+ modified_eval (Tensor): Output from forward function with expanded batch size
180+ (batch_size = num_examples * n_perturb).
181+ perturbations_per_eval (int): Number of perturbations processed simultaneously.
182+ Validation only occurs when this value is greater than 1.
183+
184+ Raises:
185+ AssertionError: If perturbations_per_eval > 1 and the output shape does not
186+ scale correctly. Specifically, if modified_eval.shape[0] is not
187+ equal to n_perturb * initial_eval.shape[0], where n_perturb is
188+ the ratio of current batch size to original batch size.
189+ """
190+
191+ if perturbations_per_eval > 1 :
192+ # if perturbations_per_eval > 1, the output shape must grow with
193+ # input and not be aggregated
194+ current_batch_size = inputs [0 ].shape [0 ]
195+
196+ # number of perturbation, which is not the same as
197+ # perturbations_per_eval when not enough features to perturb
198+ n_perturb : int = current_batch_size // num_examples
199+ mod_perturb : int = current_batch_size % num_examples
200+ current_output_shape = modified_eval .shape
201+
202+ # use initial_eval as the forward of perturbations_per_eval = 1
203+ initial_output_shape = initial_eval .shape
204+
205+ assert (
206+ # check if the output is not a scalar
207+ current_output_shape
208+ and initial_output_shape
209+ and mod_perturb == 0
210+ # check if the output grow in same ratio, i.e., not agg
211+ and current_output_shape [0 ] == n_perturb * initial_output_shape [0 ]
212+ ), (
213+ "When perturbations_per_eval > 1, forward_func's output "
214+ "should be a tensor whose 1st dim grow with the input "
215+ f"batch size: when input batch size is { num_examples } , "
216+ f"the output shape is { initial_output_shape } ; "
217+ f"when input batch size is { current_batch_size } , "
218+ f"the output shape is { current_output_shape } "
219+ )
220+
221+
156222class FeatureAblation (PerturbationAttribution ):
157223 """
158224 A perturbation based approach to computing attribution, involving
@@ -395,7 +461,7 @@ def attribute(
395461 """
396462 # Keeps track whether original input is a tuple or not before
397463 # converting it into a tuple. We return the attribution as tuple in the
398- # end if the inputs where tuple.
464+ # end if the inputs were a tuple.
399465 is_inputs_tuple = _is_tuple (inputs )
400466
401467 formatted_inputs , baselines = _format_input_baseline (inputs , baselines )
@@ -443,7 +509,7 @@ def attribute(
443509 "when using the attribute function, initial_eval should have "
444510 f"non-Future type rather than { type (initial_eval )} "
445511 )
446-
512+ initial_eval = _parse_forward_out ( initial_eval )
447513 (
448514 total_attrib ,
449515 weights ,
@@ -581,26 +647,33 @@ def _attribute_with_cross_tensor_feature_masks(
581647 current_target ,
582648 current_additional_args ,
583649 )
650+ modified_eval = _parse_forward_out (modified_eval )
584651
585652 attr_progress .update ()
586653
587654 assert not isinstance (modified_eval , torch .Future ), (
588655 "when use_futures is True, modified_eval should have "
589656 f"non-Future type rather than { type (modified_eval )} "
590657 )
591-
658+ # Just do the check once.
659+ if not self ._is_output_shape_valid :
660+ check_output_shape_valid (
661+ inputs = current_inputs ,
662+ num_examples = num_examples ,
663+ initial_eval = initial_eval ,
664+ modified_eval = modified_eval ,
665+ perturbations_per_eval = perturbations_per_eval ,
666+ )
667+ self ._is_output_shape_valid = True
592668 total_attrib , weights = self ._process_ablated_out_full (
593- modified_eval ,
594- current_masks ,
595- flattened_initial_eval ,
596- initial_eval ,
597- current_inputs ,
598- n_outputs ,
599- num_examples ,
600- total_attrib ,
601- weights ,
602- attrib_type ,
603- perturbations_per_eval ,
669+ modified_eval = modified_eval ,
670+ current_mask = current_masks ,
671+ flattened_initial_eval = flattened_initial_eval ,
672+ inputs = current_inputs ,
673+ n_outputs = n_outputs ,
674+ total_attrib = total_attrib ,
675+ weights = weights ,
676+ attrib_type = attrib_type ,
604677 )
605678 return total_attrib , weights
606679
@@ -705,6 +778,7 @@ def _initial_eval_to_processed_initial_eval_fut(
705778 "initial_eval_to_processed_initial_eval_fut: "
706779 "initial_eval should be a Tensor"
707780 )
781+ initial_eval_processed = _parse_forward_out (initial_eval_processed )
708782 result = process_initial_eval (
709783 initial_eval_processed , formatted_inputs , use_weights = self .use_weights
710784 )
@@ -1039,6 +1113,7 @@ def _eval_fut_to_ablated_out_fut_cross_tensor(
10391113 "total_attrib, weights, initial_eval, "
10401114 "flattened_initial_eval, n_outputs, attrib_type "
10411115 )
1116+ modified_eval = _parse_forward_out (modified_eval )
10421117 if not isinstance (modified_eval , Tensor ):
10431118 raise AssertionError (
10441119 "_eval_fut_to_ablated_out_fut_cross_tensor: "
@@ -1052,13 +1127,21 @@ def _eval_fut_to_ablated_out_fut_cross_tensor(
10521127 n_outputs ,
10531128 attrib_type ,
10541129 ) = initial_eval_tuple
1130+ # Just do the check once.
1131+ if not self ._is_output_shape_valid :
1132+ check_output_shape_valid (
1133+ inputs = current_inputs ,
1134+ num_examples = num_examples ,
1135+ initial_eval = initial_eval ,
1136+ modified_eval = modified_eval ,
1137+ perturbations_per_eval = perturbations_per_eval ,
1138+ )
1139+ self ._is_output_shape_valid = True
1140+
10551141 total_attrib , weights = self ._process_ablated_out_full (
10561142 modified_eval = modified_eval ,
10571143 inputs = current_inputs ,
10581144 current_mask = current_mask ,
1059- perturbations_per_eval = perturbations_per_eval ,
1060- num_examples = num_examples ,
1061- initial_eval = initial_eval ,
10621145 flattened_initial_eval = flattened_initial_eval ,
10631146 n_outputs = n_outputs ,
10641147 total_attrib = total_attrib ,
@@ -1076,47 +1159,12 @@ def _process_ablated_out_full(
10761159 modified_eval : Tensor ,
10771160 current_mask : Tuple [Optional [Tensor ], ...],
10781161 flattened_initial_eval : Tensor ,
1079- initial_eval : Tensor ,
10801162 inputs : TensorOrTupleOfTensorsGeneric ,
10811163 n_outputs : int ,
1082- num_examples : int ,
10831164 total_attrib : List [Tensor ],
10841165 weights : List [Tensor ],
10851166 attrib_type : dtype ,
1086- perturbations_per_eval : int ,
10871167 ) -> Tuple [List [Tensor ], List [Tensor ]]:
1088- modified_eval = _parse_forward_out (modified_eval )
1089- # if perturbations_per_eval > 1, the output shape must grow with
1090- # input and not be aggregated
1091- current_batch_size = inputs [0 ].shape [0 ]
1092-
1093- # number of perturbation, which is not the same as
1094- # perturbations_per_eval when not enough features to perturb
1095- n_perturb = current_batch_size / num_examples
1096- if perturbations_per_eval > 1 and not self ._is_output_shape_valid :
1097-
1098- current_output_shape = modified_eval .shape
1099-
1100- # use initial_eval as the forward of perturbations_per_eval = 1
1101- initial_output_shape = initial_eval .shape
1102-
1103- assert (
1104- # check if the output is not a scalar
1105- current_output_shape
1106- and initial_output_shape
1107- # check if the output grow in same ratio, i.e., not agg
1108- and current_output_shape [0 ] == n_perturb * initial_output_shape [0 ]
1109- ), (
1110- "When perturbations_per_eval > 1, forward_func's output "
1111- "should be a tensor whose 1st dim grow with the input "
1112- f"batch size: when input batch size is { num_examples } , "
1113- f"the output shape is { initial_output_shape } ; "
1114- f"when input batch size is { current_batch_size } , "
1115- f"the output shape is { current_output_shape } "
1116- )
1117-
1118- self ._is_output_shape_valid = True
1119-
11201168 # reshape the leading dim for n_feature_perturbed
11211169 # flatten each feature's eval outputs into 1D of (n_outputs)
11221170 modified_eval = modified_eval .reshape (- 1 , n_outputs )
0 commit comments