Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions examples/dynamo/autocast_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import torch.nn as nn
import torch_tensorrt


class MixedPytorchAutocastModel(nn.Module):
def __init__(self):
super(MixedPytorchAutocastModel, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1
)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(
in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1
)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(16 * 8 * 8, 10)

def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.flatten(x)
with torch.autocast(x.device.type, enabled=True, dtype=torch.float16):
x = self.fc1(x)
out = torch.log(
torch.abs(x) + 1
) # log is fp32 due to Pytorch Autocast requirements
return out


if __name__ == "__main__":
model = MixedPytorchAutocastModel().cuda().eval()
inputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device="cuda"),)
ep = torch.export.export(model, inputs)
calibration_dataloader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False
)

with torch_tensorrt.dynamo.Debugger(
"graphs",
logging_dir=".",
engine_builder_monitor=False,
):
trt_autocast_mod = torch_tensorrt.compile(
ep.module(),
arg_inputs=inputs,
min_block_size=1,
use_python_runtime=True,
##### weak typing #####
# use_explicit_typing=False,
# enabled_precisions={torch.float16},
##### strong typing + autocast #####
use_explicit_typing=True,
enable_autocast=True,
autocast_low_precision_type=torch.float16,
autocast_excluded_nodes={"^conv1$", "relu"},
autocast_excluded_ops={torch.ops.aten.flatten.using_ints},
autocast_max_output_threshold=512,
autocast_max_depth_of_reduction=None,
autocast_calibration_dataloader=calibration_dataloader,
)

