Skip to content

Commit f03ab2c

Browse files
committed
Added cpu memory budget to the frontend
1 parent 57b04d7 commit f03ab2c

File tree

5 files changed

+59
-42
lines changed

5 files changed

+59
-42
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def cross_compile_for_windows(
105105
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
106106
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
107107
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
108+
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
108109
**kwargs: Any,
109110
) -> torch.fx.GraphModule:
110111
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -179,6 +180,7 @@ def cross_compile_for_windows(
179180
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
180181
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
181182
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
183+
cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory.
182184
**kwargs: Any,
183185
Returns:
184186
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -334,6 +336,7 @@ def cross_compile_for_windows(
334336
"tiling_optimization_level": tiling_optimization_level,
335337
"l2_limit_for_tiling": l2_limit_for_tiling,
336338
"use_distributed_mode_trace": use_distributed_mode_trace,
339+
"cpu_memory_budget": cpu_memory_budget,
337340
}
338341

339342
# disable the following settings is not supported for cross compilation for windows feature
@@ -435,6 +438,7 @@ def compile(
435438
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
436439
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
437440
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
441+
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
438442
**kwargs: Any,
439443
) -> torch.fx.GraphModule:
440444
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -681,6 +685,7 @@ def compile(
681685
"l2_limit_for_tiling": l2_limit_for_tiling,
682686
"offload_module_to_cpu": offload_module_to_cpu,
683687
"use_distributed_mode_trace": use_distributed_mode_trace,
688+
"cpu_memory_budget": cpu_memory_budget,
684689
}
685690
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
686691
settings = CompilationSettings(**compilation_options)
@@ -833,6 +838,7 @@ def preserve_module_specs(
833838
torch_executed_ops=settings.torch_executed_ops,
834839
require_full_compilation=settings.require_full_compilation,
835840
skip_fusion=(num_supported_ops == total_ops),
841+
cpu_memory_budget=settings.cpu_memory_budget,
836842
)
837843

838844
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
@@ -878,11 +884,10 @@ def preserve_module_specs(
878884
if attr.startswith("_frozen_param"):
879885
delattr(gm, attr)
880886

881-
882-
883887
from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS
888+
884889
DYNAMO_CONVERTERS.disallowed_targets = set()
885-
890+
886891
for name, _ in partitioned_module.named_children():
887892
submodule = getattr(partitioned_module, name)
888893
# filter on the GraphModule
@@ -1071,6 +1076,7 @@ def convert_exported_program_to_serialized_trt_engine(
10711076
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
10721077
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
10731078
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
1079+
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
10741080
**kwargs: Any,
10751081
) -> bytes:
10761082
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1345,7 +1351,7 @@ def convert_exported_program_to_serialized_trt_engine(
13451351
)
13461352

13471353
flattened_input_list = get_flat_args_with_check(
1348-
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
1354+
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
13491355
)[0]
13501356

13511357
try:

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
L2_LIMIT_FOR_TILING = -1
5858
USE_DISTRIBUTED_MODE_TRACE = False
5959
OFFLOAD_MODULE_TO_CPU = False
60+
CPU_MEMORY_BUDGET = -1
6061

6162
if platform.system() == "Linux":
6263
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch_tensorrt.dynamo._defaults import (
88
ASSUME_DYNAMIC_SHAPE_SUPPORT,
99
CACHE_BUILT_ENGINES,
10+
CPU_MEMORY_BUDGET,
1011
DISABLE_TF32,
1112
DLA_GLOBAL_DRAM_SIZE,
1213
DLA_LOCAL_DRAM_SIZE,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
141142
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
142143
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
144+
cpu_memory_budget: int = CPU_MEMORY_BUDGET
143145

144146
def __getstate__(self) -> dict[str, Any]:
145147
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
119119
return_tuple: bool = False,
120120
skip_fusion: bool = False,
121+
cpu_memory_budget: int = -1,
121122
):
122123
"""
123124
Preprocesses graph before splitting:
@@ -137,6 +138,7 @@ def __init__(
137138
skip_fusion=skip_fusion,
138139
)
139140
self.operator_support = operator_support
141+
self.cpu_memory_budget = cpu_memory_budget
140142

141143
# Get all accelerated nodes based on operator support conditions
142144
self.acc_nodes = FxNetAccNodesFinder(
@@ -231,19 +233,15 @@ def partition_graph(self) -> torch.fx.GraphModule:
231233
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
232234

233235
subgraphs = self.break_subgraphs(
234-
subgraphs, size_budget=self.calculate_size_budget()
236+
subgraphs, subgraph_size_budget=self.calculate_size_budget()
235237
)
236238

237239
# Set the number of TRT engines to be generated
238240
self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc])
239241

240242
# Tag the accelerated nodes and split the graph accordingly
241-
print([len(s.nodes) for s in subgraphs])
242243
self.tag(subgraphs)
243244

244-
for s in subgraphs:
245-
print(s.nodes)
246-
247245
gm = self.split()
248246

249247
return gm
@@ -255,8 +253,11 @@ def calculate_size_budget(
255253
This function calculates the size budget based on the available RSS. We assume that TRT compilation
256254
needs at most 4x the memory of the model.
257255
"""
258-
259-
available_rss: int = psutil.virtual_memory().available
256+
if self.cpu_memory_budget == -1:
257+
available_rss: int = psutil.virtual_memory().available
258+
else:
259+
used_rss: int = psutil.virtual_memory().used
260+
available_rss = self.cpu_memory_budget - used_rss
260261
return available_rss // engine_compilation_memory_usage_multiplier
261262

262263
def break_subgraphs_by_node(
@@ -303,24 +304,25 @@ def break_subgraphs_by_node(
303304
return new_subgraphs
304305

305306
def break_subgraphs(
306-
self, subgraphs: List[Subgraph], size_budget: int
307+
self, subgraphs: List[Subgraph], subgraph_size_budget: int
307308
) -> List[Subgraph]:
308309
"""
309310
This function breaks the subgraphs into smaller subgraphs to save CPU memory.
310311
"""
311312
new_subgraphs = []
312313
# We throw an error if the remaining memory is almost empty compared to the model size.
313314
# i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
314-
sizes = [(subgraph, self.size_of_subgraph(subgraph)) for subgraph in subgraphs]
315-
if sum([size for _, size in sizes]) > size_budget * 40:
315+
sizes = self.size_of_subgraphs(subgraphs)
316+
if sum(sizes) > subgraph_size_budget * 40:
316317
raise ValueError(
317-
f"Subgraph size {sum([size for _, size in sizes])} is too large to break. Size budget: {size_budget}"
318+
f"CPU memory budget or available memory is too small to compile the model. CPU memory budget: {self.cpu_memory_budget // (1024 * 1024) if self.cpu_memory_budget != -1 else "All available memory"} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. "
319+
+ "Consider setting cpu_memory_budget to a larger value or disable offload_module_to_cpu to save more CPU memory."
318320
)
319-
for subgraph, size in sizes:
321+
for subgraph, size in zip(subgraphs, sizes):
320322

321-
while size > size_budget:
323+
while size > subgraph_size_budget:
322324
broken_subgraphs, size_0, size_1 = self.break_subgraph_by_size(
323-
subgraph, size_budget
325+
subgraph, subgraph_size_budget
324326
)
325327
size = size_1
326328
new_subgraphs.append(broken_subgraphs[0])
@@ -351,10 +353,11 @@ def break_subgraph_by_size(
351353
]
352354

353355
while True:
354-
new_subgraphs = self.step_and_validate(new_subgraphs)
355-
size_0, size_1 = self.size_of_subgraph(
356-
new_subgraphs[0]
357-
), self.size_of_subgraph(new_subgraphs[1])
356+
step_size = (
357+
1 if not new_subgraphs[0].nodes else max(1, len(all_nodes) // 50)
358+
) # Set a step size proportional to the size of the subgraph to make the algorithm more efficient
359+
new_subgraphs = self.step_and_validate(new_subgraphs, step_size)
360+
size_0, size_1 = self.size_of_subgraphs(new_subgraphs)
358361
if size_0 > size_to_break:
359362
break
360363

@@ -451,31 +454,34 @@ def get_leaf_node(
451454
break
452455
return leaf_node
453456

454-
def size_of_subgraph(self, subgraph: Subgraph) -> int:
457+
def size_of_subgraphs(self, subgraphs: List[Subgraph]) -> List[int]:
455458
"""
456459
This function calculates the size of the subgraph.
457460
"""
458-
nodes_in_subgraph = set(subgraph.nodes)
461+
state_dict = self.module.state_dict(keep_vars=True)
462+
sizes = []
459463
weight_visited_nodes = set()
460-
stack = subgraph.nodes.copy()
461-
size = 0
462-
while stack:
463-
node = stack.pop()
464-
if node in weight_visited_nodes:
465-
continue
466-
if node.op == "get_attr":
467-
weight = self.module.state_dict()[node.target]
468-
size += weight.numel() * weight.element_size()
464+
for subgraph in subgraphs:
465+
nodes_in_subgraph = set(subgraph.nodes)
466+
stack = subgraph.nodes.copy()
467+
size = 0
468+
while stack:
469+
node = stack.pop()
470+
if node in weight_visited_nodes:
471+
continue
469472
weight_visited_nodes.add(node)
470-
continue
471-
if node not in nodes_in_subgraph:
472-
# Trace to other subgraphs
473-
continue
474-
for input_node in node._input_nodes:
475-
if input_node not in weight_visited_nodes:
476-
stack.append(input_node)
477-
478-
return size
473+
if node.op == "get_attr":
474+
weight = state_dict[node.target]
475+
size += weight.numel() * weight.element_size()
476+
continue
477+
if node not in nodes_in_subgraph:
478+
# Trace to other subgraphs
479+
continue
480+
for input_node in node._input_nodes:
481+
if input_node not in weight_visited_nodes:
482+
stack.append(input_node)
483+
sizes.append(size)
484+
return sizes
479485

480486
def validate_and_correct_subgraphs(
481487
self, subgraphs: List[Subgraph]
@@ -541,6 +547,7 @@ def partition(
541547
torch_executed_ops: Collection[Target] = set(),
542548
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
543549
skip_fusion: bool = False,
550+
cpu_memory_budget: int = -1,
544551
) -> Tuple[torch.fx.GraphModule, OpSupportTester]:
545552
"""Partition an FX GraphModule with aten ops into TRT engines
546553
Partitioning is based on converter operator support
@@ -567,6 +574,7 @@ def partition(
567574
min_block_size=min_block_size,
568575
require_full_compilation=require_full_compilation,
569576
skip_fusion=skip_fusion,
577+
cpu_memory_budget=cpu_memory_budget,
570578
)
571579

572580
partitioned_graph = partitioner.partition_graph()

py/torch_tensorrt/dynamo/partitioning/fusion_subgraphs.py

Whitespace-only changes.

0 commit comments

Comments
 (0)