Skip to content

Commit 3804e56

Browse files
cyrjanofacebook-github-bot
authored andcommitted
Add Method to check if output has valid shape for perturbations per eval. (meta-pytorch#1666)
Summary: This diff adds a new method to the `captum.attr._core.feature_ablation` module to check if the output shape of the forward function scales correctly with the input batch size when perturbations are applied. The method takes in the inputs, the number of examples, the initial evaluation, the modified evaluation, and the number of perturbations per evaluation as arguments. It then validates that the output shape of the forward function scales correctly with the input batch size. Differential Revision: D86976520
1 parent 6fa02a2 commit 3804e56

File tree

2 files changed

+185
-64
lines changed

2 files changed

+185
-64
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 102 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ def process_initial_eval(
9090
use_weights: bool = False,
9191
) -> Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]:
9292

93-
initial_eval = _parse_forward_out(initial_eval)
94-
9593
# number of elements in the output of forward_func
9694
n_outputs = initial_eval.numel()
9795

@@ -153,6 +151,74 @@ def format_result(
153151
return _format_output(is_inputs_tuple, attrib)
154152

155153

154+
def check_output_shape_valid(
155+
inputs: TensorOrTupleOfTensorsGeneric,
156+
num_examples: int,
157+
initial_eval: Tensor,
158+
modified_eval: Tensor,
159+
perturbations_per_eval: int,
160+
) -> None:
161+
"""
162+
Validates that the forward function's output shape scales correctly with
163+
input batch size when perturbations_per_eval > 1.
164+
165+
When multiple perturbations are evaluated simultaneously
166+
(perturbations_per_eval > 1),
167+
the forward function must return outputs whose first dimension grows proportionally
168+
with the input batch size. This ensures the forward function is not aggregating
169+
results across the batch, which would prevent correct attribution calculation.
170+
171+
Args:
172+
inputs (Tensor or tuple[Tensor, ...]): Input tensors used for evaluation.
173+
The first dimension of inputs[0] is used to determine current
174+
batch size.
175+
num_examples (int): The original number of examples (batch size) before
176+
expansion for perturbations.
177+
initial_eval (Tensor): Output from forward function with original batch size
178+
(perturbations_per_eval = 1). Used as baseline for shape comparison.
179+
modified_eval (Tensor): Output from forward function with expanded batch size
180+
(batch_size = num_examples * n_perturb).
181+
perturbations_per_eval (int): Number of perturbations processed simultaneously.
182+
Validation only occurs when this value is greater than 1.
183+
184+
Raises:
185+
AssertionError: If perturbations_per_eval > 1 and the output shape does not
186+
scale correctly. Specifically, if modified_eval.shape[0] is not
187+
equal to n_perturb * initial_eval.shape[0], where n_perturb is
188+
the ratio of current batch size to original batch size.
189+
"""
190+
191+
if perturbations_per_eval > 1:
192+
# if perturbations_per_eval > 1, the output shape must grow with
193+
# input and not be aggregated
194+
current_batch_size = inputs[0].shape[0]
195+
196+
# number of perturbation, which is not the same as
197+
# perturbations_per_eval when not enough features to perturb
198+
n_perturb: int = current_batch_size // num_examples
199+
mod_perturb: int = current_batch_size % num_examples
200+
current_output_shape = modified_eval.shape
201+
202+
# use initial_eval as the forward of perturbations_per_eval = 1
203+
initial_output_shape = initial_eval.shape
204+
205+
assert (
206+
# check if the output is not a scalar
207+
current_output_shape
208+
and initial_output_shape
209+
and mod_perturb == 0
210+
# check if the output grow in same ratio, i.e., not agg
211+
and current_output_shape[0] == n_perturb * initial_output_shape[0]
212+
), (
213+
"When perturbations_per_eval > 1, forward_func's output "
214+
"should be a tensor whose 1st dim grow with the input "
215+
f"batch size: when input batch size is {num_examples}, "
216+
f"the output shape is {initial_output_shape}; "
217+
f"when input batch size is {current_batch_size}, "
218+
f"the output shape is {current_output_shape}"
219+
)
220+
221+
156222
class FeatureAblation(PerturbationAttribution):
157223
"""
158224
A perturbation based approach to computing attribution, involving
@@ -395,7 +461,7 @@ def attribute(
395461
"""
396462
# Keeps track whether original input is a tuple or not before
397463
# converting it into a tuple. We return the attribution as tuple in the
398-
# end if the inputs where tuple.
464+
# end if the inputs were a tuple.
399465
is_inputs_tuple = _is_tuple(inputs)
400466

401467
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
@@ -443,7 +509,7 @@ def attribute(
443509
"when using the attribute function, initial_eval should have "
444510
f"non-Future type rather than {type(initial_eval)}"
445511
)
446-
512+
initial_eval = _parse_forward_out(initial_eval)
447513
(
448514
total_attrib,
449515
weights,
@@ -581,26 +647,33 @@ def _attribute_with_cross_tensor_feature_masks(
581647
current_target,
582648
current_additional_args,
583649
)
650+
modified_eval = _parse_forward_out(modified_eval)
584651

585652
attr_progress.update()
586653

587654
assert not isinstance(modified_eval, torch.Future), (
588655
"when use_futures is True, modified_eval should have "
589656
f"non-Future type rather than {type(modified_eval)}"
590657
)
591-
658+
# Just do the check once.
659+
if not self._is_output_shape_valid:
660+
check_output_shape_valid(
661+
inputs=current_inputs,
662+
num_examples=num_examples,
663+
initial_eval=initial_eval,
664+
modified_eval=modified_eval,
665+
perturbations_per_eval=perturbations_per_eval,
666+
)
667+
self._is_output_shape_valid = True
592668
total_attrib, weights = self._process_ablated_out_full(
593-
modified_eval,
594-
current_masks,
595-
flattened_initial_eval,
596-
initial_eval,
597-
current_inputs,
598-
n_outputs,
599-
num_examples,
600-
total_attrib,
601-
weights,
602-
attrib_type,
603-
perturbations_per_eval,
669+
modified_eval=modified_eval,
670+
current_mask=current_masks,
671+
flattened_initial_eval=flattened_initial_eval,
672+
inputs=current_inputs,
673+
n_outputs=n_outputs,
674+
total_attrib=total_attrib,
675+
weights=weights,
676+
attrib_type=attrib_type,
604677
)
605678
return total_attrib, weights
606679

@@ -705,6 +778,7 @@ def _initial_eval_to_processed_initial_eval_fut(
705778
"initial_eval_to_processed_initial_eval_fut: "
706779
"initial_eval should be a Tensor"
707780
)
781+
initial_eval_processed = _parse_forward_out(initial_eval_processed)
708782
result = process_initial_eval(
709783
initial_eval_processed, formatted_inputs, use_weights=self.use_weights
710784
)
@@ -1039,6 +1113,7 @@ def _eval_fut_to_ablated_out_fut_cross_tensor(
10391113
"total_attrib, weights, initial_eval, "
10401114
"flattened_initial_eval, n_outputs, attrib_type "
10411115
)
1116+
modified_eval = _parse_forward_out(modified_eval)
10421117
if not isinstance(modified_eval, Tensor):
10431118
raise AssertionError(
10441119
"_eval_fut_to_ablated_out_fut_cross_tensor: "
@@ -1052,13 +1127,21 @@ def _eval_fut_to_ablated_out_fut_cross_tensor(
10521127
n_outputs,
10531128
attrib_type,
10541129
) = initial_eval_tuple
1130+
# Just do the check once.
1131+
if not self._is_output_shape_valid:
1132+
check_output_shape_valid(
1133+
inputs=current_inputs,
1134+
num_examples=num_examples,
1135+
initial_eval=initial_eval,
1136+
modified_eval=modified_eval,
1137+
perturbations_per_eval=perturbations_per_eval,
1138+
)
1139+
self._is_output_shape_valid = True
1140+
10551141
total_attrib, weights = self._process_ablated_out_full(
10561142
modified_eval=modified_eval,
10571143
inputs=current_inputs,
10581144
current_mask=current_mask,
1059-
perturbations_per_eval=perturbations_per_eval,
1060-
num_examples=num_examples,
1061-
initial_eval=initial_eval,
10621145
flattened_initial_eval=flattened_initial_eval,
10631146
n_outputs=n_outputs,
10641147
total_attrib=total_attrib,
@@ -1076,47 +1159,12 @@ def _process_ablated_out_full(
10761159
modified_eval: Tensor,
10771160
current_mask: Tuple[Optional[Tensor], ...],
10781161
flattened_initial_eval: Tensor,
1079-
initial_eval: Tensor,
10801162
inputs: TensorOrTupleOfTensorsGeneric,
10811163
n_outputs: int,
1082-
num_examples: int,
10831164
total_attrib: List[Tensor],
10841165
weights: List[Tensor],
10851166
attrib_type: dtype,
1086-
perturbations_per_eval: int,
10871167
) -> Tuple[List[Tensor], List[Tensor]]:
1088-
modified_eval = _parse_forward_out(modified_eval)
1089-
# if perturbations_per_eval > 1, the output shape must grow with
1090-
# input and not be aggregated
1091-
current_batch_size = inputs[0].shape[0]
1092-
1093-
# number of perturbation, which is not the same as
1094-
# perturbations_per_eval when not enough features to perturb
1095-
n_perturb = current_batch_size / num_examples
1096-
if perturbations_per_eval > 1 and not self._is_output_shape_valid:
1097-
1098-
current_output_shape = modified_eval.shape
1099-
1100-
# use initial_eval as the forward of perturbations_per_eval = 1
1101-
initial_output_shape = initial_eval.shape
1102-
1103-
assert (
1104-
# check if the output is not a scalar
1105-
current_output_shape
1106-
and initial_output_shape
1107-
# check if the output grow in same ratio, i.e., not agg
1108-
and current_output_shape[0] == n_perturb * initial_output_shape[0]
1109-
), (
1110-
"When perturbations_per_eval > 1, forward_func's output "
1111-
"should be a tensor whose 1st dim grow with the input "
1112-
f"batch size: when input batch size is {num_examples}, "
1113-
f"the output shape is {initial_output_shape}; "
1114-
f"when input batch size is {current_batch_size}, "
1115-
f"the output shape is {current_output_shape}"
1116-
)
1117-
1118-
self._is_output_shape_valid = True
1119-
11201168
# reshape the leading dim for n_feature_perturbed
11211169
# flatten each feature's eval outputs into 1D of (n_outputs)
11221170
modified_eval = modified_eval.reshape(-1, n_outputs)

tests/attr/test_feature_ablation.py

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
1515
from captum.attr._core.feature_ablation import (
1616
_parse_forward_out,
17+
check_output_shape_valid,
1718
FeatureAblation,
1819
format_result,
1920
)
@@ -936,8 +937,10 @@ def test_parse_forward_out_invalid_none(self) -> None:
936937
class TestFormatResult(BaseTest):
937938

938939
def test_format_result_single_tensor_no_weights(self) -> None:
939-
total_attrib = [torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])]
940-
weights = []
940+
total_attrib: list[torch.Tensor] = [
941+
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
942+
]
943+
weights: list[torch.Tensor] = []
941944
is_inputs_tuple = False
942945
use_weights = False
943946

@@ -951,11 +954,11 @@ def test_format_result_single_tensor_no_weights(self) -> None:
951954
)
952955

953956
def test_format_result_tuple_output_no_weights(self) -> None:
954-
total_attrib = [
957+
total_attrib: list[torch.Tensor] = [
955958
torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
956959
torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
957960
]
958-
weights = []
961+
weights: list[torch.Tensor] = []
959962
is_inputs_tuple = True
960963
use_weights = False
961964

@@ -967,8 +970,12 @@ def test_format_result_tuple_output_no_weights(self) -> None:
967970
assertTensorAlmostEqual(self, result[1], torch.tensor([[5.0, 6.0], [7.0, 8.0]]))
968971

969972
def test_format_result_single_tensor_with_weights(self) -> None:
970-
total_attrib = [torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])]
971-
weights = [torch.tensor([[2.0, 4.0, 5.0], [8.0, 10.0, 12.0]])]
973+
total_attrib: list[torch.Tensor] = [
974+
torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])
975+
]
976+
weights: list[torch.Tensor] = [
977+
torch.tensor([[2.0, 4.0, 5.0], [8.0, 10.0, 12.0]])
978+
]
972979
is_inputs_tuple = False
973980
use_weights = True
974981

