@@ -57,16 +57,15 @@ def _parse_forward_out(forward_output: object) -> Tensor:
5757 if isinstance (forward_output , Tensor ):
5858 return forward_output
5959
60- output_type = type (forward_output )
61- assert output_type is int or output_type is float , (
60+ assert isinstance (forward_output , (int , float )), (
6261 "the return of forward_func must be a tensor, int, or float,"
6362 f" received: { forward_output } "
6463 )
6564
6665 # using python built-in type as torch dtype
6766 # int -> torch.int64, float -> torch.float64
6867 # ref: https://github.com/pytorch/pytorch/pull/21215
69- return torch .tensor (forward_output , dtype = cast (dtype , output_type ))
68+ return torch .tensor (forward_output , dtype = cast (dtype , type ( forward_output ) ))
7069
7170
7271def process_initial_eval (
@@ -78,7 +77,7 @@ def process_initial_eval(
7877 initial_eval = _parse_forward_out (initial_eval )
7978
8079 # number of elements in the output of forward_func
81- n_outputs = initial_eval .numel () if isinstance ( initial_eval , Tensor ) else 1
80+ n_outputs = initial_eval .numel ()
8281
8382 # flatten eval outputs into 1D (n_outputs)
8483 # add the leading dim for n_feature_perturbed
@@ -87,10 +86,12 @@ def process_initial_eval(
8786 # Initialize attribution totals and counts
8887 attrib_type = flattened_initial_eval .dtype
8988
89+ # Shape of attribution is the outputs * inputs dimensions.
90+ # where the inputs dimension should remove the batch size dimension.
9091 total_attrib = [
9192 # attribute w.r.t each output element
9293 torch .zeros (
93- (n_outputs ,) + input .shape [1 :],
94+ (n_outputs , * input .shape [1 :]) ,
9495 dtype = attrib_type ,
9596 device = input .device ,
9697 )
@@ -101,7 +102,7 @@ def process_initial_eval(
101102 weights = []
102103 if use_weights :
103104 weights = [
104- torch .zeros ((n_outputs ,) + input .shape [1 :], device = input .device ). float ( )
105+ torch .zeros ((n_outputs , * input .shape [1 :]) , device = input .device )
105106 for input in inputs
106107 ]
107108
@@ -137,7 +138,7 @@ def format_result(
137138
138139
139140class FeatureAblation (PerturbationAttribution ):
140- r """
141+ """
141142 A perturbation based approach to computing attribution, involving
142143 replacing each input feature with a given baseline / reference, and
143144 computing the difference in output. By default, each scalar value within
@@ -377,7 +378,8 @@ def attribute(
377378 >>> attr = ablator.attribute(input, target=1, feature_mask=feature_mask)
378379 """
379380 # Keeps track whether original input is a tuple or not before
380- # converting it into a tuple.
381+ # converting it into a tuple. We return the attribution as tuple in the
382+ # end if the inputs where tuple.
381383 is_inputs_tuple = _is_tuple (inputs )
382384
383385 formatted_inputs , baselines = _format_input_baseline (inputs , baselines )
0 commit comments