@@ -555,73 +555,9 @@ def attribute_future(
555555 ]
556556 )
557557
558- def eval_fut_to_ablated_out_fut (
559- # pyre-ignore Invalid type parameters [24]
560- eval_futs : Future [List [Future [List [object ]]]],
561- current_inputs : Tuple [Tensor , ...],
562- current_mask : Tensor ,
563- i : int ,
564- perturbations_per_eval : int ,
565- num_examples : int ,
566- formatted_inputs : Tuple [Tensor , ...],
567- ) -> Tuple [List [Tensor ], List [Tensor ]]:
568- try :
569- modified_eval = cast (Tensor , eval_futs .value ()[1 ].value ())
570- initial_eval_tuple = cast (
571- Tuple [
572- List [Tensor ],
573- List [Tensor ],
574- Tensor ,
575- Tensor ,
576- int ,
577- dtype ,
578- ],
579- eval_futs .value ()[0 ].value (),
580- )
581- if len (initial_eval_tuple ) != 6 :
582- raise AssertionError (
583- "eval_fut_to_ablated_out_fut: "
584- "initial_eval_tuple should have 6 elements: "
585- "total_attrib, weights, initial_eval, "
586- "flattened_initial_eval, n_outputs, attrib_type "
587- )
588- if not isinstance (modified_eval , Tensor ):
589- raise AssertionError (
590- "eval_fut_to_ablated_out_fut: "
591- "modified eval should be a Tensor"
592- )
593- (
594- total_attrib ,
595- weights ,
596- initial_eval ,
597- flattened_initial_eval ,
598- n_outputs ,
599- attrib_type ,
600- ) = initial_eval_tuple
601- result = self ._process_ablated_out ( # type: ignore # noqa: E501 line too long
602- modified_eval = modified_eval ,
603- current_inputs = current_inputs ,
604- current_mask = current_mask ,
605- perturbations_per_eval = perturbations_per_eval ,
606- num_examples = num_examples ,
607- initial_eval = initial_eval ,
608- flattened_initial_eval = flattened_initial_eval ,
609- inputs = formatted_inputs ,
610- n_outputs = n_outputs ,
611- total_attrib = total_attrib ,
612- weights = weights ,
613- i = i ,
614- attrib_type = attrib_type ,
615- )
616- except FeatureAblationFutureError as e :
617- raise FeatureAblationFutureError (
618- "eval_fut_to_ablated_out_fut func failed)"
619- ) from e
620- return result
621-
622558 ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = (
623559 eval_futs .then (
624- lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
560+ lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : self . _eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
625561 eval_futs = eval_futs ,
626562 current_inputs = current_inputs ,
627563 current_mask = current_mask ,
@@ -660,6 +596,70 @@ def _attribute_progress_setup(
660596 )
661597 return attr_progress
662598
599+ def _eval_fut_to_ablated_out_fut (
600+ self ,
601+ # pyre-ignore Invalid type parameters [24]
602+ eval_futs : Future [List [Future [List [object ]]]],
603+ current_inputs : Tuple [Tensor , ...],
604+ current_mask : Tensor ,
605+ i : int ,
606+ perturbations_per_eval : int ,
607+ num_examples : int ,
608+ formatted_inputs : Tuple [Tensor , ...],
609+ ) -> Tuple [List [Tensor ], List [Tensor ]]:
610+ try :
611+ modified_eval = cast (Tensor , eval_futs .value ()[1 ].value ())
612+ initial_eval_tuple = cast (
613+ Tuple [
614+ List [Tensor ],
615+ List [Tensor ],
616+ Tensor ,
617+ Tensor ,
618+ int ,
619+ dtype ,
620+ ],
621+ eval_futs .value ()[0 ].value (),
622+ )
623+ if len (initial_eval_tuple ) != 6 :
624+ raise AssertionError (
625+ "eval_fut_to_ablated_out_fut: "
626+ "initial_eval_tuple should have 6 elements: "
627+ "total_attrib, weights, initial_eval, "
628+ "flattened_initial_eval, n_outputs, attrib_type "
629+ )
630+ if not isinstance (modified_eval , Tensor ):
631+ raise AssertionError (
632+ "eval_fut_to_ablated_out_fut: " "modified eval should be a Tensor"
633+ )
634+ (
635+ total_attrib ,
636+ weights ,
637+ initial_eval ,
638+ flattened_initial_eval ,
639+ n_outputs ,
640+ attrib_type ,
641+ ) = initial_eval_tuple
642+ result = self ._process_ablated_out ( # type: ignore # noqa: E501 line too long
643+ modified_eval = modified_eval ,
644+ current_inputs = current_inputs ,
645+ current_mask = current_mask ,
646+ perturbations_per_eval = perturbations_per_eval ,
647+ num_examples = num_examples ,
648+ initial_eval = initial_eval ,
649+ flattened_initial_eval = flattened_initial_eval ,
650+ inputs = formatted_inputs ,
651+ n_outputs = n_outputs ,
652+ total_attrib = total_attrib ,
653+ weights = weights ,
654+ i = i ,
655+ attrib_type = attrib_type ,
656+ )
657+ except FeatureAblationFutureError as e :
658+ raise FeatureAblationFutureError (
659+ "eval_fut_to_ablated_out_fut func failed)"
660+ ) from e
661+ return result
662+
663663 # pyre-fixme[3]: Return type must be specified as type that does not contain `Any`
664664 def _ith_input_ablation_generator (
665665 self ,
0 commit comments