|
| 1 | + |
| 2 | + |
| 3 | +# @manual=//deeplearning/trt/python:py_tensorrt |
| 4 | +import tensorrt as trt |
| 5 | +import torch |
| 6 | + |
| 7 | +from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer |
| 8 | +from torch_tensorrt.fx import InputTensorSpec |
| 9 | +from torch_tensorrt.fx import TRTInterpreter |
| 10 | +from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem |
| 11 | +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter |
| 12 | +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting |
| 13 | +from torch_tensorrt.fx.trt_module import TRTModule |
| 14 | +from torch_tensorrt.fx.utils import LowerPrecision |
| 15 | + |
| 16 | + |
| 17 | +def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums.dtype.float]), **kwargs): |
| 18 | + """Compile a PyTorch module through fx |
| 19 | +
|
| 20 | + Takes a existing PyTorch module and a set of settings to configure the compiler |
| 21 | + and using the path specified in ``ir`` lower and compile the module to TensorRT |
| 22 | + returning a PyTorch Module back |
| 23 | +
|
| 24 | + Converts specifically the forward method of a Module |
| 25 | +
|
| 26 | + Arguments: |
| 27 | + module (torch.nn.Module): Source module |
| 28 | +
|
| 29 | + Keyword Arguments: |
| 30 | + inputs (List[torch.Tensor]): for fixed shape scenario, inputs shapes can not change |
| 31 | + enabled_precision (torch.dtype): The datatype that TensorRT can use when selecting kernels. If torch.float is chosen, the kernel is running with fp32; If torch.float16 is chosen, the kernel is running with fp16 or fp32 which selected by TensorRT |
| 32 | + ir (str): The requested strategy to compile. (default is ts - TorchScript with scripting path, fx is FX based path) |
| 33 | + **kwargs: Additional settings for the specific requested strategy (See submodules for more info) |
| 34 | +
|
| 35 | + Returns: |
| 36 | + torch.nn.Module: Compiled Module, when run it will execute via TensorRT |
| 37 | + """ |
| 38 | + acc_model = acc_tracer.trace(module, inputs) |
| 39 | + |
| 40 | + splitter_setting = TRTSplitterSetting() |
| 41 | + splitter_setting.use_implicit_batch_dim = False |
| 42 | + splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) |
| 43 | + splitter.node_support_preview() |
| 44 | + split_mod = splitter() |
| 45 | + num_piece = 0 |
| 46 | + for name, _ in split_mod.named_children(): |
| 47 | + print(f"graph is split into {name}") |
| 48 | + num_piece += 1 |
| 49 | + |
| 50 | + # if the graph module is split into pieces larger than 8, we consider its perf |
| 51 | + # is not good and fall back to non-TRT |
| 52 | + if num_piece > 8: |
| 53 | + print( |
| 54 | + f"The graph module is split into {num_piece} which is large than the \ |
| 55 | + threshold=8. Fall back to non-TRT module." |
| 56 | + ) |
| 57 | + return None |
| 58 | + |
| 59 | + if torch.float16 in enabled_precisions or torch.half in enabled_precisions: |
| 60 | + precision = LowerPrecision.FP16 |
| 61 | + else: |
| 62 | + precision = LowerPrecision.FP32 |
| 63 | + |
| 64 | + def get_submod_inputs(mod, submod, inputs): |
| 65 | + acc_inputs = None |
| 66 | + |
| 67 | + def get_input(self, inputs): |
| 68 | + nonlocal acc_inputs |
| 69 | + acc_inputs = inputs |
| 70 | + |
| 71 | + handle = submod.register_forward_pre_hook(get_input) |
| 72 | + mod(*inputs) |
| 73 | + handle.remove() |
| 74 | + return acc_inputs |
| 75 | + |
| 76 | + for name, _ in split_mod.named_children(): |
| 77 | + if "_run_on_acc" in name: |
| 78 | + submod = getattr(split_mod, name) |
| 79 | + # Get submodule inputs for fx2trt |
| 80 | + acc_inputs = get_submod_inputs(split_mod, submod, inputs) |
| 81 | + |
| 82 | + # fx2trt replacement |
| 83 | + interp = TRTInterpreter( |
| 84 | + submod, |
| 85 | + InputTensorSpec.from_tensors(acc_inputs), |
| 86 | + explicit_batch_dimension=True, |
| 87 | + ) |
| 88 | + r = interp.run( |
| 89 | + max_workspace_size=20 << 30, |
| 90 | + lower_precision=precision, |
| 91 | + # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile |
| 92 | + ) |
| 93 | + # For profile |
| 94 | + # from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module |
| 95 | + # profile_trt_module("", trt_mod, acc_inputs) |
| 96 | + trt_mod = TRTModule(*r) |
| 97 | + |
| 98 | + setattr(split_mod, name, trt_mod) |
| 99 | + else: |
| 100 | + submod = getattr(split_mod, name) |
| 101 | + return split_mod |
0 commit comments