|
10 | 10 | from torch._dynamo.backends.common import aot_autograd |
11 | 11 | from torch._dynamo.utils import detect_fake_mode |
12 | 12 | from torch._functorch.aot_autograd import aot_export_joint_simple |
| 13 | +from torch._ops import OpOverload |
13 | 14 | from torch_tensorrt.dynamo import CompilationSettings |
14 | 15 | from torch_tensorrt.dynamo._compiler import compile_module |
15 | 16 | from torch_tensorrt.dynamo.lowering import ( |
@@ -59,17 +60,17 @@ def aot_torch_tensorrt_aten_backend( |
59 | 60 | _pretraced_backend, settings=settings, engine_cache=engine_cache |
60 | 61 | ) |
61 | 62 | settings_aot_autograd = {} |
62 | | - settings_aot_autograd["decompostions"] = get_decompositions( |
| 63 | + settings_aot_autograd["decompositions"] = get_decompositions( |
63 | 64 | settings.enable_experimental_decompositions |
64 | 65 | ) |
65 | | - # This is added since detach lowering leads to alias nodes |
66 | | - # Error - View operation returned a tensor that is the same as the input base tensor |
67 | | - # torch nop_decompositions in torch/_decomp/decompositions.py |
68 | | - if aten.detach in settings_aot_autograd["decompositions"]: |
69 | | - del settings_aot_autograd["decompositions"][aten.detach] |
| 66 | + # transpose key deleted since not desirable to lower it to permute |
| 67 | + for key in settings_aot_autograd["decompositions"]: |
| 68 | + if "transpose" in key._name: |
| 69 | + to_delete = key |
| 70 | + del settings_aot_autograd["decompositions"][to_delete] |
70 | 71 | return aot_autograd( |
71 | 72 | fw_compiler=_pretraced_backend_autograd, |
72 | | - decompositions=get_decompositions(settings.enable_experimental_decompositions), |
| 73 | + decompositions=settings_aot_autograd["decompositions"], |
73 | 74 | )(gm, sample_inputs) |
74 | 75 |
|
75 | 76 |
|
|
0 commit comments