2222
2323import os
2424import re
25+ import warnings
2526from enum import Enum
2627from graphviz import Digraph
2728from typing import Callable , NamedTuple , List , Dict
@@ -434,6 +435,17 @@ def add_io(tensors: List):
434435 return label
435436
436437
438+ def layer_node_highlighter (node_id : str , highlighted_layers_ids : List [int ]
439+ ) -> Dict :
440+ """Highlight a layer node.
441+
442+ Create a yellow hailo around the node.
443+ """
444+ should_highlight = highlighted_layers_ids and node_id in highlighted_layers_ids
445+ formatting = {'penwidth' : str (6 ), 'color' : 'yellow' }
446+ return formatting if should_highlight else {}
447+
448+
437449def layer_node_configurable_renderer (
438450 layer : Layer ,
439451 latency : float ,
@@ -584,10 +596,12 @@ def handle_reformat(layer: Layer):
584596 except KeyError :
585597 layer_color = "#E5E7E9"
586598
587- formatting = {'style' : 'filled' ,
599+ formatting = {'shape' : 'Mrecord' ,
600+ 'style' : 'filled' ,
588601 'tooltip' : layer .tooltip (),
589602 'fillcolor' : layer_color ,
590- 'color' : 'white' ,}
603+ 'color' : 'lightgray' ,
604+ 'fontname' : 'Helvetica' }
591605 return formatting
592606
593607
@@ -623,11 +637,16 @@ def get_latency(plan: EnginePlan, layer: Layer, latency_type) -> float:
623637 return latency
624638
625639
640+ def get_dot_id (layer_name : str ) -> str :
641+ return layer_name .replace (":" , "###" ) # f"l_{dot_node_id}"
642+
643+
626644class DotGraph (object ):
627645 """This class converts a TensorRT plan into Graphviz DOT graphs"""
628646 def __init__ (self ,
629647 plan : EnginePlan ,
630648 layer_node_formatter : Callable ,
649+ layer_node_highlighter : Callable = layer_node_highlighter ,
631650 layer_node_renderer : Callable = layer_node_configurable_renderer ,
632651 region_formatter : Callable = region_precision_formatter ,
633652 display_layer_names : bool = True ,
@@ -642,11 +661,13 @@ def __init__(self,
642661 display_region_names : bool = False ,
643662 display_edge_name : bool = False ,
644663 display_edge_details : bool = True ,
664+ highlight_layers : list = None ,
645665 ):
646666 plan_graph = PlanGraph (
647667 plan , display_regions , display_constants , display_forking_regions )
648668 self .dot = Digraph ()
649669 self .layer_node_formatter = layer_node_formatter
670+ self .layer_node_highlighter = layer_node_highlighter
650671 self .layer_node_renderer = layer_node_renderer
651672 self .region_formatter = region_formatter
652673 self .expand_layer_details = expand_layer_details
@@ -659,6 +680,14 @@ def __init__(self,
659680 self .display_region_names = display_region_names
660681 self .display_edge_name = display_edge_name
661682 self .display_edge_details = display_edge_details
683+ # Get the node names of the layers to highlight
684+ self .highlighted_layers_ids = None
685+ if highlight_layers :
686+ try :
687+ highlight_layers_name = plan .df ['Name' ].iloc [highlight_layers ].to_list ()
688+ self .highlighted_layers_ids = [get_dot_id (name ) for name in highlight_layers_name ]
689+ except IndexError :
690+ warnings .warn ("The layers indices specified for highlighting are incorrect" )
662691
663692 node_name_2_node_id = {}
664693 self .__add_dot_region_nodes (plan_graph , node_name_2_node_id )
@@ -672,7 +701,7 @@ def __init__(self,
672701 def __add_dot_region_nodes (self , plan_graph , node_name_2_node_id ):
673702 dot_node_id = 0
674703 for mem_node in plan_graph .memory_nodes :
675- node_name_2_node_id [mem_node .name ] = dot_id = mem_node .name . replace ( ":" , "###" ) #f"r_{dot_node_id}"
704+ node_name_2_node_id [mem_node .name ] = dot_id = get_dot_id ( mem_node .name )
676705 self .__create_dot_region_node (dot_id , mem_node .tensor , mem_node .is_user , mem_node .region_gen )
677706 dot_node_id += 1
678707
@@ -681,7 +710,7 @@ def __add_dot_layer_nodes(self, plan, plan_graph, node_name_2_node_id):
681710 layer = layer_node .layer
682711 latency = get_latency (plan , layer , self .latency_type )
683712 if not layer .type == 'Constant' or plan_graph .include_constants :
684- dot_id = layer .name . replace ( ":" , "###" ) # f"l_{dot_node_id}"
713+ dot_id = get_dot_id ( layer .name )
685714 node_name_2_node_id [layer .name ] = dot_id
686715 self .__create_dot_layer_node (
687716 dot_id , layer , latency , layer_node_renderer = self .layer_node_renderer )
@@ -713,6 +742,7 @@ def __create_dot_layer_node(
713742 self , node_id : int , layer : Layer , latency : float , layer_node_renderer : Callable
714743 ):
715744 formatting = self .layer_node_formatter (layer )
745+ formatting .update (self .layer_node_highlighter (node_id , self .highlighted_layers_ids ))
716746 self .dot .node (
717747 str (node_id ),
718748 layer_node_renderer (
@@ -721,8 +751,7 @@ def __create_dot_layer_node(
721751 expand_layer_details = self .expand_layer_details ,
722752 display_layer_names = self .display_layer_names ,
723753 stack_layer_names = self .stack_layer_names ),
724- shape = 'Mrecord' ,
725- fontname = "Helvetica" , ** formatting )
754+ ** formatting )
726755
727756 def __create_dot_edge (self , src , end , tensor , region_gen ):
728757 def generation_color (gen : int , line_color : str ) -> str :
0 commit comments