diff --git a/examples/dynamo/autocast_example.py b/examples/dynamo/autocast_example.py new file mode 100644 index 0000000000..b467ea0d31 --- /dev/null +++ b/examples/dynamo/autocast_example.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +import torch_tensorrt + + +class MixedPytorchAutocastModel(nn.Module): + def __init__(self): + super(MixedPytorchAutocastModel, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 + ) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d( + in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(16 * 8 * 8, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.relu1(x) + x = self.pool1(x) + x = self.conv2(x) + x = self.relu2(x) + x = self.pool2(x) + x = self.flatten(x) + with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): + x = self.fc1(x) + out = torch.log( + torch.abs(x) + 1 + ) # log is fp32 due to Pytorch Autocast requirements + return out + + +if __name__ == "__main__": + model = MixedPytorchAutocastModel().cuda().eval() + inputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device="cuda"),) + ep = torch.export.export(model, inputs) + calibration_dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False + ) + + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=".", + engine_builder_monitor=False, + ): + trt_autocast_mod = torch_tensorrt.compile( + ep.module(), + arg_inputs=inputs, + min_block_size=1, + use_python_runtime=True, + ##### weak typing ##### + # use_explicit_typing=False, + # enabled_precisions={torch.float16}, + ##### strong typing + autocast ##### + use_explicit_typing=True, + enable_autocast=True, + autocast_low_precision_type=torch.float16, + autocast_excluded_nodes={"^conv1$", "relu"}, + autocast_excluded_ops={torch.ops.aten.flatten.using_ints}, + autocast_max_output_threshold=512, + autocast_max_depth_of_reduction=None, + autocast_calibration_dataloader=calibration_dataloader, + ) + + autocast_outs = trt_autocast_mod(*inputs) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c8ad938032..0ff86ad235 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -141,7 +141,7 @@ def cross_compile_for_windows( disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. - enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels + enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -434,6 +434,19 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + enable_autocast: bool = _defaults.ENABLE_AUTOCAST, + autocast_low_precision_type: Optional[ + Union[torch.dtype, dtype] + ] = _defaults.AUTOCAST_LOW_PRECISION_TYPE, + autocast_excluded_nodes: Collection[str] = _defaults.AUTOCAST_EXCLUDED_NODES, + autocast_excluded_ops: Collection[Target] = _defaults.AUTOCAST_EXCLUDED_OPS, + autocast_max_output_threshold: float = _defaults.AUTOCAST_MAX_OUTPUT_THRESHOLD, + autocast_max_depth_of_reduction: Optional[ + int + ] = _defaults.AUTOCAST_MAX_DEPTH_OF_REDUCTION, + autocast_calibration_dataloader: Optional[ + torch.utils.data.DataLoader + ] = _defaults.AUTOCAST_CALIBRATION_DATALOADER, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -511,6 +524,13 @@ def compile( l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. + autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + autocast_excluded_nodes (Collection[str]): The set of regex patterns to match user-specified node names that should remain in FP32. Default is []. + autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None. + autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -584,6 +604,10 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) + if enable_autocast: + use_explicit_typing = True + logger.debug("Autocast is enabled, setting use_explicit_typing to True.") + if use_explicit_typing: if len(enabled_precisions) != 1 or not any( x in enabled_precisions @@ -593,6 +617,19 @@ def compile( f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" ) + if autocast_low_precision_type is not None: + if not isinstance(autocast_low_precision_type, (torch.dtype, dtype)): + raise ValueError( + f"autocast_low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(autocast_low_precision_type)}" + ) + if autocast_low_precision_type not in { + torch.float16, + torch.bfloat16, + } and autocast_low_precision_type not in {dtype.f16, dtype.bf16}: + raise ValueError( + f"autocast_low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {autocast_low_precision_type}" + ) + if use_fp32_acc: logger.debug( "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \ @@ -680,6 +717,13 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "enable_autocast": enable_autocast, + "autocast_low_precision_type": autocast_low_precision_type, + "autocast_excluded_nodes": autocast_excluded_nodes, + "autocast_excluded_ops": autocast_excluded_ops, + "autocast_max_output_threshold": autocast_max_output_threshold, + "autocast_max_depth_of_reduction": autocast_max_depth_of_reduction, + "autocast_calibration_dataloader": autocast_calibration_dataloader, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de970ecd81..3a238c11ee 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -57,6 +57,13 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +ENABLE_AUTOCAST = False +AUTOCAST_LOW_PRECISION_TYPE = None +AUTOCAST_EXCLUDED_NODES = set[str]() +AUTOCAST_EXCLUDED_OPS = set[torch.fx.node.Target]() +AUTOCAST_MAX_OUTPUT_THRESHOLD = 512 +AUTOCAST_MAX_DEPTH_OF_REDUCTION = None +AUTOCAST_CALIBRATION_DATALOADER = None if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..d62c75e0da 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,17 +1,25 @@ from dataclasses import dataclass, field from typing import Any, Collection, Optional, Set, Tuple, Union +import torch from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, + AUTOCAST_CALIBRATION_DATALOADER, + AUTOCAST_EXCLUDED_NODES, + AUTOCAST_EXCLUDED_OPS, + AUTOCAST_LOW_PRECISION_TYPE, + AUTOCAST_MAX_DEPTH_OF_REDUCTION, + AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, DRYRUN, + ENABLE_AUTOCAST, ENABLE_CROSS_COMPILE_FOR_WINDOWS, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENABLE_WEIGHT_STREAMING, @@ -97,6 +105,13 @@ class CompilationSettings: tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. + autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + autocast_excluded_nodes (Collection[str]): The set of regex patterns to match user-specified node names that should remain in FP32. Default is []. + autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None. + autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -140,6 +155,19 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + enable_autocast: bool = ENABLE_AUTOCAST + autocast_low_precision_type: Optional[dtype] = AUTOCAST_LOW_PRECISION_TYPE + autocast_excluded_nodes: Collection[str] = field( + default_factory=lambda: AUTOCAST_EXCLUDED_NODES + ) + autocast_excluded_ops: Collection[Target] = field( + default_factory=lambda: AUTOCAST_EXCLUDED_OPS + ) + autocast_max_output_threshold: float = AUTOCAST_MAX_OUTPUT_THRESHOLD + autocast_max_depth_of_reduction: Optional[int] = AUTOCAST_MAX_DEPTH_OF_REDUCTION + autocast_calibration_dataloader: Optional[torch.utils.data.DataLoader] = ( + AUTOCAST_CALIBRATION_DATALOADER + ) def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -157,6 +185,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state) +# If any of the following setting is changed, the engine should be rebuilt. _SETTINGS_TO_BE_ENGINE_INVARIANT = ( "enabled_precisions", "max_aux_streams", diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index e5183668ae..8ad5f2fcae 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -1,9 +1,13 @@ import logging +import operator from typing import Any, Callable, Optional, Sequence, Union import torch from torch_tensorrt._utils import is_tegra_platform from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + trace_intermediate_node_outputs, +) from .complex_graph_rewrite import complex_graph_detection from .constant_folding import constant_fold @@ -15,6 +19,13 @@ from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices +from .rule_based_autocast import rule_based_autocast + +pre_lowering_pass_list = [ + remove_detach, + remove_assert_nodes, + rule_based_autocast, +] post_lowering_pass_list = [ remove_input_alias_fixing_clones, @@ -27,10 +38,6 @@ complex_graph_detection, ] -pre_lowering_pass_list = [ - remove_detach, -] - if not is_tegra_platform(): from .fuse_distributed_ops import fuse_distributed_ops @@ -135,6 +142,14 @@ def pre_export_lowering( logging.debug( f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}" ) + + # Only for rule-based autocast to collect the intermediate node outputs + if settings.enable_autocast: + settings.autocast_intermediate_node_outputs = trace_intermediate_node_outputs( + ep.module(), + settings.autocast_calibration_dataloader, + [torch.ops.higher_order.wrap_with_autocast, operator.getitem], + ) gm = ep.graph_module gm = ATEN_PRE_LOWERING_PASSES(gm, settings) return ep diff --git a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py new file mode 100644 index 0000000000..9f0aa41916 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py @@ -0,0 +1,313 @@ +# Borrowed from ModelOpt AutoCast's nodeclassifier.py, modified to fit Torch-TensorRT's needs. +import abc +import logging +import operator +import re +from typing import Collection, Optional + +import torch + +logger = logging.getLogger(__name__) + + +class NodeRuleBase: + """Base class for node classification rules. + + This class defines the interface for rules that determine whether a node + should be kept in high precision or converted to low precision. + """ + + @abc.abstractmethod + def _check_inner(self, node): + """Implement this method to check if node conversion should be skipped based on rule criteria.""" + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes.""" + logger.info(f"Skipping node {node.name}: {self.__class__.__name__}") + + def check(self, node): + """Check if a node should be skipped based on the rule. + + Args: + node: The torch.fx.Node to check. + + Returns: + bool: True if the node should be kept in high precision, False otherwise. + """ + result = self._check_inner(node) + if result: + self._log_skipped(node) + return True + return False + + +class DisabledNodeNameRegexRule(NodeRuleBase): + """Rule for keeping nodes with matching user-specified names in high precision.""" + + def __init__(self, disabled_node_name_regex): + """Initialize the rule. + + Args: + disabled_node_name_regex: List of regex patterns for user-specified node names to keep in high precision. + """ + self.disabled_node_name_regex = disabled_node_name_regex + + def _check_inner(self, node): + stack = node.meta.get("nn_module_stack") + try: + # get the user specified name of the node + node_name = stack.get(next(reversed(stack)), [""])[0] + except Exception as e: + raise ValueError( + f"Failed to get the user specified name of the node {node.name} because {e}. Please file a bug with Torch-TensorRT." + ) + return any( + re.match(regex, node_name) for regex in self.disabled_node_name_regex + ) + + +class DisabledOpTypes(NodeRuleBase): + """Rule for keeping nodes with specific ATen ops in high precision.""" + + def __init__(self, excluded_ops): + """Initialize the rule. + + Args: + excluded_ops: List of ATen ops that should remain in FP32. + """ + self.excluded_ops = excluded_ops + + def _check_inner(self, node): + return node.target in self.excluded_ops + + +class IORangeRule(NodeRuleBase): + """Rule for keeping nodes with out-of-range inputs/outputs in high precision.""" + + def __init__(self, max_output_threshold, reference_data): + """Initialize the rule. + + Args: + max_output_threshold: Maximum absolute value allowed for node I/O. + reference_data: Reference data for checking I/O ranges. + """ + self.max_output_threshold = max_output_threshold + self.reference_data = reference_data + self.output_data = None + + def _check_inner(self, node): + def is_io_out_of_range(node): + tensor_name = node.name + if tensor_name not in self.reference_data: + logger.debug( + f"Node {node.name}: Tensor {tensor_name} not found in reference data. Skipping I/O range check." + ) + return False + ref_data = self.reference_data[tensor_name] + if ref_data.numel() == 0: + logger.debug( + f"Node {node.name}: Tensor {tensor_name} has 0 elements. Skipping I/O range check." + ) + return False + logger.debug( + f"Node {node.name}: reference data: min={ref_data.min()}, max={ref_data.max()}" + ) + if torch.any(torch.abs(ref_data) > self.max_output_threshold): + self.output_data = ref_data + return True + + if self.reference_data: + for in_node in node.all_input_nodes: + if is_io_out_of_range(in_node): + return True + for out_node in list(node.users): + if is_io_out_of_range(out_node): + return True + return False + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes with I/O range violations.""" + if self.output_data is not None: + logger.info( + f"Skipping node {node.name}: reference IO out of range: min={torch.min(self.output_data)}, " + f"max={torch.max(self.output_data)}, range=[{-self.max_output_threshold}, {self.max_output_threshold}]" + ) + else: + super()._log_skipped(node, **kwargs) + + +class DepthOfReductionRule(NodeRuleBase): + """ + Rule for keeping nodes with high depth of reduction in high precision. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. + Reduction ops are those that aggregate data across one or more axes, decreasing the dimensionality of the input tensor, such as convolution, gemm, etc. + """ + + def __init__(self, max_depth_of_reduction, reference_data): + """Initialize the rule. + + Args: + max_depth_of_reduction: Maximum depth of reduction allowed in low precision. + reference_data: Reference data for checking I/O ranges. + """ + self.max_depth_of_reduction = max_depth_of_reduction + self.reference_data = reference_data + self.reduction_depth = 0 + + def _get_tensor_shape(self, tensor_name): + """Get tensor shape from reference data.""" + if tensor_name in self.reference_data: + return self.reference_data[tensor_name].shape + return None + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes with depth of reduction violations.""" + if self.reduction_depth > 0: + logger.info( + f"Skipping node {node.name}: depth of reduction {self.reduction_depth} exceeds " + f"{self.max_depth_of_reduction}." + ) + else: + super()._log_skipped(node, **kwargs) + + def _check_inner(self, node): + # All reduction ops rely on shape of input[0] + input_0_dims = ( + self._get_tensor_shape(node.all_input_nodes[0].name) + if len(node.all_input_nodes) > 0 + else None + ) + if input_0_dims is None: + return False + self.reduction_depth = 0 + if node.target in [ + torch.ops.aten.scaled_dot_product_attention.default, + ]: + # Attention: input (batch_size, sequence_length, hidden_size) + # or (batch_size, kv_num_heads, total_sequence_length, head_size) + assert len(input_0_dims) == 3 or len(input_0_dims) == 4 + hidden_size = ( + input_0_dims[2] + if len(input_0_dims) == 3 + else input_0_dims[1] * input_0_dims[3] + ) + self.reduction_depth = hidden_size + elif node.target in [ + torch.ops.aten.convolution.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, + ]: + # Conv: input (N x C x D1 x D2 ... x Dn) + # weight (out_channels, in_channels, kD1, kD2, ... kDn) + # Reduction depth = in_channels * kernel_volume + weight_shape = ( + self._get_tensor_shape(node.all_input_nodes[1].name) + if len(node.all_input_nodes) > 1 + else None + ) + if weight_shape is None: + return False + in_channels = weight_shape[1] + kernel_volume = torch.prod(weight_shape[2:]) + self.reduction_depth = in_channels * kernel_volume + elif node.target in [ + torch.ops.aten.matmul, + torch.ops.aten.matmul.default, + torch.ops.aten.dot.default, + torch.ops.aten.mm.default, + torch.ops.aten.mv.default, + torch.ops.aten.bmm.default, + ]: + # GEMM: A (M, K) @ B (K, N) = C (M, N) + self.reduction_depth = input_0_dims[-1] + # TODO: Add more reduction ops here + return self.reduction_depth > self.max_depth_of_reduction + + +class NodeClassifier: + """Main class for classifying nodes into high and low precision groups.""" + + def __init__( + self, + nodes, + excluded_nodes: Collection[str] | None = None, + excluded_ops: Collection[torch.fx.node.Target] | None = None, + custom_rule: NodeRuleBase | None = None, + max_output_threshold: float | None = 512, + max_depth_of_reduction: int | None = None, + ): + """Initialize the node classifier. + + Args: + nodes: The nodes to classify. + nodes_to_exclude: Collection of regex patterns for node names to keep in high precision. + targets_to_exclude: Collection of targets to keep in high precision. + custom_rule: Optional custom classification rule. + max_output_threshold: Maximum absolute value allowed for node I/O. + max_depth_of_reduction: Maximum depth of reduction allowed in low precision. + """ + self.nodes = nodes + self.excluded_nodes = excluded_nodes + self.excluded_ops = excluded_ops + self.custom_rule = custom_rule + self.max_output_threshold = max_output_threshold + self.max_depth_of_reduction = max_depth_of_reduction + + def _gen_block_node_rules(self, reference_data): + """Generate list of rules for blocking nodes from precision conversion. + + Args: + reference_data: Reference data for checking I/O ranges. + + Returns: + list[NodeRuleBase]: List of rules to apply. + """ + block_node_rules: list[NodeRuleBase] = [] + if self.excluded_nodes: + block_node_rules.append(DisabledNodeNameRegexRule(self.excluded_nodes)) + if self.excluded_ops: + block_node_rules.append(DisabledOpTypes(self.excluded_ops)) + if reference_data: + block_node_rules.append( + IORangeRule(self.max_output_threshold, reference_data) + ) + if self.max_depth_of_reduction is not None: + block_node_rules.append( + DepthOfReductionRule( + self.max_depth_of_reduction, + reference_data, + ) + ) + if self.custom_rule: + block_node_rules.append(self.custom_rule) + return block_node_rules + + def run( + self, ref_outputs_dict: Optional[dict[str, torch.Tensor]] = None + ) -> tuple[list[str], list[str]]: + """Run node classification. + + Args: + ref_outputs_dict: Optional tensors' reference data. + + Returns: + tuple: Lists of node names (low_precision_nodes, high_precision_nodes). + """ + block_node_rules = self._gen_block_node_rules(ref_outputs_dict) + low_precision_nodes = [] + high_precision_nodes = [] + for node in self.nodes: + if node.op == "call_function": + if ( + node.target == torch.ops.higher_order.wrap_with_autocast + or node.target == operator.getitem + ): + continue + # If any condition is met - node will be executed in high precision + if any(rule.check(node) for rule in block_node_rules): + high_precision_nodes.append(node.name) + else: + low_precision_nodes.append(node.name) + logger.debug(f"Low Precision Nodes: {low_precision_nodes}") + logger.debug(f"High Precision Nodes: {high_precision_nodes}") + return low_precision_nodes, high_precision_nodes diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py index 1736a234a2..478c221872 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, Dict, List, Sequence import torch @@ -68,3 +68,42 @@ def is_node_complex(node: torch.fx.Node, complexNodes): complexNodes[node.name] = True return True return False + + +def trace_intermediate_node_outputs( + gm: torch.fx.GraphModule, + calibration_dataloader: torch.utils.data.DataLoader, + excluded_ops: Sequence[torch.fx.node.Target] = [], +) -> Dict[str, torch.Tensor]: + """Trace the intermediate node outputs of a graph module. + + Args: + gm (torch.fx.GraphModule): The graph module to trace the intermediate node outputs of. + calibration_dataloader (torch.utils.data.DataLoader): The dataloader to use for tracing. + excluded_ops (Set[torch.fx.node.Target]): The set of ATen ops that should be excluded from the trace. For example, `{torch.ops.higher_order.wrap_with_autocast, operator.getitem}`. Default is an empty set. + + Returns: + Dict[str, torch.Tensor]: A dictionary of intermediate node outputs. The key is the node name and the value is the tensor. + """ + + intermediate_node_outputs: Dict[str, torch.Tensor] = {} + + class IntermediateNodeTracer(torch.fx.Interpreter): # type: ignore[misc] + def run_node(self, n: torch.fx.Node) -> Any: + out = super().run_node(n) + if n.op == "call_function" and n.target not in excluded_ops: + if not isinstance(out, torch.Tensor): + return out + if n.name in intermediate_node_outputs: + intermediate_node_outputs[n.name] = torch.cat( + [intermediate_node_outputs[n.name], out], dim=0 + ) + else: + intermediate_node_outputs[n.name] = out + return out + + if calibration_dataloader is not None: + tracer = IntermediateNodeTracer(gm) + for batch in calibration_dataloader: + tracer.run(tuple(batch)) + return intermediate_node_outputs diff --git a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py new file mode 100644 index 0000000000..b6266e5787 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py @@ -0,0 +1,122 @@ +import logging +import operator +from typing import Any + +import torch +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo._settings import CompilationSettings + +from .nodeclassifier import NodeClassifier +from .pass_utils import clean_up_graph_after_modifications + +logger = logging.getLogger(__name__) + + +def is_tensor_node(n: torch.fx.Node) -> bool: + val = n.meta.get("val", None) + if hasattr(val, "dtype"): + return True + return False + + +def rule_based_autocast( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Rule-based autocast""" + if not settings.enable_autocast: + logger.debug("Autocast is not enabled, skipping rule-based autocast.") + return gm + + # get config from settings + autocast_low_precision_type = settings.autocast_low_precision_type + if autocast_low_precision_type is None: + return gm + if isinstance(autocast_low_precision_type, dtype): + autocast_low_precision_type = autocast_low_precision_type.to(torch.dtype) + autocast_high_precision_type = torch.float32 + autocast_excluded_nodes = settings.autocast_excluded_nodes + autocast_excluded_ops = settings.autocast_excluded_ops + autocast_max_output_threshold = settings.autocast_max_output_threshold + autocast_max_depth_of_reduction = settings.autocast_max_depth_of_reduction + reference_data: dict[str, torch.Tensor] = ( + settings.autocast_intermediate_node_outputs + ) + + node_classifier = NodeClassifier( + gm.graph.nodes, + excluded_nodes=autocast_excluded_nodes, + excluded_ops=autocast_excluded_ops, + max_output_threshold=autocast_max_output_threshold, + max_depth_of_reduction=autocast_max_depth_of_reduction, + ) + low_precision_nodes, high_precision_nodes = node_classifier.run(reference_data) + + def _cast_all_tensor_args_to_dtype( + node: torch.fx.Node, arg: Any, dtype: torch.dtype + ) -> Any: + """Cast all tensor args to the given dtype + + Args: + node: The node to insert the cast before + arg: The argument to cast + dtype: The dtype to cast to + + Returns: + The casted argument + """ + if isinstance(arg, torch.fx.Node) and is_tensor_node(arg): + val = arg.meta.get("val", None) + if isinstance(val, torch.Tensor): + if val.dtype == dtype: + return arg + else: + with gm.graph.inserting_before(node): + cast = gm.graph.call_function( + torch.ops.aten.to.dtype, args=(arg, dtype) + ) + # copy the meta of the original tensor to the casted tensor + cast.meta.update(arg.meta) + # update the dtype of the casted tensor + cast.meta["val"] = cast.meta["val"].to(dtype) + return cast + elif isinstance(arg, (tuple, list)): + return type(arg)( + _cast_all_tensor_args_to_dtype(node, a, dtype) for a in arg + ) + elif isinstance(arg, dict): + return { + k: _cast_all_tensor_args_to_dtype(node, v, dtype) + for k, v in arg.items() + } + else: + return arg + + for node in list(gm.graph.nodes): + if node.op == "call_function": + if ( + node.target == torch.ops.higher_order.wrap_with_autocast + or node.target == operator.getitem + ): + continue + + if node.name in low_precision_nodes: + node.args = _cast_all_tensor_args_to_dtype( + node, node.args, autocast_low_precision_type + ) + node.kwargs = _cast_all_tensor_args_to_dtype( + node, node.kwargs, autocast_low_precision_type + ) + node.meta["val"] = node.meta["val"].to(autocast_low_precision_type) + elif node.name in high_precision_nodes: + node.args = _cast_all_tensor_args_to_dtype( + node, node.args, autocast_high_precision_type + ) + node.kwargs = _cast_all_tensor_args_to_dtype( + node, node.kwargs, autocast_high_precision_type + ) + node.meta["val"] = node.meta["val"].to(autocast_high_precision_type) + + gm = clean_up_graph_after_modifications(gm) + logger.debug("Graph after Autocast based on the rules:\n%s", gm.graph) + + return gm diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..a946f38761 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -279,6 +279,7 @@ def setup_engine(self) -> None: dtype._from(self.engine.get_tensor_dtype(input_name)) for input_name in self.input_names ] + self.input_shapes = [ self.engine.get_tensor_shape(input_name) for input_name in self.input_names ] diff --git a/tests/py/dynamo/models/test_autocast.py b/tests/py/dynamo/models/test_autocast.py new file mode 100644 index 0000000000..c3ceac3bc3 --- /dev/null +++ b/tests/py/dynamo/models/test_autocast.py @@ -0,0 +1,369 @@ +import pytest +import torch +import torch.nn as nn +import torch_tensorrt + + +@pytest.mark.unit +@pytest.mark.critical +def test_no_pytorch_autocast(): + class NoPytorchAutocastModel(nn.Module): + def __init__(self): + super(NoPytorchAutocastModel, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 + ) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d( + in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(16 * 8 * 8, 10) + + def forward(self, x): + out1 = self.conv1(x) + out2 = self.relu1(out1) + out3 = self.pool1(out2) + out4 = self.conv2(out3) + out5 = self.relu2(out4) + out6 = self.pool2(out5) + out7 = self.flatten(out6) + out8 = self.fc1(out7) + out9 = torch.add(out8, out8) + return x, out1, out2, out3, out4, out5, out6, out7, out8, out9 + + model = NoPytorchAutocastModel().cuda().eval() + inputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device="cuda"),) + ep = torch.export.export(model, inputs) + calibration_dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False + ) + + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=".", + engine_builder_monitor=False, + ): + trt_autocast_mod = torch_tensorrt.compile( + ep.module(), + arg_inputs=inputs, + min_block_size=1, + use_python_runtime=True, + use_explicit_typing=True, + enable_autocast=True, + autocast_low_precision_type=torch.float16, + autocast_excluded_nodes={"^conv1$", "relu"}, + autocast_excluded_ops={torch.ops.aten.flatten.using_ints}, + autocast_max_output_threshold=512, + autocast_max_depth_of_reduction=None, + autocast_calibration_dataloader=calibration_dataloader, + ) + + autocast_outs = trt_autocast_mod(*inputs) + pytorch_outs = model(*inputs) + + should_be_fp32 = [ + autocast_outs[0], + autocast_outs[1], + autocast_outs[2], + autocast_outs[5], + autocast_outs[7], + ] + should_be_fp16 = [ + autocast_outs[3], + autocast_outs[4], + autocast_outs[6], + autocast_outs[8], + autocast_outs[9], + ] + assert all( + a.dtype == torch.float32 for a in should_be_fp32 + ), "Some Autocast outputs are not float32!" + assert all( + a.dtype == torch.float16 for a in should_be_fp16 + ), "Some Autocast outputs are not float16!" + for i, (a, w) in enumerate(zip(autocast_outs, pytorch_outs)): + assert torch.allclose( + a.to(torch.float32), w.to(torch.float32), atol=1e-2, rtol=1e-2 + ), f"Autocast and Pytorch outputs do not match! autocast_outs[{i}] = {a}, pytorch_outs[{i}] = {w}" + + +@pytest.mark.unit +@pytest.mark.critical +def test_whole_pytorch_autocast(): + class WholePytorchAutocastModel(nn.Module): + def __init__(self): + super(WholePytorchAutocastModel, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 + ) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d( + in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(16 * 8 * 8, 10) + + def forward(self, x): + with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): + out1 = self.conv1(x) + out2 = self.relu1(out1) + out3 = self.pool1(out2) + out4 = self.conv2(out3) + out5 = self.relu2(out4) + out6 = self.pool2(out5) + out7 = self.flatten(out6) + out8 = self.fc1(out7) + out9 = torch.log( + torch.abs(out8) + 1 + ) # log is fp32 due to Pytorch Autocast requirements + return x, out1, out2, out3, out4, out5, out6, out7, out8, out9 + + model = WholePytorchAutocastModel().cuda().eval() + inputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device="cuda"),) + ep = torch.export.export(model, inputs) + calibration_dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False + ) + + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=".", + engine_builder_monitor=False, + ): + trt_autocast_mod = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=inputs, + min_block_size=1, + use_python_runtime=True, + use_explicit_typing=True, + # Torch-TensorRT's autocast doesn't affect layers inside Pytorch autocast + enable_autocast=True, + autocast_low_precision_type=torch.bfloat16, + autocast_excluded_nodes={"^conv1$", "relu"}, + autocast_excluded_ops={torch.ops.aten.flatten.using_ints}, + autocast_max_output_threshold=512, + autocast_max_depth_of_reduction=None, + autocast_calibration_dataloader=calibration_dataloader, + ) + + autocast_outs = trt_autocast_mod(*inputs) + pytorch_outs = model(*inputs) + + should_be_fp32 = [autocast_outs[0], autocast_outs[9]] + should_be_fp16 = [autocast_outs[i] for i in range(1, 9)] + assert all( + a.dtype == torch.float32 for a in should_be_fp32 + ), "Some Autocast outputs are not float32!" + assert all( + a.dtype == torch.float16 for a in should_be_fp16 + ), "Some Autocast outputs are not float16!" + for i, (a, w) in enumerate(zip(autocast_outs, pytorch_outs)): + assert torch.allclose( + a.to(torch.float32), w.to(torch.float32), atol=1e-2, rtol=1e-2 + ), f"Autocast and Pytorch outputs do not match! autocast_outs[{i}] = {a}, pytorch_outs[{i}] = {w}" + + +@pytest.mark.unit +@pytest.mark.critical +def test_mixed_pytorch_autocast(): + class MixedPytorchAutocastModel(nn.Module): + def __init__(self): + super(MixedPytorchAutocastModel, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 + ) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d( + in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(16 * 8 * 8, 10) + + def forward(self, x): + out1 = self.conv1(x) + out2 = self.relu1(out1) + out3 = self.pool1(out2) + out4 = self.conv2(out3) + out5 = self.relu2(out4) + out6 = self.pool2(out5) + out7 = self.flatten(out6) + with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): + out8 = self.fc1(out7) + out9 = torch.log( + torch.abs(out8) + 1 + ) # log is fp32 due to Pytorch Autocast requirements + return x, out1, out2, out3, out4, out5, out6, out7, out8, out9 + + model = MixedPytorchAutocastModel().cuda().eval() + inputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device="cuda"),) + ep = torch.export.export(model, inputs) + calibration_dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False + ) + + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=".", + engine_builder_monitor=False, + ): + trt_autocast_mod = torch_tensorrt.compile( + ep.module(), + arg_inputs=inputs, + min_block_size=1, + use_python_runtime=True, + use_explicit_typing=True, + # Torch-TensorRT's autocast doesn't affect layers inside Pytorch autocast + enable_autocast=True, + autocast_low_precision_type=torch.bfloat16, + autocast_excluded_nodes={"^conv1$", "relu"}, + autocast_excluded_ops={torch.ops.aten.flatten.using_ints}, + autocast_max_output_threshold=512, + autocast_max_depth_of_reduction=None, + autocast_calibration_dataloader=calibration_dataloader, + ) + + autocast_outs = trt_autocast_mod(*inputs) + pytorch_outs = model(*inputs) + + should_be_fp32 = [ + autocast_outs[0], + autocast_outs[1], + autocast_outs[2], + autocast_outs[5], + autocast_outs[7], + autocast_outs[9], + ] + should_be_fp16 = [ + autocast_outs[8], + ] + should_be_bf16 = [autocast_outs[3], autocast_outs[4], autocast_outs[6]] + assert all( + a.dtype == torch.float32 for a in should_be_fp32 + ), "Some Autocast outputs are not float32!" + assert all( + a.dtype == torch.float16 for a in should_be_fp16 + ), "Some Autocast outputs are not float16!" + assert all( + a.dtype == torch.bfloat16 for a in should_be_bf16 + ), "Some Autocast outputs are not bfloat16!" + for i, (a, w) in enumerate(zip(autocast_outs, pytorch_outs)): + assert torch.allclose( + a.to(torch.float32), w.to(torch.float32), atol=1e-2, rtol=1e-2 + ), f"Autocast and Pytorch outputs do not match! autocast_outs[{i}] = {a}, pytorch_outs[{i}] = {w}" + + +@pytest.mark.unit +@pytest.mark.critical +def test_nested_pytorch_autocast(): + class NestedPytorchAutocastModel(nn.Module): + def __init__(self): + super(NestedPytorchAutocastModel, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 + ) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d( + in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(16 * 8 * 8, 10) + + def forward(self, x, y): + out1 = self.conv1( + x + ) # fp32 because of "^conv1$" in `autocast_excluded_nodes` + out2 = self.relu1( + out1 + ) # fp32 because of "relu" in `autocast_excluded_nodes` + out3 = self.pool1(out2) # bf16 + out4 = self.conv2(out3) # bf16 + out5 = self.relu2( + out4 + ) # fp32 because of "relu" in `autocast_excluded_nodes` + out6 = self.pool2(out5) # bf16 + out7 = self.flatten( + out6 + ) # fp32 because of `torch.ops.aten.flatten.using_ints` in `autocast_excluded_ops` + # Respect the precisions in the pytorch autocast context + with torch.autocast(x.device.type, enabled=True, dtype=torch.float32): + out8 = self.fc1(out7) # fp32 + with torch.autocast(x.device.type, enabled=False): + out9 = torch.sub(out8.half(), y) # fp16 + out10 = torch.add(out9, out9) # fp16 + with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): + out11 = torch.log( + torch.abs(out10) + 1 + ) # fp32 because Pytorch Autocast requires `log` to be fp32 + return x, out1, out2, out3, out4, out5, out6, out7, out8, out9, out10, out11 + + model = NestedPytorchAutocastModel().cuda().eval() + inputs = ( + torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"), + torch.randn((1,), dtype=torch.float16, device="cuda"), + ) + ep = torch.export.export(model, inputs) + calibration_dataloader = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False + ) + + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=".", + engine_builder_monitor=False, + ): + trt_autocast_mod = torch_tensorrt.dynamo.compile( + ep, + arg_inputs=inputs, + min_block_size=1, + use_python_runtime=True, + use_explicit_typing=True, + # Torch-TensorRT's autocast doesn't affect layers inside Pytorch autocast + enable_autocast=True, + autocast_low_precision_type=torch.bfloat16, + autocast_excluded_nodes={"^conv1$", "relu"}, + autocast_excluded_ops={torch.ops.aten.flatten.using_ints}, + autocast_max_output_threshold=512, + autocast_max_depth_of_reduction=None, + autocast_calibration_dataloader=calibration_dataloader, + ) + + autocast_outs = trt_autocast_mod(*inputs) + pytorch_outs = model(*inputs) + + should_be_fp32 = [ + autocast_outs[0], + autocast_outs[1], + autocast_outs[2], + autocast_outs[5], + autocast_outs[7], + autocast_outs[8], + autocast_outs[11], + ] + should_be_fp16 = [autocast_outs[9], autocast_outs[10]] + should_be_bf16 = [autocast_outs[3], autocast_outs[4], autocast_outs[6]] + assert all( + a.dtype == torch.float32 for a in should_be_fp32 + ), "Some Autocast outputs are not float32!" + assert all( + a.dtype == torch.float16 for a in should_be_fp16 + ), "Some Autocast outputs are not float16!" + assert all( + a.dtype == torch.bfloat16 for a in should_be_bf16 + ), "Some Autocast outputs are not bfloat16!" + for i, (a, w) in enumerate(zip(autocast_outs, pytorch_outs)): + assert torch.allclose( + a.to(torch.float32), w.to(torch.float32), atol=1e-2, rtol=1e-2 + ), f"Autocast and Pytorch outputs do not match! autocast_outs[{i}] = {a}, pytorch_outs[{i}] = {w}"