Skip to content

Commit 6fa02a2

Browse files
cyrjanometa-codesync[bot]
authored andcommitted
Simplify the Update Progress Code. (#1664)
Summary: Pull Request resolved: #1664 Create Progress Protocol and NullProgress to simplify the progress bar. Reviewed By: sarahtranfb Differential Revision: D86798999 fbshipit-source-id: dba673a3a0534a54052784ea7159a49eb2bd4700
1 parent 82c1f80 commit 6fa02a2

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Iterable,
1313
List,
1414
Optional,
15+
Protocol,
1516
Tuple,
1617
TypeVar,
1718
Union,
@@ -41,13 +42,28 @@
4142
from torch import dtype, Tensor
4243
from torch.futures import collect_all, Future
4344

44-
from tqdm.auto import tqdm
4545

4646
IterableType = TypeVar("IterableType")
4747

4848
logger: logging.Logger = logging.getLogger(__name__)
4949

5050

51+
class Progress(Protocol):
52+
def update(self, n: int = 1) -> Optional[bool]:
53+
"""TQDM Update method signature."""
54+
55+
def close(self) -> None:
56+
"""TQDM Close method signature."""
57+
58+
59+
class NullProgress:
60+
def update(self, n: int = 1) -> Optional[bool]:
61+
return None
62+
63+
def close(self) -> None:
64+
return None
65+
66+
5167
def _parse_forward_out(forward_output: object) -> Tensor:
5268
"""
5369
A temp wrapper for global _run_forward util to force forward output
@@ -392,15 +408,18 @@ def attribute(
392408
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
393409
), "Perturbations per evaluation must be an integer and at least 1."
394410
with torch.no_grad():
395-
attr_progress = None
411+
attr_progress: Progress
396412
if show_progress:
397413
attr_progress = self._attribute_progress_setup(
398414
formatted_inputs,
399415
formatted_feature_mask,
400416
**kwargs,
401417
perturbations_per_eval=perturbations_per_eval,
402418
)
403-
attr_progress.update(0)
419+
else:
420+
attr_progress = NullProgress()
421+
422+
attr_progress.update(0)
404423

405424
# Computes initial evaluation with all features, which is compared
406425
# to each ablated result.
@@ -410,8 +429,8 @@ def attribute(
410429
target,
411430
formatted_additional_forward_args,
412431
)
413-
if attr_progress is not None:
414-
attr_progress.update()
432+
433+
attr_progress.update()
415434

416435
total_attrib: List[Tensor] = []
417436
weights: List[Tensor] = []
@@ -453,8 +472,7 @@ def attribute(
453472
**kwargs,
454473
)
455474

456-
if attr_progress is not None:
457-
attr_progress.close()
475+
attr_progress.close()
458476

459477
return cast(
460478
TensorOrTupleOfTensorsGeneric,
@@ -470,7 +488,7 @@ def _attribute_with_cross_tensor_feature_masks(
470488
target: TargetType,
471489
baselines: BaselineType,
472490
formatted_feature_mask: Tuple[Tensor, ...],
473-
attr_progress: Optional[tqdm],
491+
attr_progress: Progress,
474492
flattened_initial_eval: Tensor,
475493
initial_eval: Tensor,
476494
n_outputs: int,
@@ -564,8 +582,7 @@ def _attribute_with_cross_tensor_feature_masks(
564582
current_additional_args,
565583
)
566584

567-
if attr_progress is not None:
568-
attr_progress.update()
585+
attr_progress.update()
569586

570587
assert not isinstance(modified_eval, torch.Future), (
571588
"when use_futures is True, modified_eval should have "
@@ -728,15 +745,17 @@ def attribute_future(
728745
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
729746
), "Perturbations per evaluation must be an integer and at least 1."
730747
with torch.no_grad():
731-
attr_progress = None
748+
attr_progress: Progress
732749
if show_progress:
733750
attr_progress = self._attribute_progress_setup(
734751
formatted_inputs,
735752
formatted_feature_mask,
736753
**kwargs,
737754
perturbations_per_eval=perturbations_per_eval,
738755
)
739-
attr_progress.update(0)
756+
else:
757+
attr_progress = NullProgress()
758+
attr_progress.update(0)
740759

741760
# Computes initial evaluation with all features, which is compared
742761
# to each ablated result.
@@ -747,8 +766,7 @@ def attribute_future(
747766
formatted_additional_forward_args,
748767
)
749768

750-
if attr_progress is not None:
751-
attr_progress.update()
769+
attr_progress.update()
752770

753771
processed_initial_eval_fut: Optional[
754772
Future[Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]]
@@ -789,7 +807,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
789807
target: TargetType,
790808
baselines: BaselineType,
791809
formatted_feature_mask: Tuple[Tensor, ...],
792-
attr_progress: Optional[tqdm],
810+
attr_progress: Progress,
793811
processed_initial_eval_fut: Future[
794812
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
795813
],
@@ -883,8 +901,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
883901
current_additional_args,
884902
)
885903

886-
if attr_progress is not None:
887-
attr_progress.update()
904+
attr_progress.update()
888905

889906
if not isinstance(modified_eval, torch.Future):
890907
raise AssertionError(
@@ -928,8 +945,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
928945

929946
all_modified_eval_futures.append(ablated_out_fut)
930947

931-
if attr_progress is not None:
932-
attr_progress.close()
948+
attr_progress.close()
933949

934950
return self._generate_async_result_cross_tensor(
935951
all_modified_eval_futures,
@@ -959,7 +975,7 @@ def _attribute_progress_setup(
959975
feature_mask: Tuple[Tensor, ...],
960976
perturbations_per_eval: int,
961977
**kwargs: Any,
962-
) -> tqdm:
978+
) -> Progress:
963979
total_forwards = math.ceil(
964980
get_total_features_from_mask(feature_mask) / perturbations_per_eval
965981
)

0 commit comments

Comments
 (0)