@@ -591,13 +591,11 @@ def _save_weight_mapping(self) -> None:
591591 torch .cuda .empty_cache ()
592592
593593 @needs_refit # type: ignore[misc]
594- def _insert_engine_to_cache (self , hash_val : str , serialized_engine : bytes ) -> None :
594+ def _insert_engine_to_cache (self , hash_val : str , engine : bytes ) -> None :
595+ serialized_engine = engine .serialize ()
595596 # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
596597 # if not self.compilation_settings.strip_engine_weights:
597598 # # set EXCLUDE_WEIGHTS flag to strip weights
598- # runtime = trt.Runtime(TRT_LOGGER)
599- # engine = runtime.deserialize_cuda_engine(serialized_engine)
600-
601599 # serialization_config = engine.create_serialization_config()
602600 # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
603601 # serialized_engine = engine.serialize_with_config(
@@ -731,10 +729,6 @@ def run(
731729 if interpreter_result is not None : # hit the cache
732730 return interpreter_result # type: ignore[no-any-return]
733731
734- import psutil
735-
736- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
737- # breakpoint()
738732 self ._construct_trt_network_def ()
739733
740734 if not self .compilation_settings .immutable_weights :
@@ -753,14 +747,11 @@ def run(
753747 self ._create_timing_cache (
754748 builder_config , self .compilation_settings .timing_cache_path
755749 )
756- import psutil
757-
758- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
759- # breakpoint()
760750
761751 cuda_engine = self .builder .build_engine_with_config (
762752 self .ctx .net , builder_config
763753 )
754+ assert cuda_engine
764755
765756 _LOGGER .info (
766757 f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
@@ -772,17 +763,13 @@ def run(
772763 )
773764
774765 # Engine caching only for refittable engines
775- # if (
776- # not self.compilation_settings.immutable_weights
777- # and self.compilation_settings.cache_built_engines
778- # and self.engine_cache is not None
779- # ):
780- # self._insert_engine_to_cache(hash_val, serialized_engine)
781-
782- print ("After build_engine_with_config" )
783- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
784- # breakpoint()
785- assert cuda_engine
766+ if (
767+ not self .compilation_settings .immutable_weights
768+ and self .compilation_settings .cache_built_engines
769+ and self .engine_cache is not None
770+ ):
771+ self ._insert_engine_to_cache (hash_val , cuda_engine )
772+
786773 if self .compilation_settings .use_python_runtime :
787774 return TRTInterpreterResult (
788775 cuda_engine ,
@@ -792,16 +779,13 @@ def run(
792779 self .ctx .requires_output_allocator ,
793780 )
794781 else :
795- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
796- # breakpoint()
797782 serialized_engine = cuda_engine .serialize ()
798783 _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
799784
800785 with io .BytesIO () as engine_bytes :
801786 engine_bytes .write (serialized_engine )
802787 engine_str = engine_bytes .getvalue ()
803- print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
804- # breakpoint()
788+
805789 return TRTInterpreterResult (
806790 engine_str ,
807791 self ._input_names ,
0 commit comments