From b072ea16937137a2c64c9d36e9f960a455dbfb31 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 16 Oct 2025 22:22:20 +0000 Subject: [PATCH 1/5] rebased to main --- py/torch_tensorrt/dynamo/_compiler.py | 15 +++- .../dynamo/conversion/_TRTInterpreter.py | 40 ++++++----- .../dynamo/conversion/_conversion.py | 53 ++++++++++++--- py/torch_tensorrt/dynamo/debug/_Debugger.py | 1 + .../lowering/passes/constant_folding.py | 4 +- .../partitioning/_adjacency_partitioner.py | 68 +++++++++++++++++++ py/torch_tensorrt/dynamo/utils.py | 36 ++++++++++ tests/py/dynamo/models/test_models.py | 46 +++++++++++++ tools/llm/torchtrt_ext/register_sdpa.py | 7 +- 9 files changed, 238 insertions(+), 32 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c8ad938032..bc345947d3 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -42,6 +42,7 @@ ) from torch_tensorrt.dynamo.utils import ( deallocate_module, + get_cpu_memory_usage, get_flat_args_with_check, get_output_metadata, parse_graph_io, @@ -681,7 +682,7 @@ def compile( "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, } - + logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -695,14 +696,17 @@ def compile( # Apply lowering on the graph module gm = post_lowering(gm, settings) + logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") logger.debug("Lowered Input graph: " + str(gm.graph)) # Move the weights in the state_dict to CPU if offload_module_to_cpu: + deallocate_module(gm, delete_module=False) deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) + logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB") else: remaining_memory, total_memory = torch.cuda.mem_get_info() if remaining_memory < total_memory // 2: @@ -868,6 +872,11 @@ def preserve_module_specs( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. + # This is done to release CPU memory. + for attr in dir(gm): + if attr.startswith("_frozen_param"): + delattr(gm, attr) for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -1243,7 +1252,7 @@ def convert_exported_program_to_serialized_trt_engine( # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) - trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) + trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} @@ -1330,7 +1339,7 @@ def convert_exported_program_to_serialized_trt_engine( ) flattened_input_list = get_flat_args_with_check( - exported_program, list(trt_arg_inputs), trt_kwarg_inputs + exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore )[0] try: diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 73af09448e..2542d652bd 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -50,7 +50,12 @@ from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.observer import Observer -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device +from torch_tensorrt.dynamo.utils import ( + DYNAMIC_DIM, + deallocate_module, + get_cpu_memory_usage, + to_torch_device, +) from torch_tensorrt.logging import TRT_LOGGER _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError): class TRTInterpreterResult(NamedTuple): - serialized_engine: bytes + engine: trt.ICudaEngine input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] @@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - self.module.to(torch_device) - sd = self.module.state_dict() + sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} weight_refit_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} @@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None: torch.cuda.empty_cache() @needs_refit # type: ignore[misc] - def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: + def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None: + serialized_engine = engine.serialize() # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: # # set EXCLUDE_WEIGHTS flag to strip weights - # runtime = trt.Runtime(TRT_LOGGER) - # engine = runtime.deserialize_cuda_engine(serialized_engine) - # serialization_config = engine.create_serialization_config() # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) # serialized_engine = engine.serialize_with_config( @@ -733,6 +735,9 @@ def run( return interpreter_result # type: ignore[no-any-return] self._construct_trt_network_def() + _LOGGER.debug( + f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" + ) if not self.compilation_settings.immutable_weights: self._save_weight_mapping() @@ -750,16 +755,19 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - serialized_engine = self.builder.build_serialized_network( + + cuda_engine = self.builder.build_engine_with_config( self.ctx.net, builder_config ) - assert serialized_engine + assert cuda_engine + + _LOGGER.debug( + f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB" + ) _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") - self.ctx.clear_cpu_weights_reference_holder() self._save_timing_cache( @@ -772,14 +780,10 @@ def run( and self.compilation_settings.cache_built_engines and self.engine_cache is not None ): - self._insert_engine_to_cache(hash_val, serialized_engine) - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() + self._insert_engine_to_cache(hash_val, cuda_engine) return TRTInterpreterResult( - engine_str, + cuda_engine, self._input_names, self._output_names, self.weight_name_map, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 35b6c26617..aaec8d3be8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,7 +1,8 @@ from __future__ import annotations +import io import logging -from typing import Any, List, Optional, Sequence +from typing import Any, List, NamedTuple, Optional, Sequence import torch from torch_tensorrt._enums import dtype @@ -9,16 +10,25 @@ from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( - TRTInterpreter, - TRTInterpreterResult, -) +from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_output_dtypes +from torch_tensorrt.dynamo.utils import ( + get_cpu_memory_usage, + get_output_dtypes, + release_memory, +) logger = logging.getLogger(__name__) +class SerializedInterpreterResult(NamedTuple): + serialized_engine: bytes + input_names: Sequence[str] + output_names: Sequence[str] + weight_name_map: Optional[dict[Any, Any]] + requires_output_allocator: bool + + def infer_module_output_dtypes( module: torch.fx.GraphModule, truncate_double: bool = False, @@ -29,7 +39,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] + return get_output_dtypes(outputs, truncate_double) def interpret_module_to_result( @@ -39,7 +49,7 @@ def interpret_module_to_result( arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, -) -> TRTInterpreterResult: +) -> SerializedInterpreterResult: """Interpret an FX module to a TRTInterpreterResult Args: module: FX GraphModule to interpret @@ -65,7 +75,32 @@ def interpret_module_to_result( ) interpreter_result = interpreter.run() - return interpreter_result + # Delete the frozen parameters from the module to release CPU memory + del interpreter + for attr in dir(module): + if attr.startswith("_frozen_param"): + delattr(module, attr) + release_memory() + logger.debug( + f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" + ) + + serialized_engine = interpreter_result.engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + serialized_engine = engine_bytes.getvalue() + logger.debug( + f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" + ) + serialized_interpreter_result = SerializedInterpreterResult( + serialized_engine=serialized_engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + weight_name_map=interpreter_result.weight_name_map, + requires_output_allocator=interpreter_result.requires_output_allocator, + ) + + return serialized_interpreter_result def convert_module( diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index 39e4217f73..e565929861 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -220,6 +220,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]: "class": "logging.FileHandler", "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", "formatter": "standard", + "mode": "w", # This will clear the previous content } config["loggers"][""]["handlers"].append("file") return config diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 5ba84b09b0..9b821df906 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -37,7 +37,9 @@ def constant_fold( # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): replace_node_with_constant( - gm, node, torch.nn.Parameter(constant, requires_grad=False) + gm, + node, + torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), ) erased_params = [] diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index e2f544c2a7..b6972c7e85 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -1,6 +1,7 @@ import logging from typing import Collection, Dict, List, Optional, Tuple +import psutil import torch import torch.fx.passes.operator_support as ops from torch.fx.node import Target @@ -225,6 +226,9 @@ def partition_graph(self) -> torch.fx.GraphModule: # Remove segments smaller than the block size (with exceptions) subgraphs = self.remove_small_acc_subgraphs(subgraphs) + num_of_break = self.calculate_num_of_break(subgraphs) + subgraphs = self.break_subgraphs(subgraphs, num_of_break=num_of_break) + # Set the number of TRT engines to be generated self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) @@ -232,6 +236,70 @@ def partition_graph(self) -> torch.fx.GraphModule: self.tag(subgraphs) return self.split() + def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int: + """ + This function calculates the break period based on the number of subgraphs. + """ + rss = psutil.Process().memory_info().rss + available_rss = psutil.virtual_memory().available + num_of_graphs = len(subgraphs) + if rss < available_rss * 0.3: + num_of_graphs = 1 + elif rss < available_rss * 0.5: + num_of_graphs = 2 + elif rss < available_rss: + num_of_graphs = 4 + elif rss < available_rss * 1.5: + num_of_graphs = 8 + elif rss < available_rss * 2: + num_of_graphs = 16 + else: + num_of_graphs = 32 + + return max( + 1, num_of_graphs // ((len(subgraphs) + 1) // 2) + ) # If there are already graph breaks, for each TRT subgraph, we break for a few times. + + def break_subgraphs( + self, subgraphs: List[Subgraph], num_of_break: int = 1 + ) -> List[Subgraph]: + """ + This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory. + """ + + num_of_sdpa_node = len( + [node for node in self.acc_nodes if "scaled_dot" in str(node.target)] + ) + break_period = num_of_sdpa_node // num_of_break + 1 + current_break_idx = 0 + current_num_break = 0 + new_subgraphs = [] + for subgraph in subgraphs: + if subgraph.is_acc: + for i, node in enumerate(subgraph.nodes): + if "scaled_dot" in str(node.target): + current_num_break += 1 + if current_num_break % break_period != 0: + continue + new_subgraphs.append( + Subgraph( + is_acc=True, + nodes=subgraph.nodes[current_break_idx : i + 1], + device_ordinal=subgraph.device_ordinal, + ) + ) + current_break_idx = i + 1 + new_subgraphs.append( + Subgraph( + is_acc=True, + nodes=subgraph.nodes[current_break_idx:], + device_ordinal=subgraph.device_ordinal, + ) + ) + else: + new_subgraphs.append(subgraph) + return new_subgraphs + def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: """Generates starter nodes for partitioning + segmentation""" # Starter accelerated nodes are all callable accelerated ops diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index f822e40e1b..0d08b620b5 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,7 +1,9 @@ from __future__ import annotations +import ctypes import gc import logging +import platform import warnings from dataclasses import fields, replace from enum import Enum @@ -17,6 +19,7 @@ ) import numpy as np +import psutil import sympy import tensorrt as trt import torch @@ -853,3 +856,36 @@ def get_output_dtypes(output: Any, truncate_double: bool = False) -> List[dtype] f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node" ) return output_dtypes + + +def is_tegra_platform() -> bool: + if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: + return True + return False + + +def is_thor() -> bool: + if torch.cuda.get_device_capability() in [(11, 0)]: + return True + return False + + +def get_cpu_memory_usage() -> Any: + return psutil.Process().memory_info().rss / 1024 / 1024 + + +def release_memory() -> None: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + + if platform.system() == "Linux": + try: + libc = ctypes.CDLL("libc.so.6") + if libc.malloc_trim(0) != 1: + logger.warning("Failed to release CPU memory.") + except Exception: + logger.warning("Failed to release CPU memory.") diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index d4133ff4b4..a1600e46eb 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -55,6 +55,52 @@ def test_resnet18(ir): torch._dynamo.reset() +def compile_one(idx: int, ir: str): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((idx + 1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"In multiprocess compilation test, process {idx} failed: Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +def test_resnet18_multiprocess(ir): + import torch.multiprocessing as mp + + mp.set_start_method("spawn", force=True) + procs = [] + for i in range(3): + p = mp.Process(target=compile_one, args=(i, ir)) + p.start() + procs.append(p) + for p in procs: + p.join() + torch._dynamo.reset() + + @pytest.mark.unit @unittest.skipIf( not importlib.util.find_spec("torchvision"), diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index a82384fda9..c86ee6f3a4 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -23,6 +23,7 @@ torch.ops.aten.scaled_dot_product_attention.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, ) @@ -43,6 +44,7 @@ def _remove_decompositions(): REPLACEABLE_ATEN_OPS = { torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, } from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( @@ -79,7 +81,10 @@ def _process_sdpa_node( ValueError: If the SDPA node has an unexpected number of arguments """ - if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if node.target in [ + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + ]: if len(node.args) == 7: ( query, From 7d8554ce8cb34cd4703a5a10b0695f8341e13c01 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 28 Oct 2025 20:12:33 +0000 Subject: [PATCH 2/5] Added the experiment files --- py/torch_tensorrt/dynamo/_compiler.py | 12 +- .../partitioning/_adjacency_partitioner.py | 128 +++++++++++++++++- 2 files changed, 130 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc345947d3..81e9c67010 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -861,10 +861,10 @@ def preserve_module_specs( dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module)) submodule_node_dict = {} - for node in partitioned_module.graph.nodes: - if "_run_on_acc" not in node.name: + for name, node in partitioned_module.named_children(): + if "_run_on_acc" not in name: continue - submodule_node_dict[node.name] = node + submodule_node_dict[name] = node preserve_module_specs(original_in_spec, original_out_spec, partitioned_module) # Store TRT replicas of Torch subgraphs @@ -877,6 +877,12 @@ def preserve_module_specs( for attr in dir(gm): if attr.startswith("_frozen_param"): delattr(gm, attr) + + + + from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS + DYNAMO_CONVERTERS.disallowed_targets = set() + for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index b6972c7e85..f9aa7090af 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -1,3 +1,4 @@ +from ast import Assert import logging from typing import Collection, Dict, List, Optional, Tuple @@ -226,16 +227,26 @@ def partition_graph(self) -> torch.fx.GraphModule: # Remove segments smaller than the block size (with exceptions) subgraphs = self.remove_small_acc_subgraphs(subgraphs) - num_of_break = self.calculate_num_of_break(subgraphs) - subgraphs = self.break_subgraphs(subgraphs, num_of_break=num_of_break) + # num_of_break = self.calculate_num_of_break(subgraphs) + subgraphs = self.break_subgraphs_by_node(subgraphs, num_of_break=5) # Set the number of TRT engines to be generated self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) # Tag the accelerated nodes and split the graph accordingly + print([len(s.nodes) for s in subgraphs]) self.tag(subgraphs) - return self.split() + for s in subgraphs: + print(s.nodes) + + gm = self.split() + self.weight_visited_nodes = set() + [self.size_of_subgraph(s) for s in subgraphs] + + + return gm + def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int: """ This function calculates the break period based on the number of subgraphs. @@ -260,15 +271,16 @@ def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int: 1, num_of_graphs // ((len(subgraphs) + 1) // 2) ) # If there are already graph breaks, for each TRT subgraph, we break for a few times. - def break_subgraphs( + + def break_subgraphs_by_node( self, subgraphs: List[Subgraph], num_of_break: int = 1 ) -> List[Subgraph]: """ This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory. """ - + op_to_break = "add." num_of_sdpa_node = len( - [node for node in self.acc_nodes if "scaled_dot" in str(node.target)] + [node for node in self.acc_nodes if op_to_break in str(node.target)] ) break_period = num_of_sdpa_node // num_of_break + 1 current_break_idx = 0 @@ -277,7 +289,7 @@ def break_subgraphs( for subgraph in subgraphs: if subgraph.is_acc: for i, node in enumerate(subgraph.nodes): - if "scaled_dot" in str(node.target): + if op_to_break in str(node.target): current_num_break += 1 if current_num_break % break_period != 0: continue @@ -298,8 +310,110 @@ def break_subgraphs( ) else: new_subgraphs.append(subgraph) + + new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) + return new_subgraphs + def break_subgraphs( + self, subgraphs: List[Subgraph], num_of_break: int = 1 + ) -> List[Subgraph]: + """ + This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory. + """ + break_pos = [0, 100, 200, 300, 400] + current_break_idx = 0 + new_subgraphs = [] + for subgraph in subgraphs: + if subgraph.is_acc: + for i, node in enumerate(subgraph.nodes): + if i in break_pos: + + new_subgraphs.append( + Subgraph( + is_acc=True, + nodes=subgraph.nodes[current_break_idx : i + 1], + device_ordinal=subgraph.device_ordinal, + ) + ) + current_break_idx = i + 1 + new_subgraphs.append( + Subgraph( + is_acc=True, + nodes=subgraph.nodes[current_break_idx:], + device_ordinal=subgraph.device_ordinal, + ) + ) + else: + new_subgraphs.append(subgraph) + + new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) + return new_subgraphs + + def size_of_subgraph(self, subgraph: Subgraph) -> int: + """ + This function calculates the size of the subgraph. + """ + stack = subgraph.nodes.copy() + size = 0 + while stack: + node = stack.pop() + if node in self.weight_visited_nodes: + continue + self.weight_visited_nodes.add(node) + if node.op == "get_attr": + weight = self.module.state_dict()[node.target] + size += weight.numel() * weight.element_size() + self.weight_visited_nodes.add(node) + continue + for input_node in node._input_nodes: + if input_node not in self.weight_visited_nodes: + stack.append(input_node) + print(size) + return size + + def validate_and_correct_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: + """ + This function validates the subgraphs by checking if the subgraphs are valid, and corrects the subgraphs if they are not valid. + """ + visited_nodes = {} + print([len(s.nodes) for s in subgraphs]) + for i, subgraph in enumerate(subgraphs): + if i == 0: + for node in subgraph.nodes: + visited_nodes[node] = i + visited_nodes[subgraph.nodes[-1]] = i + 1 + continue + + + elif not subgraph.is_acc: + for node in subgraph.nodes: + visited_nodes[subgraph.nodes[-1]] = i + 1 + continue + + else: + to_remove_nodes = [] + for j, node in enumerate(subgraph.nodes): + if j == len(subgraph.nodes) - 1: + visited_nodes[node] = i + 1 + continue + subgraph_idx = 0 + for dep in self.deps[node]: + if dep in visited_nodes: + subgraph_idx = max(subgraph_idx, visited_nodes[dep]) + else: + raise ValueError(f"Node {node} have a dependency that is not covered in the previous subgraphs. This is caused by a invalid subgraph segmentation.") + if subgraph_idx != i: + subgraphs[subgraph_idx].nodes.append(node) + to_remove_nodes.append(node) + visited_nodes[node] = subgraph_idx + for node in to_remove_nodes: + subgraph.nodes.remove(node) + + return subgraphs + + + def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: """Generates starter nodes for partitioning + segmentation""" # Starter accelerated nodes are all callable accelerated ops From cbfac69cf962e8163965ca61308e3925715845ed Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 4 Nov 2025 01:05:57 +0000 Subject: [PATCH 3/5] Implemeted the prototype --- .../partitioning/_adjacency_partitioner.py | 244 +++++++++++++----- 1 file changed, 174 insertions(+), 70 deletions(-) diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index f9aa7090af..fad49ae6ab 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -1,4 +1,3 @@ -from ast import Assert import logging from typing import Collection, Dict, List, Optional, Tuple @@ -26,6 +25,10 @@ ) logger = logging.getLogger(__name__) +NON_BREAKABLE_OP_LISTS = [ + ["addmm", "addmm"], + ["conv2d", "batch_norm2d", "relu"], +] class OpSupportTester(ops.OperatorSupportBase): # type: ignore @@ -227,8 +230,9 @@ def partition_graph(self) -> torch.fx.GraphModule: # Remove segments smaller than the block size (with exceptions) subgraphs = self.remove_small_acc_subgraphs(subgraphs) - # num_of_break = self.calculate_num_of_break(subgraphs) - subgraphs = self.break_subgraphs_by_node(subgraphs, num_of_break=5) + subgraphs = self.break_subgraphs( + subgraphs, size_budget=self.calculate_size_budget() + ) # Set the number of TRT engines to be generated self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) @@ -241,36 +245,19 @@ def partition_graph(self) -> torch.fx.GraphModule: print(s.nodes) gm = self.split() - self.weight_visited_nodes = set() - [self.size_of_subgraph(s) for s in subgraphs] - return gm - - def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int: + + def calculate_size_budget( + self, engine_compilation_memory_usage_multiplier: int = 4 + ) -> int: """ - This function calculates the break period based on the number of subgraphs. + This function calculates the size budget based on the available RSS. We assume that TRT compilation + needs at most 4x the memory of the model. """ - rss = psutil.Process().memory_info().rss - available_rss = psutil.virtual_memory().available - num_of_graphs = len(subgraphs) - if rss < available_rss * 0.3: - num_of_graphs = 1 - elif rss < available_rss * 0.5: - num_of_graphs = 2 - elif rss < available_rss: - num_of_graphs = 4 - elif rss < available_rss * 1.5: - num_of_graphs = 8 - elif rss < available_rss * 2: - num_of_graphs = 16 - else: - num_of_graphs = 32 - - return max( - 1, num_of_graphs // ((len(subgraphs) + 1) // 2) - ) # If there are already graph breaks, for each TRT subgraph, we break for a few times. + available_rss: int = psutil.virtual_memory().available + return available_rss // engine_compilation_memory_usage_multiplier def break_subgraphs_by_node( self, subgraphs: List[Subgraph], num_of_break: int = 1 @@ -278,7 +265,7 @@ def break_subgraphs_by_node( """ This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory. """ - op_to_break = "add." + op_to_break = "addmm." num_of_sdpa_node = len( [node for node in self.acc_nodes if op_to_break in str(node.target)] ) @@ -312,72 +299,193 @@ def break_subgraphs_by_node( new_subgraphs.append(subgraph) new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) - + return new_subgraphs def break_subgraphs( - self, subgraphs: List[Subgraph], num_of_break: int = 1 + self, subgraphs: List[Subgraph], size_budget: int ) -> List[Subgraph]: """ - This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory. + This function breaks the subgraphs into smaller subgraphs to save CPU memory. """ - break_pos = [0, 100, 200, 300, 400] - current_break_idx = 0 new_subgraphs = [] - for subgraph in subgraphs: - if subgraph.is_acc: - for i, node in enumerate(subgraph.nodes): - if i in break_pos: + # We throw an error if the remaining memory is almost empty compared to the model size. + # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. + sizes = [(subgraph, self.size_of_subgraph(subgraph)) for subgraph in subgraphs] + if sum([size for _, size in sizes]) > size_budget * 40: + raise ValueError( + f"Subgraph size {sum([size for _, size in sizes])} is too large to break. Size budget: {size_budget}" + ) + for subgraph, size in sizes: - new_subgraphs.append( - Subgraph( - is_acc=True, - nodes=subgraph.nodes[current_break_idx : i + 1], - device_ordinal=subgraph.device_ordinal, - ) - ) - current_break_idx = i + 1 - new_subgraphs.append( - Subgraph( - is_acc=True, - nodes=subgraph.nodes[current_break_idx:], - device_ordinal=subgraph.device_ordinal, - ) + while size > size_budget: + broken_subgraphs, size_0, size_1 = self.break_subgraph_by_size( + subgraph, size_budget ) - else: - new_subgraphs.append(subgraph) + size = size_1 + new_subgraphs.append(broken_subgraphs[0]) + subgraph = broken_subgraphs[1] + new_subgraphs.append(subgraph) - new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) return new_subgraphs + def break_subgraph_by_size( + self, subgraph: Subgraph, size_to_break: int + ) -> Tuple[List[Subgraph], int, int]: + """ + This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory. + """ + all_nodes = subgraph.nodes + device_ordinal = subgraph.device_ordinal + new_subgraphs = [ + Subgraph( + is_acc=True, + nodes=[], + device_ordinal=device_ordinal, + ), + Subgraph( + is_acc=True, + nodes=all_nodes, + device_ordinal=device_ordinal, + ), + ] + + while True: + new_subgraphs = self.step_and_validate(new_subgraphs) + size_0, size_1 = self.size_of_subgraph( + new_subgraphs[0] + ), self.size_of_subgraph(new_subgraphs[1]) + if size_0 > size_to_break: + break + + if len(new_subgraphs[1].nodes) == 0: + new_subgraphs.pop(1) + return new_subgraphs, size_0, size_1 + + def step_and_validate( + self, new_subgraphs: List[Subgraph], step_size: int = 1 + ) -> List[Subgraph]: + + # TODO: We can change it to binary search to find the optimal break point + for _ in range(step_size): + new_subgraphs[0].nodes.append(new_subgraphs[1].nodes.pop(0)) + + while True: + new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) + nodes_in_first_subgraph = set(new_subgraphs[0].nodes) + leaf_node = self.get_leaf_node(nodes_in_first_subgraph) + broken_fusion = self.step_if_break_fusion( + new_subgraphs, leaf_node, nodes_in_first_subgraph + ) + if not broken_fusion or len(new_subgraphs[1].nodes) == 0: + break + + return new_subgraphs + + def step_if_break_fusion( + self, + subgraphs: List[Subgraph], + leaf_nodes: set[torch.fx.Node], + nodes_in_first_subgraph: set[torch.fx.Node], + ) -> bool: + + def add_nodes(node: torch.fx.Node) -> None: + """ + This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order. + """ + if node.op in CALLABLE_NODE_OPS and node not in nodes_in_first_subgraph: + nodes_in_first_subgraph.add(node) + for input_node in node._input_nodes: + add_nodes(input_node) + subgraphs[0].nodes.append(node) + subgraphs[1].nodes.remove(node) + + def match_subgraph_and_step(node: torch.fx.Node) -> bool: + added_nodes = False + for op_list in NON_BREAKABLE_OP_LISTS: + for i, op in enumerate(op_list): + if i != len(op_list) - 1 and op in str(node.target): + # Search following ops forward using BFS. We skip search previous ops because + # even if it's just a subset of fusion graph, we still want it to be fused. + + users = node.users.keys() + matching_nodes: set[torch.fx.Node] = set() + for following_op_idx in range(i + 1, len(op_list)): + matching_nodes = set() + for user in users: + if op_list[following_op_idx] in str(user.target): + matching_nodes.add(user) + if not matching_nodes: + break + users = set() + for matching_node in matching_nodes: + for next_user in matching_node.users: + users.add(next_user) + + for matching_node in matching_nodes: + added_nodes = True + add_nodes(matching_node) + + if added_nodes: + # Early terminate the search if we have found a match because preceeding matches can cover following matches + break + + return True if added_nodes else False + + found_match = False + for leaf in leaf_nodes: + if match_subgraph_and_step(leaf): + found_match = True + + return found_match + + def get_leaf_node( + self, nodes_in_first_subgraph: set[torch.fx.Node] + ) -> set[torch.fx.Node]: + leaf_node = set() + + for node in nodes_in_first_subgraph: + for user in node.users: + if user not in nodes_in_first_subgraph: + leaf_node.add(node) + break + return leaf_node + def size_of_subgraph(self, subgraph: Subgraph) -> int: """ This function calculates the size of the subgraph. """ + nodes_in_subgraph = set(subgraph.nodes) + weight_visited_nodes = set() stack = subgraph.nodes.copy() size = 0 while stack: node = stack.pop() - if node in self.weight_visited_nodes: + if node in weight_visited_nodes: continue - self.weight_visited_nodes.add(node) if node.op == "get_attr": weight = self.module.state_dict()[node.target] size += weight.numel() * weight.element_size() - self.weight_visited_nodes.add(node) + weight_visited_nodes.add(node) + continue + if node not in nodes_in_subgraph: + # Trace to other subgraphs continue - for input_node in node._input_nodes: - if input_node not in self.weight_visited_nodes: + for input_node in node._input_nodes: + if input_node not in weight_visited_nodes: stack.append(input_node) - print(size) + return size - def validate_and_correct_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: + def validate_and_correct_subgraphs( + self, subgraphs: List[Subgraph] + ) -> List[Subgraph]: """ This function validates the subgraphs by checking if the subgraphs are valid, and corrects the subgraphs if they are not valid. """ - visited_nodes = {} - print([len(s.nodes) for s in subgraphs]) + visited_nodes = ( + {} + ) # a map from a node to the index of the subgraph it's user should belong to for i, subgraph in enumerate(subgraphs): if i == 0: for node in subgraph.nodes: @@ -385,7 +493,6 @@ def validate_and_correct_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subg visited_nodes[subgraph.nodes[-1]] = i + 1 continue - elif not subgraph.is_acc: for node in subgraph.nodes: visited_nodes[subgraph.nodes[-1]] = i + 1 @@ -401,18 +508,15 @@ def validate_and_correct_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subg for dep in self.deps[node]: if dep in visited_nodes: subgraph_idx = max(subgraph_idx, visited_nodes[dep]) - else: - raise ValueError(f"Node {node} have a dependency that is not covered in the previous subgraphs. This is caused by a invalid subgraph segmentation.") + if subgraph_idx != i: subgraphs[subgraph_idx].nodes.append(node) to_remove_nodes.append(node) visited_nodes[node] = subgraph_idx for node in to_remove_nodes: subgraph.nodes.remove(node) - + return subgraphs - - def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: """Generates starter nodes for partitioning + segmentation""" From 51197bc898fc4b3dbbefefe45d2c6aa8a045dc13 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 5 Nov 2025 21:23:47 +0000 Subject: [PATCH 4/5] Added cpu memory budget to the frontend --- py/torch_tensorrt/dynamo/_compiler.py | 14 +++- py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 2 + .../partitioning/_adjacency_partitioner.py | 84 ++++++++++--------- .../dynamo/partitioning/fusion_subgraphs.py | 0 5 files changed, 59 insertions(+), 42 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/partitioning/fusion_subgraphs.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 81e9c67010..41a24d7ab1 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -105,6 +105,7 @@ def cross_compile_for_windows( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -179,6 +180,7 @@ def cross_compile_for_windows( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -334,6 +336,7 @@ def cross_compile_for_windows( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "use_distributed_mode_trace": use_distributed_mode_trace, + "cpu_memory_budget": cpu_memory_budget, } # disable the following settings is not supported for cross compilation for windows feature @@ -435,6 +438,7 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -681,6 +685,7 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "cpu_memory_budget": cpu_memory_budget, } logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) @@ -833,6 +838,7 @@ def preserve_module_specs( torch_executed_ops=settings.torch_executed_ops, require_full_compilation=settings.require_full_compilation, skip_fusion=(num_supported_ops == total_ops), + cpu_memory_budget=settings.cpu_memory_budget, ) except torch.fx.passes.splitter_base.FxNetSplitterInternalError: @@ -878,11 +884,10 @@ def preserve_module_specs( if attr.startswith("_frozen_param"): delattr(gm, attr) - - from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS + DYNAMO_CONVERTERS.disallowed_targets = set() - + for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -1071,6 +1076,7 @@ def convert_exported_program_to_serialized_trt_engine( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1345,7 +1351,7 @@ def convert_exported_program_to_serialized_trt_engine( ) flattened_input_list = get_flat_args_with_check( - exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore + exported_program, list(trt_arg_inputs), trt_kwarg_inputs )[0] try: diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de970ecd81..712eeb1ba0 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -57,6 +57,7 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +CPU_MEMORY_BUDGET = -1 if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..52ac86012c 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -7,6 +7,7 @@ from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, CACHE_BUILT_ENGINES, + CPU_MEMORY_BUDGET, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -140,6 +141,7 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + cpu_memory_budget: int = CPU_MEMORY_BUDGET def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index fad49ae6ab..99dd5f69ca 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -118,6 +118,7 @@ def __init__( require_full_compilation: bool = REQUIRE_FULL_COMPILATION, return_tuple: bool = False, skip_fusion: bool = False, + cpu_memory_budget: int = -1, ): """ Preprocesses graph before splitting: @@ -137,6 +138,7 @@ def __init__( skip_fusion=skip_fusion, ) self.operator_support = operator_support + self.cpu_memory_budget = cpu_memory_budget # Get all accelerated nodes based on operator support conditions self.acc_nodes = FxNetAccNodesFinder( @@ -231,19 +233,15 @@ def partition_graph(self) -> torch.fx.GraphModule: subgraphs = self.remove_small_acc_subgraphs(subgraphs) subgraphs = self.break_subgraphs( - subgraphs, size_budget=self.calculate_size_budget() + subgraphs, subgraph_size_budget=self.calculate_size_budget() ) # Set the number of TRT engines to be generated self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) # Tag the accelerated nodes and split the graph accordingly - print([len(s.nodes) for s in subgraphs]) self.tag(subgraphs) - for s in subgraphs: - print(s.nodes) - gm = self.split() return gm @@ -255,8 +253,11 @@ def calculate_size_budget( This function calculates the size budget based on the available RSS. We assume that TRT compilation needs at most 4x the memory of the model. """ - - available_rss: int = psutil.virtual_memory().available + if self.cpu_memory_budget == -1: + available_rss: int = psutil.virtual_memory().available + else: + used_rss: int = psutil.virtual_memory().used + available_rss = self.cpu_memory_budget - used_rss return available_rss // engine_compilation_memory_usage_multiplier def break_subgraphs_by_node( @@ -303,7 +304,7 @@ def break_subgraphs_by_node( return new_subgraphs def break_subgraphs( - self, subgraphs: List[Subgraph], size_budget: int + self, subgraphs: List[Subgraph], subgraph_size_budget: int ) -> List[Subgraph]: """ This function breaks the subgraphs into smaller subgraphs to save CPU memory. @@ -311,16 +312,17 @@ def break_subgraphs( new_subgraphs = [] # We throw an error if the remaining memory is almost empty compared to the model size. # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. - sizes = [(subgraph, self.size_of_subgraph(subgraph)) for subgraph in subgraphs] - if sum([size for _, size in sizes]) > size_budget * 40: + sizes = self.size_of_subgraphs(subgraphs) + if sum(sizes) > subgraph_size_budget * 40: raise ValueError( - f"Subgraph size {sum([size for _, size in sizes])} is too large to break. Size budget: {size_budget}" + f"CPU memory budget or available memory is too small to compile the model. CPU memory budget: {self.cpu_memory_budget // (1024 * 1024) if self.cpu_memory_budget != -1 else "All available memory"} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + + "Consider setting cpu_memory_budget to a larger value or disable offload_module_to_cpu to save more CPU memory." ) - for subgraph, size in sizes: + for subgraph, size in zip(subgraphs, sizes): - while size > size_budget: + while size > subgraph_size_budget: broken_subgraphs, size_0, size_1 = self.break_subgraph_by_size( - subgraph, size_budget + subgraph, subgraph_size_budget ) size = size_1 new_subgraphs.append(broken_subgraphs[0]) @@ -351,10 +353,11 @@ def break_subgraph_by_size( ] while True: - new_subgraphs = self.step_and_validate(new_subgraphs) - size_0, size_1 = self.size_of_subgraph( - new_subgraphs[0] - ), self.size_of_subgraph(new_subgraphs[1]) + step_size = ( + 1 if not new_subgraphs[0].nodes else max(1, len(all_nodes) // 50) + ) # Set a step size proportional to the size of the subgraph to make the algorithm more efficient + new_subgraphs = self.step_and_validate(new_subgraphs, step_size) + size_0, size_1 = self.size_of_subgraphs(new_subgraphs) if size_0 > size_to_break: break @@ -451,31 +454,34 @@ def get_leaf_node( break return leaf_node - def size_of_subgraph(self, subgraph: Subgraph) -> int: + def size_of_subgraphs(self, subgraphs: List[Subgraph]) -> List[int]: """ This function calculates the size of the subgraph. """ - nodes_in_subgraph = set(subgraph.nodes) + state_dict = self.module.state_dict(keep_vars=True) + sizes = [] weight_visited_nodes = set() - stack = subgraph.nodes.copy() - size = 0 - while stack: - node = stack.pop() - if node in weight_visited_nodes: - continue - if node.op == "get_attr": - weight = self.module.state_dict()[node.target] - size += weight.numel() * weight.element_size() + for subgraph in subgraphs: + nodes_in_subgraph = set(subgraph.nodes) + stack = subgraph.nodes.copy() + size = 0 + while stack: + node = stack.pop() + if node in weight_visited_nodes: + continue weight_visited_nodes.add(node) - continue - if node not in nodes_in_subgraph: - # Trace to other subgraphs - continue - for input_node in node._input_nodes: - if input_node not in weight_visited_nodes: - stack.append(input_node) - - return size + if node.op == "get_attr": + weight = state_dict[node.target] + size += weight.numel() * weight.element_size() + continue + if node not in nodes_in_subgraph: + # Trace to other subgraphs + continue + for input_node in node._input_nodes: + if input_node not in weight_visited_nodes: + stack.append(input_node) + sizes.append(size) + return sizes def validate_and_correct_subgraphs( self, subgraphs: List[Subgraph] @@ -541,6 +547,7 @@ def partition( torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, skip_fusion: bool = False, + cpu_memory_budget: int = -1, ) -> Tuple[torch.fx.GraphModule, OpSupportTester]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -567,6 +574,7 @@ def partition( min_block_size=min_block_size, require_full_compilation=require_full_compilation, skip_fusion=skip_fusion, + cpu_memory_budget=cpu_memory_budget, ) partitioned_graph = partitioner.partition_graph() diff --git a/py/torch_tensorrt/dynamo/partitioning/fusion_subgraphs.py b/py/torch_tensorrt/dynamo/partitioning/fusion_subgraphs.py new file mode 100644 index 0000000000..e69de29bb2 From 9ee7e678a1590f40d83a08232dd700794a959a74 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 7 Nov 2025 20:53:12 +0000 Subject: [PATCH 5/5] Added new subgraph definition paradigm and revised matching logic --- .../partitioning/_adjacency_partitioner.py | 87 +++++---- .../dynamo/partitioning/fusion_patterns.py | 183 ++++++++++++++++++ .../dynamo/partitioning/fusion_subgraphs.py | 0 3 files changed, 230 insertions(+), 40 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py delete mode 100644 py/torch_tensorrt/dynamo/partitioning/fusion_subgraphs.py diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 99dd5f69ca..2bb67d1e1e 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -233,7 +233,8 @@ def partition_graph(self) -> torch.fx.GraphModule: subgraphs = self.remove_small_acc_subgraphs(subgraphs) subgraphs = self.break_subgraphs( - subgraphs, subgraph_size_budget=self.calculate_size_budget() + subgraphs, + subgraph_size_budget=500 * 1024 * 1024, # self.calculate_size_budget() ) # Set the number of TRT engines to be generated @@ -309,6 +310,11 @@ def break_subgraphs( """ This function breaks the subgraphs into smaller subgraphs to save CPU memory. """ + from torch_tensorrt.dynamo.partitioning.fusion_patterns import ( + get_node_in_fusion_pattern, + ) + + self.fusion_patterns = get_node_in_fusion_pattern(self.module.graph) new_subgraphs = [] # We throw an error if the remaining memory is almost empty compared to the model size. # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. @@ -328,9 +334,26 @@ def break_subgraphs( new_subgraphs.append(broken_subgraphs[0]) subgraph = broken_subgraphs[1] new_subgraphs.append(subgraph) - + self._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs) return new_subgraphs + def _varify_all_fusion_nodes_in_same_subgraph( + self, subgraphs: List[Subgraph] + ) -> None: + node_to_subgraph = {} + for i, s in enumerate(subgraphs): + for n in s.nodes: + node_to_subgraph[n] = i + + fusion_nodes_map_list = [ + len({node_to_subgraph[n] for n in ns}) == 1 + for ns in self.fusion_patterns.values() + ] + assert all( + fusion_nodes_map_list + ), "All fusion nodes must be in the same subgraph" + logger.info("All fusion nodes are in the same subgraph.") + def break_subgraph_by_size( self, subgraph: Subgraph, size_to_break: int ) -> Tuple[List[Subgraph], int, int]: @@ -376,9 +399,13 @@ def step_and_validate( while True: new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) nodes_in_first_subgraph = set(new_subgraphs[0].nodes) + nodes_in_second_subgraph = set(new_subgraphs[1].nodes) leaf_node = self.get_leaf_node(nodes_in_first_subgraph) broken_fusion = self.step_if_break_fusion( - new_subgraphs, leaf_node, nodes_in_first_subgraph + new_subgraphs, + leaf_node, + nodes_in_first_subgraph, + nodes_in_second_subgraph, ) if not broken_fusion or len(new_subgraphs[1].nodes) == 0: break @@ -390,57 +417,37 @@ def step_if_break_fusion( subgraphs: List[Subgraph], leaf_nodes: set[torch.fx.Node], nodes_in_first_subgraph: set[torch.fx.Node], + nodes_in_second_subgraph: set[torch.fx.Node], ) -> bool: def add_nodes(node: torch.fx.Node) -> None: """ This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order. """ - if node.op in CALLABLE_NODE_OPS and node not in nodes_in_first_subgraph: + if ( + node.op in CALLABLE_NODE_OPS + and node not in nodes_in_first_subgraph + and node in nodes_in_second_subgraph + ): + # Exclude all nodes already in the first subgraph nodes_in_first_subgraph.add(node) + nodes_in_second_subgraph.remove(node) for input_node in node._input_nodes: add_nodes(input_node) subgraphs[0].nodes.append(node) subgraphs[1].nodes.remove(node) - def match_subgraph_and_step(node: torch.fx.Node) -> bool: - added_nodes = False - for op_list in NON_BREAKABLE_OP_LISTS: - for i, op in enumerate(op_list): - if i != len(op_list) - 1 and op in str(node.target): - # Search following ops forward using BFS. We skip search previous ops because - # even if it's just a subset of fusion graph, we still want it to be fused. - - users = node.users.keys() - matching_nodes: set[torch.fx.Node] = set() - for following_op_idx in range(i + 1, len(op_list)): - matching_nodes = set() - for user in users: - if op_list[following_op_idx] in str(user.target): - matching_nodes.add(user) - if not matching_nodes: - break - users = set() - for matching_node in matching_nodes: - for next_user in matching_node.users: - users.add(next_user) - - for matching_node in matching_nodes: - added_nodes = True - add_nodes(matching_node) - - if added_nodes: - # Early terminate the search if we have found a match because preceeding matches can cover following matches - break - - return True if added_nodes else False - - found_match = False + fusion_broken = False for leaf in leaf_nodes: - if match_subgraph_and_step(leaf): - found_match = True + for node in self.fusion_patterns.get(leaf, []): + if ( + node not in nodes_in_first_subgraph + and node in nodes_in_second_subgraph + ): + fusion_broken = True + add_nodes(node) - return found_match + return fusion_broken def get_leaf_node( self, nodes_in_first_subgraph: set[torch.fx.Node] diff --git a/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py b/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py new file mode 100644 index 0000000000..86e2e901a0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/fusion_patterns.py @@ -0,0 +1,183 @@ +from typing import Dict, List, Set + +import torch +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher +from torch.ops import aten + + +class ConvBNReLU(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + running_mean: torch.Tensor, + running_var: torch.Tensor, + momentum: float, + eps: float, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten._native_batch_norm_legit_no_training.default( + x, bn_weight, bn_bias, running_mean, running_var, momentum, eps + )[0] + x = aten.relu.default(x) + return x + + +class ConvReLU(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten.relu.default(x) + return x + + +class ConvGelu(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten.gelu.default(x) + return x + + +class ConvSilu(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + x = aten.convolution.default( + x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1 + ) + x = aten.silu.default(x) + return x + + +class MulAdd(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + x = aten.mul.Tensor(x, weight) + x = aten.add.Tensor(x, bias) + return x + + +class MulMul(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + x = aten.mul.Tensor(x, y) + x = aten.mul.Tensor(x, z) + return x + + +All_FUSION_PATTERNS = [ + ConvBNReLU, + ConvReLU, + ConvGelu, + ConvSilu, + MulAdd, + MulMul, +] + + +def get_node_in_fusion_pattern( + graph: torch.fx.Graph, +) -> Dict[torch.fx.Node, Set[torch.fx.Node]]: + """ + This function gets the nodes map of the fusion pattern from the graph. + Key: node that appears in the fusion pattern + Value: the list of nodes that should be fused together + """ + fusion_nodes = {} + for pattern in All_FUSION_PATTERNS: + pattern_graph = torch.fx.symbolic_trace(pattern()) + subgraph_matcher = SubgraphMatcher(pattern_graph.graph) + match_result = subgraph_matcher.match(graph) + for match in match_result: + fusion_group = { + node + for node in match.nodes_map.values() + if node + and type(node) == torch.fx.Node + and node.op == "call_function" + and node not in match.placeholder_nodes + } + for node in fusion_group: + fusion_nodes[node] = fusion_group + + return fusion_nodes diff --git a/py/torch_tensorrt/dynamo/partitioning/fusion_subgraphs.py b/py/torch_tensorrt/dynamo/partitioning/fusion_subgraphs.py deleted file mode 100644 index e69de29bb2..0000000000