1212 Iterable ,
1313 List ,
1414 Optional ,
15+ Protocol ,
1516 Tuple ,
1617 TypeVar ,
1718 Union ,
4142from torch import dtype , Tensor
4243from torch .futures import collect_all , Future
4344
44- from tqdm .auto import tqdm
4545
4646IterableType = TypeVar ("IterableType" )
4747
4848logger : logging .Logger = logging .getLogger (__name__ )
4949
5050
51+ class Progress (Protocol ):
52+ def update (self , n : int = 1 ) -> Optional [bool ]:
53+ """TQDM Update method signature."""
54+
55+ def close (self ) -> None :
56+ """TQDM Close method signature."""
57+
58+
59+ class NullProgress :
60+ def update (self , n : int = 1 ) -> Optional [bool ]:
61+ return None
62+
63+ def close (self ) -> None :
64+ return None
65+
66+
5167def _parse_forward_out (forward_output : object ) -> Tensor :
5268 """
5369 A temp wrapper for global _run_forward util to force forward output
@@ -392,15 +408,18 @@ def attribute(
392408 isinstance (perturbations_per_eval , int ) and perturbations_per_eval >= 1
393409 ), "Perturbations per evaluation must be an integer and at least 1."
394410 with torch .no_grad ():
395- attr_progress = None
411+ attr_progress : Progress
396412 if show_progress :
397413 attr_progress = self ._attribute_progress_setup (
398414 formatted_inputs ,
399415 formatted_feature_mask ,
400416 ** kwargs ,
401417 perturbations_per_eval = perturbations_per_eval ,
402418 )
403- attr_progress .update (0 )
419+ else :
420+ attr_progress = NullProgress ()
421+
422+ attr_progress .update (0 )
404423
405424 # Computes initial evaluation with all features, which is compared
406425 # to each ablated result.
@@ -410,8 +429,8 @@ def attribute(
410429 target ,
411430 formatted_additional_forward_args ,
412431 )
413- if attr_progress is not None :
414- attr_progress .update ()
432+
433+ attr_progress .update ()
415434
416435 total_attrib : List [Tensor ] = []
417436 weights : List [Tensor ] = []
@@ -453,8 +472,7 @@ def attribute(
453472 ** kwargs ,
454473 )
455474
456- if attr_progress is not None :
457- attr_progress .close ()
475+ attr_progress .close ()
458476
459477 return cast (
460478 TensorOrTupleOfTensorsGeneric ,
@@ -470,7 +488,7 @@ def _attribute_with_cross_tensor_feature_masks(
470488 target : TargetType ,
471489 baselines : BaselineType ,
472490 formatted_feature_mask : Tuple [Tensor , ...],
473- attr_progress : Optional [ tqdm ] ,
491+ attr_progress : Progress ,
474492 flattened_initial_eval : Tensor ,
475493 initial_eval : Tensor ,
476494 n_outputs : int ,
@@ -564,8 +582,7 @@ def _attribute_with_cross_tensor_feature_masks(
564582 current_additional_args ,
565583 )
566584
567- if attr_progress is not None :
568- attr_progress .update ()
585+ attr_progress .update ()
569586
570587 assert not isinstance (modified_eval , torch .Future ), (
571588 "when use_futures is True, modified_eval should have "
@@ -728,15 +745,17 @@ def attribute_future(
728745 isinstance (perturbations_per_eval , int ) and perturbations_per_eval >= 1
729746 ), "Perturbations per evaluation must be an integer and at least 1."
730747 with torch .no_grad ():
731- attr_progress = None
748+ attr_progress : Progress
732749 if show_progress :
733750 attr_progress = self ._attribute_progress_setup (
734751 formatted_inputs ,
735752 formatted_feature_mask ,
736753 ** kwargs ,
737754 perturbations_per_eval = perturbations_per_eval ,
738755 )
739- attr_progress .update (0 )
756+ else :
757+ attr_progress = NullProgress ()
758+ attr_progress .update (0 )
740759
741760 # Computes initial evaluation with all features, which is compared
742761 # to each ablated result.
@@ -747,8 +766,7 @@ def attribute_future(
747766 formatted_additional_forward_args ,
748767 )
749768
750- if attr_progress is not None :
751- attr_progress .update ()
769+ attr_progress .update ()
752770
753771 processed_initial_eval_fut : Optional [
754772 Future [Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]]
@@ -789,7 +807,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
789807 target : TargetType ,
790808 baselines : BaselineType ,
791809 formatted_feature_mask : Tuple [Tensor , ...],
792- attr_progress : Optional [ tqdm ] ,
810+ attr_progress : Progress ,
793811 processed_initial_eval_fut : Future [
794812 Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]
795813 ],
@@ -883,8 +901,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
883901 current_additional_args ,
884902 )
885903
886- if attr_progress is not None :
887- attr_progress .update ()
904+ attr_progress .update ()
888905
889906 if not isinstance (modified_eval , torch .Future ):
890907 raise AssertionError (
@@ -928,8 +945,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
928945
929946 all_modified_eval_futures .append (ablated_out_fut )
930947
931- if attr_progress is not None :
932- attr_progress .close ()
948+ attr_progress .close ()
933949
934950 return self ._generate_async_result_cross_tensor (
935951 all_modified_eval_futures ,
@@ -959,7 +975,7 @@ def _attribute_progress_setup(
959975 feature_mask : Tuple [Tensor , ...],
960976 perturbations_per_eval : int ,
961977 ** kwargs : Any ,
962- ) -> tqdm :
978+ ) -> Progress :
963979 total_forwards = math .ceil (
964980 get_total_features_from_mask (feature_mask ) / perturbations_per_eval
965981 )
0 commit comments