Skip to content

Commit 3932533

Browse files
add capture and replay feature in tensorrt (#3849)
1 parent fc69b16 commit 3932533

File tree

5 files changed

+143
-0
lines changed

5 files changed

+143
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
Introduction
2+
============
3+
4+
This toolchain captures TensorRT network creation and build parameters at runtime via a shim, then deterministically replays them to reproduce an engine build. Use it to debug or reproduce builds independent of the originating framework.
5+
6+
Prerequisites
7+
-------------
8+
9+
- TensorRT installed (ensure you know the absolute path to its ``lib`` and ``bin`` directories)
10+
- ``libtensorrt_shim.so`` available in your TensorRT ``lib`` directory
11+
- ``tensorrt_player`` available in your TensorRT ``bin`` directory
12+
13+
Quick start: Capture
14+
--------------------
15+
16+
.. code-block:: bash
17+
18+
TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1 python test.py
19+
20+
You should see ``shim.json`` and ``shim.bin`` generated in ``/tmp/torch_tensorrt_{current_user}/shim``.
21+
22+
Replay: Build the engine from the capture
23+
-----------------------------------------
24+
25+
Use ``tensorrt_player`` to replay the captured build without the original framework:
26+
27+
.. code-block:: bash
28+
29+
tensorrt_player -j /absolute/path/to/shim.json -o /absolute/path/to/output_engine
30+
31+
This produces a serialized TensorRT engine at ``output_engine``.
32+
33+
Validate the engine
34+
-------------------
35+
36+
Run the engine with ``trtexec``:
37+
38+
.. code-block:: bash
39+
40+
trtexec --loadEngine=/absolute/path/to/output_engine
41+
42+
Notes
43+
-----
44+
45+
- Ensure the ``libnvinfer.so`` used by the shim matches the TensorRT version in your environment.
46+
- If multiple TensorRT versions are installed, prefer absolute paths as shown above.
47+
- Currently, it is not supported to capture multiple engines, in case of graph break, only the first engine will be captured.
48+
49+

docsrc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Getting Started
2929
getting_started/jetpack
3030
getting_started/quick_start
3131
getting_started/tensorrt_rtx
32+
getting_started/capture_and_replay
3233

3334
User Guide
3435
------------

py/torch_tensorrt/_TensorRTProxyModule.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import ctypes
22
import importlib
33
import importlib.util
4+
import logging
45
import os
56
import platform
7+
import pwd
68
import sys
9+
import tempfile
710
from types import ModuleType
811
from typing import Any, Dict, List
912

13+
_LOGGER = logging.getLogger(__name__)
1014
package_imported = False
1115
package_name = ""
1216

@@ -28,6 +32,66 @@ def _find_lib(name: str, paths: List[str]) -> str:
2832
raise FileNotFoundError(f"Could not find {name}\n Search paths: {paths}")
2933

3034

35+
def enable_capture_tensorrt_api_recording() -> None:
36+
37+
os_env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None)
38+
if os_env_flag is None or (os_env_flag != "1" and os_env_flag.lower() != "true"):
39+
_LOGGER.debug("Capturing TensorRT API calls is not enabled")
40+
return
41+
if not sys.platform.startswith("linux"):
42+
_LOGGER.warning(
43+
f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring the capture_tensorrt_api_recording setting for {sys.platform}"
44+
)
45+
os.environ.pop("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE")
46+
return
47+
48+
linux_lib_path = []
49+
if "LD_LIBRARY_PATH" in os.environ:
50+
linux_lib_path.extend(os.environ["LD_LIBRARY_PATH"].split(os.path.pathsep))
51+
52+
if platform.uname().processor == "x86_64":
53+
linux_lib_path.append("/usr/lib/x86_64-linux-gnu")
54+
elif platform.uname().processor == "aarch64":
55+
linux_lib_path.append("/usr/lib/aarch64-linux-gnu")
56+
57+
for path in linux_lib_path:
58+
if os.path.isfile(os.path.join(path, "libtensorrt_shim.so")):
59+
try:
60+
ctypes.CDLL(
61+
os.path.join(path, "libtensorrt_shim.so"), mode=ctypes.RTLD_GLOBAL
62+
)
63+
tensorrt_lib_path = path
64+
break
65+
except Exception as e:
66+
continue
67+
68+
if tensorrt_lib_path is None:
69+
_LOGGER.warning(
70+
"Capturing TensorRT API calls is enabled, but libtensorrt_shim.so is not found, make sure TensorRT lib is in the LD_LIBRARY_PATH, therefore ignoring the capture_tensorrt_api_recording setting"
71+
)
72+
os.environ.pop("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE")
73+
else:
74+
os.environ["TRT_SHIM_NVINFER_LIB_NAME"] = os.path.join(
75+
tensorrt_lib_path, "libnvinfer.so"
76+
)
77+
current_user = pwd.getpwuid(os.getuid())[0]
78+
shim_temp_dir = os.path.join(
79+
tempfile.gettempdir(), f"torch_tensorrt_{current_user}/shim"
80+
)
81+
os.makedirs(shim_temp_dir, exist_ok=True)
82+
json_file_name = os.path.join(shim_temp_dir, "shim.json")
83+
os.environ["TRT_SHIM_OUTPUT_JSON_FILE"] = json_file_name
84+
bin_file_name = os.path.join(shim_temp_dir, "shim.bin")
85+
# if exists, delete the file, so that we can capture the new one
86+
if os.path.exists(json_file_name):
87+
os.remove(json_file_name)
88+
if os.path.exists(bin_file_name):
89+
os.remove(bin_file_name)
90+
_LOGGER.info(
91+
f"Capturing TensorRT API calls feature is enabled and the captured output is in the {shim_temp_dir} directory"
92+
)
93+
94+
3195
# TensorRTProxyModule is a proxy module that allows us to register the tensorrt or tensorrt-rtx package
3296
# since tensorrt-rtx is the drop-in replacement for tensorrt, we can use the same interface to use tensorrt-rtx
3397
class TensorRTProxyModule(ModuleType):
@@ -86,6 +150,11 @@ def alias_tensorrt() -> None:
86150
if use_rtx_env_var.lower() == "true":
87151
use_rtx = True
88152
package_name = "tensorrt_rtx" if use_rtx else "tensorrt"
153+
154+
if not use_rtx:
155+
# enable capture tensorrt api recording has to be done before importing the tensorrt library
156+
enable_capture_tensorrt_api_recording()
157+
89158
# Import the appropriate package
90159
try:
91160
target_module = importlib.import_module(package_name)