@@ -979,11 +986,11 @@ def test_format_result_single_tensor_with_weights(self) -> None:
979986
assertTensorAlmostEqual(self, result, expected)
980987

981988
def test_format_result_tuple_output_with_weights(self) -> None:
982-
total_attrib = [
989+
total_attrib: list[torch.Tensor] = [
983990
torch.tensor([[10.0, 20.0], [30.0, 40.0]]),
984991
torch.tensor([[50.0, 60.0], [70.0, 80.0]]),
985992
]
986-
weights = [
993+
weights: list[torch.Tensor] = [
987994
torch.tensor([[2.0, 4.0], [5.0, 8.0]]),
988995
torch.tensor([[10.0, 12.0], [14.0, 16.0]]),
989996
]
@@ -998,8 +1005,10 @@ def test_format_result_tuple_output_with_weights(self) -> None:
9981005
assertTensorAlmostEqual(self, result[1], torch.tensor([[5.0, 5.0], [5.0, 5.0]]))
9991006

10001007
def test_format_result_integer_dtype_no_weights(self) -> None:
1001-
total_attrib = [torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32)]
1002-
weights = []
1008+
total_attrib: list[torch.Tensor] = [
1009+
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32)
1010+
]
1011+
weights: list[torch.Tensor] = []
10031012
is_inputs_tuple = False
10041013
use_weights = False
10051014

