@@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError):
6565
6666
6767class TRTInterpreterResult (NamedTuple ):
68- serialized_engine : bytes
68+ engine : trt . ICudaEngine | bytes
6969 input_names : Sequence [str ]
7070 output_names : Sequence [str ]
7171 weight_name_map : Optional [dict [Any , Any ]]
@@ -731,6 +731,10 @@ def run(
731731 if interpreter_result is not None : # hit the cache
732732 return interpreter_result # type: ignore[no-any-return]
733733
734+ import psutil
735+
736+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
737+ # breakpoint()
734738 self ._construct_trt_network_def ()
735739
736740 if not self .compilation_settings .immutable_weights :
@@ -749,41 +753,62 @@ def run(
749753 self ._create_timing_cache (
750754 builder_config , self .compilation_settings .timing_cache_path
751755 )
752- serialized_engine = self .builder .build_serialized_network (
756+ import psutil
757+
758+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
759+ # breakpoint()
760+
761+ cuda_engine = self .builder .build_engine_with_config (
753762 self .ctx .net , builder_config
754763 )
755- assert serialized_engine
756764
757765 _LOGGER .info (
758766 f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
759767 )
760- _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
761-
762768 self .ctx .clear_cpu_weights_reference_holder ()
763769
764770 self ._save_timing_cache (
765771 builder_config , self .compilation_settings .timing_cache_path
766772 )
767773
768774 # Engine caching only for refittable engines
769- if (
770- not self .compilation_settings .immutable_weights
771- and self .compilation_settings .cache_built_engines
772- and self .engine_cache is not None
773- ):
774- self ._insert_engine_to_cache (hash_val , serialized_engine )
775-
776- with io .BytesIO () as engine_bytes :
777- engine_bytes .write (serialized_engine )
778- engine_str = engine_bytes .getvalue ()
779-
780- return TRTInterpreterResult (
781- engine_str ,
782- self ._input_names ,
783- self ._output_names ,
784- self .weight_name_map ,
785- self .ctx .requires_output_allocator ,
786- )
775+ # if (
776+ # not self.compilation_settings.immutable_weights
777+ # and self.compilation_settings.cache_built_engines
778+ # and self.engine_cache is not None
779+ # ):
780+ # self._insert_engine_to_cache(hash_val, serialized_engine)
781+
782+ print ("After build_engine_with_config" )
783+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
784+ # breakpoint()
785+ assert cuda_engine
786+ if self .compilation_settings .use_python_runtime :
787+ return TRTInterpreterResult (
788+ cuda_engine ,
789+ self ._input_names ,
790+ self ._output_names ,
791+ self .weight_name_map ,
792+ self .ctx .requires_output_allocator ,
793+ )
794+ else :
795+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
796+ # breakpoint()
797+ serialized_engine = cuda_engine .serialize ()
798+ _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
799+
800+ with io .BytesIO () as engine_bytes :
801+ engine_bytes .write (serialized_engine )
802+ engine_str = engine_bytes .getvalue ()
803+ print (psutil .Process ().memory_info ().rss / 1024 / 1024 , "MB" )
804+ # breakpoint()
805+ return TRTInterpreterResult (
806+ engine_str ,
807+ self ._input_names ,
808+ self ._output_names ,
809+ self .weight_name_map ,
810+ self .ctx .requires_output_allocator ,
811+ )
787812
788813 def run_node (self , n : torch .fx .Node ) -> torch .fx .Node :
789814 self ._cur_node_name = get_node_name (n )
0 commit comments