Skip to content

Commit 82c1f80

Browse files
cyrjanometa-codesync[bot]
authored andcommitted
Minor refactoring on FeatureAblation methods (#1665)
Summary: Pull Request resolved: #1665 This diff contains minor refactoring changes to the FeatureAblation methods in the Captum library to improve readability. Reviewed By: sarahtranfb Differential Revision: D86791984 fbshipit-source-id: 24eb317596ebca9ff91c3669be6c4d1c92e43c5b
1 parent 9f48afb commit 82c1f80

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,15 @@ def _parse_forward_out(forward_output: object) -> Tensor:
5757
if isinstance(forward_output, Tensor):
5858
return forward_output
5959

60-
output_type = type(forward_output)
61-
assert output_type is int or output_type is float, (
60+
assert isinstance(forward_output, (int, float)), (
6261
"the return of forward_func must be a tensor, int, or float,"
6362
f" received: {forward_output}"
6463
)
6564

6665
# using python built-in type as torch dtype
6766
# int -> torch.int64, float -> torch.float64
6867
# ref: https://github.com/pytorch/pytorch/pull/21215
69-
return torch.tensor(forward_output, dtype=cast(dtype, output_type))
68+
return torch.tensor(forward_output, dtype=cast(dtype, type(forward_output)))
7069

7170

7271
def process_initial_eval(
@@ -78,7 +77,7 @@ def process_initial_eval(
7877
initial_eval = _parse_forward_out(initial_eval)
7978

8079
# number of elements in the output of forward_func
81-
n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1
80+
n_outputs = initial_eval.numel()
8281

8382
# flatten eval outputs into 1D (n_outputs)
8483
# add the leading dim for n_feature_perturbed
@@ -87,10 +86,12 @@ def process_initial_eval(
8786
# Initialize attribution totals and counts
8887
attrib_type = flattened_initial_eval.dtype
8988

89+
# Shape of attribution is the outputs * inputs dimensions.
90+
# where the inputs dimension should remove the batch size dimension.
9091
total_attrib = [
9192
# attribute w.r.t each output element
9293
torch.zeros(
93-
(n_outputs,) + input.shape[1:],
94+
(n_outputs, *input.shape[1:]),
9495
dtype=attrib_type,
9596
device=input.device,
9697
)
@@ -101,7 +102,7 @@ def process_initial_eval(
101102
weights = []
102103
if use_weights:
103104
weights = [
104-
torch.zeros((n_outputs,) + input.shape[1:], device=input.device).float()
105+
torch.zeros((n_outputs, *input.shape[1:]), device=input.device)
105106
for input in inputs
106107
]
107108

@@ -137,7 +138,7 @@ def format_result(
137138

138139

139140
class FeatureAblation(PerturbationAttribution):
140-
r"""
141+
"""
141142
A perturbation based approach to computing attribution, involving
142143
replacing each input feature with a given baseline / reference, and
143144
computing the difference in output. By default, each scalar value within
@@ -377,7 +378,8 @@ def attribute(
377378
>>> attr = ablator.attribute(input, target=1, feature_mask=feature_mask)
378379
"""
379380
# Keeps track whether original input is a tuple or not before
380-
# converting it into a tuple.
381+
# converting it into a tuple. We return the attribution as tuple in the
382+
# end if the inputs where tuple.
381383
is_inputs_tuple = _is_tuple(inputs)
382384

383385
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)

0 commit comments

Comments
 (0)