@@ -1013,5 +1022,69 @@ def test_format_result_integer_dtype_no_weights(self) -> None:
10131022
)
10141023

10151024

1025+
class TestCheckOutputShapeValid(BaseTest):
1026+
def test_valid_output_shape_scaling(self) -> None:
1027+
inputs = (torch.randn(4, 3),)
1028+
num_examples = 2
1029+
initial_eval = torch.randn(2, 5)
1030+
modified_eval = torch.randn(4, 5)
1031+
perturbations_per_eval = 2
1032+
1033+
check_output_shape_valid(
1034+
inputs=inputs,
1035+
num_examples=num_examples,
1036+
initial_eval=initial_eval,
1037+
modified_eval=modified_eval,
1038+
perturbations_per_eval=perturbations_per_eval,
1039+
)
1040+
1041+
def test_invalid_output_shape_scaling(self) -> None:
1042+
inputs = (torch.randn(4, 3),)
1043+
num_examples = 2
1044+
initial_eval = torch.randn(2, 5)
1045+
modified_eval = torch.randn(6, 5)
1046+
perturbations_per_eval = 2
1047+
1048+
with self.assertRaises(AssertionError):
1049+
check_output_shape_valid(
1050+
inputs=inputs,
1051+
num_examples=num_examples,
1052+
initial_eval=initial_eval,
1053+
modified_eval=modified_eval,
1054+
perturbations_per_eval=perturbations_per_eval,
1055+
)
1056+
1057+
def test_skip_validation_when_perturbations_per_eval_is_one(self) -> None:
1058+
inputs = (torch.randn(4, 3),)
1059+
num_examples = 2
1060+
initial_eval = torch.randn(2, 5)
1061+
modified_eval = torch.randn(3, 5)
1062+
perturbations_per_eval = 1
1063+
1064+
check_output_shape_valid(
1065+
inputs=inputs,
1066+
num_examples=num_examples,
1067+
initial_eval=initial_eval,
1068+
modified_eval=modified_eval,
1069+
perturbations_per_eval=perturbations_per_eval,
1070+
)
1071+
1072+
def test_invalid_batch_size_not_divisible_by_num_examples(self) -> None:
1073+
inputs = (torch.randn(5, 3),)
1074+
num_examples = 2
1075+
initial_eval = torch.randn(2, 5)
1076+
modified_eval = torch.randn(5, 5)
1077+
perturbations_per_eval = 2
1078+
1079+
with self.assertRaises(AssertionError):
1080+
check_output_shape_valid(
1081+
inputs=inputs,
1082+
num_examples=num_examples,
1083+
initial_eval=initial_eval,
1084+
modified_eval=modified_eval,
1085+
perturbations_per_eval=perturbations_per_eval,
1086+
)
1087+
1088+
10161089
if __name__ == "__main__":
10171090
unittest.main()

0 commit comments

Comments
 (0)