|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import functools |
3 | 4 | import logging |
4 | 5 | import unittest |
5 | 6 | from typing import Any, Callable, Sequence |
6 | 7 |
|
7 | 8 | import torch |
8 | 9 | import torch._dynamo as td |
| 10 | +from torch._dynamo.backends.common import aot_autograd |
9 | 11 | from torch._dynamo.utils import detect_fake_mode |
10 | 12 | from torch._functorch.aot_autograd import aot_export_joint_simple |
11 | 13 | from torch_tensorrt.dynamo import CompilationSettings |
12 | 14 | from torch_tensorrt.dynamo._compiler import compile_module |
13 | 15 | from torch_tensorrt.dynamo.lowering import ( |
14 | 16 | get_decompositions, |
| 17 | + modify_reshape_complex_nodes, |
15 | 18 | post_lowering, |
16 | 19 | remove_detach, |
17 | 20 | remove_sym_nodes, |
@@ -49,7 +52,25 @@ def aot_torch_tensorrt_aten_backend( |
49 | 52 | gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any |
50 | 53 | ) -> torch.nn.Module: |
51 | 54 | settings, engine_cache = parse_dynamo_kwargs(kwargs) |
52 | | - return _pretraced_backend(gm, sample_inputs, settings, engine_cache) |
| 55 | + if settings.use_aot_joint_export: |
| 56 | + return _pretraced_backend(gm, sample_inputs, settings, engine_cache) |
| 57 | + logger.debug("Wrapping the backend with aot_autograd\n") |
| 58 | + _pretraced_backend_autograd = functools.partial( |
| 59 | + _pretraced_backend, settings=settings, engine_cache=engine_cache |
| 60 | + ) |
| 61 | + settings_aot_autograd = {} |
| 62 | + settings_aot_autograd["decompostions"] = get_decompositions( |
| 63 | + settings.enable_experimental_decompositions |
| 64 | + ) |
| 65 | + # This is added since detach lowering leads to alias nodes |
| 66 | + # Error - View operation returned a tensor that is the same as the input base tensor |
| 67 | + # torch nop_decompositions in torch/_decomp/decompositions.py |
| 68 | + if aten.detach in settings_aot_autograd["decompositions"]: |
| 69 | + del settings_aot_autograd["decompositions"][aten.detach] |
| 70 | + return aot_autograd( |
| 71 | + fw_compiler=_pretraced_backend_autograd, |
| 72 | + decompositions=get_decompositions(settings.enable_experimental_decompositions), |
| 73 | + )(gm, sample_inputs) |
53 | 74 |
|
54 | 75 |
|
55 | 76 | def _pretraced_backend( |
@@ -89,22 +110,39 @@ def _pretraced_backend( |
89 | 110 | # Remove detach nodes |
90 | 111 | remove_detach(gm, settings) |
91 | 112 |
|
| 113 | + complexInputIndices = [] |
| 114 | + for i, torch_input in enumerate(torch_inputs): |
| 115 | + if torch_inputs[i].dtype == torch.complex64: |
| 116 | + complexInputIndices.append(i) |
| 117 | + torch_input_real = torch_inputs[i].real |
| 118 | + torch_input_imaginary = torch_inputs[i].imag |
| 119 | + torch_inputs[i] = torch.stack( |
| 120 | + (torch_input_real, torch_input_imaginary), dim=-1 |
| 121 | + ) |
| 122 | + |
92 | 123 | # Invoke AOTAutograd to translate operators to aten |
93 | | - gm = aot_export_joint_simple( |
94 | | - gm, |
95 | | - sample_inputs, |
96 | | - trace_joint=False, |
97 | | - decompositions=get_decompositions( |
98 | | - settings.enable_experimental_decompositions |
99 | | - ), |
100 | | - ) |
| 124 | + if settings.use_aot_joint_export: |
| 125 | + gm = aot_export_joint_simple( |
| 126 | + gm, |
| 127 | + sample_inputs, |
| 128 | + trace_joint=False, |
| 129 | + decompositions=get_decompositions( |
| 130 | + settings.enable_experimental_decompositions |
| 131 | + ), |
| 132 | + ) |
101 | 133 |
|
102 | 134 | logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) |
103 | 135 |
|
104 | 136 | gm = post_lowering(gm, settings) |
105 | 137 |
|
106 | 138 | logger.debug("Lowered Input graph:\n " + str(gm.graph)) |
107 | 139 |
|
| 140 | + if complexInputIndices: |
| 141 | + modify_reshape_complex_nodes(gm, complexInputIndices) |
| 142 | + logger.debug( |
| 143 | + "Input graph after modifying complex nodes:\n " + str(gm.graph) |
| 144 | + ) |
| 145 | + |
108 | 146 | torchtrt_inputs = prepare_inputs( |
109 | 147 | torch_inputs, disable_memory_format_check=True |
110 | 148 | ) |
|
0 commit comments