autocast_outs = trt_autocast_mod(*inputs)
46 changes: 45 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def cross_compile_for_windows(
disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
workspace_size (int): Maximum size of workspace given to TensorRT
Expand Down Expand Up @@ -434,6 +434,19 @@ def compile(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
enable_autocast: bool = _defaults.ENABLE_AUTOCAST,
autocast_low_precision_type: Optional[
Union[torch.dtype, dtype]
] = _defaults.AUTOCAST_LOW_PRECISION_TYPE,
autocast_excluded_nodes: Collection[str] = _defaults.AUTOCAST_EXCLUDED_NODES,
autocast_excluded_ops: Collection[Target] = _defaults.AUTOCAST_EXCLUDED_OPS,
autocast_max_output_threshold: float = _defaults.AUTOCAST_MAX_OUTPUT_THRESHOLD,
autocast_max_depth_of_reduction: Optional[
int
] = _defaults.AUTOCAST_MAX_DEPTH_OF_REDUCTION,
autocast_calibration_dataloader: Optional[
torch.utils.data.DataLoader
] = _defaults.AUTOCAST_CALIBRATION_DATALOADER,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -511,6 +524,13 @@ def compile(
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
autocast_excluded_nodes (Collection[str]): The set of regex patterns to match user-specified node names that should remain in FP32. Default is [].
autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None.
autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -584,6 +604,10 @@ def compile(
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
)

if enable_autocast:
use_explicit_typing = True
logger.debug("Autocast is enabled, setting use_explicit_typing to True.")

if use_explicit_typing:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions
Expand All @@ -593,6 +617,19 @@ def compile(
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
)

if autocast_low_precision_type is not None:
if not isinstance(autocast_low_precision_type, (torch.dtype, dtype)):
raise ValueError(
f"autocast_low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(autocast_low_precision_type)}"
)
if autocast_low_precision_type not in {
torch.float16,
torch.bfloat16,
} and autocast_low_precision_type not in {dtype.f16, dtype.bf16}:
raise ValueError(
f"autocast_low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {autocast_low_precision_type}"
)

if use_fp32_acc:
logger.debug(
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
Expand Down Expand Up @@ -680,6 +717,13 @@ def compile(
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
"enable_autocast": enable_autocast,
"autocast_low_precision_type": autocast_low_precision_type,
"autocast_excluded_nodes": autocast_excluded_nodes,
"autocast_excluded_ops": autocast_excluded_ops,
"autocast_max_output_threshold": autocast_max_output_threshold,
"autocast_max_depth_of_reduction": autocast_max_depth_of_reduction,
"autocast_calibration_dataloader": autocast_calibration_dataloader,
}

settings = CompilationSettings(**compilation_options)
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
ENABLE_AUTOCAST = False
AUTOCAST_LOW_PRECISION_TYPE = None
AUTOCAST_EXCLUDED_NODES = set[str]()
AUTOCAST_EXCLUDED_OPS = set[torch.fx.node.Target]()
AUTOCAST_MAX_OUTPUT_THRESHOLD = 512
AUTOCAST_MAX_DEPTH_OF_REDUCTION = None
AUTOCAST_CALIBRATION_DATALOADER = None

if platform.system() == "Linux":
import pwd
Expand Down
29 changes: 29 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from dataclasses import dataclass, field
from typing import Any, Collection, Optional, Set, Tuple, Union

import torch
from torch.fx.node import Target
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
AUTOCAST_CALIBRATION_DATALOADER,
AUTOCAST_EXCLUDED_NODES,
AUTOCAST_EXCLUDED_OPS,
AUTOCAST_LOW_PRECISION_TYPE,
AUTOCAST_MAX_DEPTH_OF_REDUCTION,
AUTOCAST_MAX_OUTPUT_THRESHOLD,
CACHE_BUILT_ENGINES,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
ENABLE_AUTOCAST,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLE_WEIGHT_STREAMING,
Expand Down Expand Up @@ -97,6 +105,13 @@ class CompilationSettings:
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
autocast_excluded_nodes (Collection[str]): The set of regex patterns to match user-specified node names that should remain in FP32. Default is [].
autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is [].
autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None.
autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None.
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -140,6 +155,19 @@ class CompilationSettings:
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
enable_autocast: bool = ENABLE_AUTOCAST
autocast_low_precision_type: Optional[dtype] = AUTOCAST_LOW_PRECISION_TYPE
autocast_excluded_nodes: Collection[str] = field(
default_factory=lambda: AUTOCAST_EXCLUDED_NODES
)
autocast_excluded_ops: Collection[Target] = field(
default_factory=lambda: AUTOCAST_EXCLUDED_OPS
)
autocast_max_output_threshold: float = AUTOCAST_MAX_OUTPUT_THRESHOLD
autocast_max_depth_of_reduction: Optional[int] = AUTOCAST_MAX_DEPTH_OF_REDUCTION
autocast_calibration_dataloader: Optional[torch.utils.data.DataLoader] = (
AUTOCAST_CALIBRATION_DATALOADER
)

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand All @@ -157,6 +185,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)


# If any of the following setting is changed, the engine should be rebuilt.
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
"enabled_precisions",
"max_aux_streams",
Expand Down
23 changes: 19 additions & 4 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
import operator
from typing import Any, Callable, Optional, Sequence, Union

import torch
from torch_tensorrt._utils import is_tegra_platform
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
trace_intermediate_node_outputs,
)

from .complex_graph_rewrite import complex_graph_detection
from .constant_folding import constant_fold
Expand All @@ -15,6 +19,13 @@
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .rule_based_autocast import rule_based_autocast

pre_lowering_pass_list = [
remove_detach,
remove_assert_nodes,
rule_based_autocast,
]

post_lowering_pass_list = [
remove_input_alias_fixing_clones,
Expand All @@ -27,10 +38,6 @@
complex_graph_detection,
]

pre_lowering_pass_list = [
remove_detach,
]

if not is_tegra_platform():
from .fuse_distributed_ops import fuse_distributed_ops

Expand Down Expand Up @@ -135,6 +142,14 @@ def pre_export_lowering(
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}"
)

# Only for rule-based autocast to collect the intermediate node outputs
if settings.enable_autocast:
settings.autocast_intermediate_node_outputs = trace_intermediate_node_outputs(
ep.module(),
settings.autocast_calibration_dataloader,
[torch.ops.higher_order.wrap_with_autocast, operator.getitem],
)
gm = ep.graph_module
gm = ATEN_PRE_LOWERING_PASSES(gm, settings)
return ep
Expand Down
Loading
Loading