|
8 | 8 | from torch_tensorrt import EngineCapability, Device |
9 | 9 | from torch_tensorrt.fx.utils import LowerPrecision |
10 | 10 |
|
11 | | -from torch_tensorrt.dynamo.backend._settings import CompilationSettings |
12 | 11 | from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device |
13 | 12 | from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend |
14 | 13 | from torch_tensorrt.dynamo.backend._defaults import ( |
@@ -62,6 +61,10 @@ def compile( |
62 | 61 |
|
63 | 62 | inputs = prepare_inputs(inputs, prepare_device(device)) |
64 | 63 |
|
| 64 | + if not isinstance(enabled_precisions, collections.abc.Collection): |
| 65 | + enabled_precisions = [enabled_precisions] |
| 66 | + |
| 67 | + # Parse user-specified enabled precisions |
65 | 68 | if ( |
66 | 69 | torch.float16 in enabled_precisions |
67 | 70 | or torch_tensorrt.dtype.half in enabled_precisions |
@@ -123,19 +126,12 @@ def create_backend( |
123 | 126 | Returns: |
124 | 127 | Backend for torch.compile |
125 | 128 | """ |
126 | | - if debug: |
127 | | - logger.setLevel(logging.DEBUG) |
128 | | - |
129 | | - settings = CompilationSettings( |
| 129 | + return partial( |
| 130 | + torch_tensorrt_backend, |
130 | 131 | debug=debug, |
131 | 132 | precision=precision, |
132 | 133 | workspace_size=workspace_size, |
133 | 134 | min_block_size=min_block_size, |
134 | 135 | torch_executed_ops=torch_executed_ops, |
135 | 136 | pass_through_build_failures=pass_through_build_failures, |
136 | 137 | ) |
137 | | - |
138 | | - return partial( |
139 | | - torch_tensorrt_backend, |
140 | | - settings=settings, |
141 | | - ) |
0 commit comments