Skip to content

Commit 4aa8b7d

Browse files
ilyasherrajeevsrao
authored andcommitted
Update trt-engine-explorer to v0.1.6
Signed-off-by: Ilya Sherstyuk <isherstyuk@nvidia.com>
1 parent 3f293e5 commit 4aa8b7d

File tree

6 files changed

+59
-9
lines changed

6 files changed

+59
-9
lines changed

tools/experimental/trt-engine-explorer/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
Dates are in YYYY-MM-DD format.
44

5+
## v0.1.6 (2023-April)
6+
- Graph rendering:
7+
- Add node highlighting option.
8+
- Fix bug https://github.com/NVIDIA/TensorRT/issues/2779
9+
510
## v0.1.5 (2022-12-06)
611
- Updated requirements.txt for Ubuntu 20.04 and 22.04
712

tools/experimental/trt-engine-explorer/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ $ git clone https://github.com/NVIDIA/TensorRT.git
4343
```
4444

4545
### 2. Create and activate a Python virtual environment
46-
The commands listed below create and activate a Python virtual enviornment named ```env_trex``` which is stored in a directory by the same name, and configures the current shell to use it as the default python environment.
46+
The commands listed below create and activate a Python virtual environment named ```env_trex``` which is stored in a directory by the same name, and configures the current shell to use it as the default python environment.
4747

4848
```
4949
$ cd TensorRT/tools/experimental/trt-engine-explorer

tools/experimental/trt-engine-explorer/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def main():
3333