py/torch_tensorrt/dynamo/debug/_Debugger.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import logging
44
import os
5+
import sys
56
import tempfile
67
from logging.config import dictConfig
78
from typing import Any, List, Optional
@@ -32,6 +33,7 @@ def __init__(
3233
capture_fx_graph_before: Optional[List[str]] = None,
3334
capture_fx_graph_after: Optional[List[str]] = None,
3435
save_engine_profile: bool = False,
36+
capture_tensorrt_api_recording: bool = False,
3537
profile_format: str = "perfetto",
3638
engine_builder_monitor: bool = True,
3739
logging_dir: str = DEBUG_LOGGING_DIR,
@@ -49,6 +51,9 @@ def __init__(
4951
after execution of a lowering pass. Defaults to None.
5052
save_engine_profile (bool): Whether to save TensorRT engine profiling information.
5153
Defaults to False.
54+
capture_tensorrt_api_recording (bool): Whether to enable the capture TensorRT API recording feature, when this is enabled, it will output the catputure TensorRT API recording in the /tmp/torch_tensorrt_{current_user}/shim directory.
55+
It is part of the TensorRT capture and replay feature, the captured output will be able to replay for debug purpose.
56+
Defaults to False.
5257
profile_format (str): Format for profiling data. Choose from 'perfetto', 'trex', 'cudagraph'.
5358
If you need to generate engine graph using the profiling files, set it to 'trex' and use the C++ runtime.
5459
If you need to generate cudagraph visualization, set it to 'cudagraph'.
@@ -65,6 +70,7 @@ def __init__(
6570
self.cfg = DebuggerConfig(
6671
log_level=log_level,
6772
save_engine_profile=save_engine_profile,
73+
capture_tensorrt_api_recording=capture_tensorrt_api_recording,
6874
engine_builder_monitor=engine_builder_monitor,
6975
logging_dir=logging_dir,
7076
profile_format=profile_format,
@@ -92,6 +98,23 @@ def __init__(
9298
self.capture_fx_graph_before = capture_fx_graph_before
9399
self.capture_fx_graph_after = capture_fx_graph_after
94100

101+
if self.cfg.capture_tensorrt_api_recording:
102+
if not sys.platform.startswith("linux"):
103+
_LOGGER.warning(
104+
f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring the capture_tensorrt_api_recording setting for {sys.platform}"
105+
)
106+
elif ENABLED_FEATURES.tensorrt_rtx:
107+
_LOGGER.warning(
108+
"Capturing TensorRT API calls is not supported for TensorRT-RTX, therefore ignoring the capture_tensorrt_api_recording setting"
109+
)
110+
else:
111+
env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None)
112+
if env_flag is None or (env_flag != "1" and env_flag.lower() != "true"):
113+
_LOGGER.warning(
114+
"In order to capture TensorRT API calls, please invoke the script with environment variable TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1"
115+
)
116+
_LOGGER.info("Capturing TensorRT API calls feature is enabled")
117+
95118
def __enter__(self) -> None:
96119
self.original_lvl = _LOGGER.getEffectiveLevel()
97120
if ENABLED_FEATURES.torch_tensorrt_runtime:

py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
class DebuggerConfig:
88
log_level: str = "debug"
99
save_engine_profile: bool = False
10+
capture_tensorrt_api_recording: bool = False
1011
engine_builder_monitor: bool = True
1112
logging_dir: str = DEBUG_LOGGING_DIR
1213
profile_format: str = "perfetto"

0 commit comments

Comments
 (0)