@@ -440,7 +440,7 @@ def check_weight_equal(
440440 except Exception :
441441 return torch .all (sd_weight == network_weight )
442442
443- @needs_refit
443+ @needs_refit # type: ignore[misc]
444444 def _save_weight_mapping (self ) -> None :
445445 """
446446 Construct the weight name mapping from engine weight name to state_dict weight name.
@@ -577,7 +577,7 @@ def _save_weight_mapping(self) -> None:
577577 gc .collect ()
578578 torch .cuda .empty_cache ()
579579
580- @needs_refit
580+ @needs_refit # type: ignore[misc]
581581 def _insert_engine_to_cache (self , hash_val : str , serialized_engine : bytes ) -> None :
582582 # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
583583 # if not self.compilation_settings.strip_engine_weights:
@@ -605,7 +605,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
605605 ),
606606 )
607607
608- @needs_refit
608+ @needs_refit # type: ignore[misc]
609609 def _pull_cached_engine (self , hash_val : str ) -> Optional [TRTInterpreterResult ]:
610610 # query the cached TRT engine
611611 cached_data = self .engine_cache .check (hash_val ) # type: ignore[union-attr]
@@ -941,7 +941,14 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
941941 f"Specified output dtypes ({ len (self .output_dtypes )} ) differ from number of outputs ({ len (outputs )} )"
942942 )
943943
944+ marked_outputs_ids = []
944945 for i , output in enumerate (outputs ):
946+ # In some cases, the same output tensor may be marked multiple times, such as _to_copy,
947+ # so we skip marking if the output is already marked
948+ if id (output ) in marked_outputs_ids :
949+ continue
950+ marked_outputs_ids .append (id (output ))
951+
945952 name = f"output{ i } "
946953
947954 output_dtype = dtype .unknown
0 commit comments