3434
setup(
3535
name="trex",
36-
version="0.1.5",
36+
version="0.1.6",
3737
description="TREX: TensorRT Engine Exploration Toolkit",
3838
long_description=open("README.md", "r", encoding="utf-8").read(),
3939
author="NVIDIA",

tools/experimental/trt-engine-explorer/trex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@
3030
from trex.compare_engines import *
3131
from trex.excel_summary import *
3232

33-
__version__ = "0.1.5"
33+
__version__ = "0.1.6"

tools/experimental/trt-engine-explorer/trex/graphing.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import os
2424
import re
25+
import warnings
2526
from enum import Enum
2627
from graphviz import Digraph
2728
from typing import Callable, NamedTuple, List, Dict
@@ -434,6 +435,17 @@ def add_io(tensors: List):
434435
return label
435436

436437

438+
def layer_node_highlighter(node_id: str, highlighted_layers_ids: List[int]
439+
) -> Dict:
440+
"""Highlight a layer node.
441+
442+
Create a yellow hailo around the node.
443+
"""
444+
should_highlight = highlighted_layers_ids and node_id in highlighted_layers_ids
445+
formatting = {'penwidth': str(6), 'color': 'yellow'}
446+
return formatting if should_highlight else {}
447+
448+
437449
def layer_node_configurable_renderer(
438450
layer: Layer,
439451
latency: float,
@@ -584,10 +596,12 @@ def handle_reformat(layer: Layer):
584596
except KeyError:
585597
layer_color = "#E5E7E9"
586598

587-
formatting = {'style': 'filled',
599+
formatting = {'shape': 'Mrecord',
600+
'style': 'filled',
588601
'tooltip': layer.tooltip(),
589602
'fillcolor': layer_color,
590-
'color': 'white',}
603+
'color': 'lightgray',
604+
'fontname': 'Helvetica'}
591605
return formatting
592606

593607

@@ -623,11 +637,16 @@ def get_latency(plan: EnginePlan, layer: Layer, latency_type) -> float:
623637
return latency
624638

625639

640+
def get_dot_id(layer_name: str) -> str:
641+
return layer_name.replace(":", "###") # f"l_{dot_node_id}"
642+
643+
626644
class DotGraph(object):
627645
"""This class converts a TensorRT plan into Graphviz DOT graphs"""
628646
def __init__(self,
629647
plan: EnginePlan,
630648
layer_node_formatter: Callable,
649+
layer_node_highlighter: Callable=layer_node_highlighter,
631650
layer_node_renderer: Callable=layer_node_configurable_renderer,
632651
region_formatter: Callable=region_precision_formatter,
633652
display_layer_names: bool=True,
@@ -642,11 +661,13 @@ def __init__(self,
642661
display_region_names: bool=False,
643662
display_edge_name: bool=False,
644663
display_edge_details: bool=True,
664+
highlight_layers: list=None,
645665
):
646666
plan_graph = PlanGraph(
647667
plan, display_regions, display_constants, display_forking_regions)
648668
self.dot = Digraph()
649669
self.layer_node_formatter = layer_node_formatter
670+
self.layer_node_highlighter = layer_node_highlighter
650671
self.layer_node_renderer = layer_node_renderer
651672
self.region_formatter = region_formatter
652673
self.expand_layer_details = expand_layer_details
@@ -659,6 +680,14 @@ def __init__(self,
659680
self.display_region_names = display_region_names
660681
self.display_edge_name = display_edge_name
661682
self.display_edge_details = display_edge_details
683+
# Get the node names of the layers to highlight
684+
self.highlighted_layers_ids = None
685+
if highlight_layers:
686+
try:
687+
highlight_layers_name = plan.df['Name'].iloc[highlight_layers].to_list()
688+
self.highlighted_layers_ids = [get_dot_id(name) for name in highlight_layers_name]
689+
except IndexError:
690+
warnings.warn("The layers indices specified for highlighting are incorrect")
662691

663692
node_name_2_node_id = {}
664693
self.__add_dot_region_nodes(plan_graph, node_name_2_node_id)
@@ -672,7 +701,7 @@ def __init__(self,
672701
def __add_dot_region_nodes(self, plan_graph, node_name_2_node_id):
673702
dot_node_id = 0
674703
for mem_node in plan_graph.memory_nodes:
675-
node_name_2_node_id[mem_node.name] = dot_id = mem_node.name.replace(":", "###") #f"r_{dot_node_id}"
704+
node_name_2_node_id[mem_node.name] = dot_id = get_dot_id(mem_node.name)
676705
self.__create_dot_region_node(dot_id, mem_node.tensor, mem_node.is_user, mem_node.region_gen)
677706
dot_node_id += 1
678707

@@ -681,7 +710,7 @@ def __add_dot_layer_nodes(self, plan, plan_graph, node_name_2_node_id):
681710
layer = layer_node.layer
682711
latency = get_latency(plan, layer, self.latency_type)
683712
if not layer.type == 'Constant' or plan_graph.include_constants:
684-
dot_id = layer.name.replace(":", "###") # f"l_{dot_node_id}"
713+
dot_id = get_dot_id(layer.name)
685714
node_name_2_node_id[layer.name] = dot_id
686715
self.__create_dot_layer_node(
687716
dot_id, layer, latency, layer_node_renderer=self.layer_node_renderer)
@@ -713,6 +742,7 @@ def __create_dot_layer_node(
713742
self, node_id: int, layer: Layer, latency: float, layer_node_renderer: Callable
714743
):
715744
formatting = self.layer_node_formatter(layer)
745+
formatting.update(self.layer_node_highlighter(node_id, self.highlighted_layers_ids))
716746
self.dot.node(
717747
str(node_id),
718748
layer_node_renderer(
@@ -721,8 +751,7 @@ def __create_dot_layer_node(
721751
expand_layer_details=self.expand_layer_details,
722752
display_layer_names=self.display_layer_names,
723753
stack_layer_names=self.stack_layer_names),
724-
shape='Mrecord',
725-
fontname="Helvetica", **formatting)
754+
**formatting)
726755

727756
def __create_dot_edge(self, src, end, tensor, region_gen):
728757
def generation_color(gen: int, line_color: str) -> str:

tools/experimental/trt-engine-explorer/trex/parser.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,23 @@ def convert_deconv(raw_layers: List) -> List:
172172
pass
173173
return raw_layers
174174

175+
def fix_metadata(raw_layers: List) -> List:
176+
"""TensorRT 8.6 introduced the Metadata field, with a non-ASCII character
177+
that triggers an SVG rendering error. This function replaces this character.
178+
179+
See: https://github.com/NVIDIA/TensorRT/issues/2779
180+
"""
181+
TRT_METADATA_DELIM = '\x1E'
182+
for l in raw_layers:
183+
try:
184+
if TRT_METADATA_DELIM in l['Metadata']:
185+
l['Metadata'] = l['Metadata'].replace(TRT_METADATA_DELIM, '+')
186+
except KeyError:
187+
pass
188+
return raw_layers
189+
175190
raw_layers, bindings = read_graph_file(graph_file)
191+
raw_layers = fix_metadata(raw_layers)
176192
raw_layers = convert_deconv(raw_layers)
177193
raw_layers = disambiguate_layer_names(raw_layers)
178194
raw_layers, bindings = filter_profiles(raw_layers, bindings, profile_id)

0 commit comments

Comments
 (0)