@@ -118,6 +118,7 @@ def __init__(
118118 require_full_compilation : bool = REQUIRE_FULL_COMPILATION ,
119119 return_tuple : bool = False ,
120120 skip_fusion : bool = False ,
121+ cpu_memory_budget : int = - 1 ,
121122 ):
122123 """
123124 Preprocesses graph before splitting:
@@ -137,6 +138,7 @@ def __init__(
137138 skip_fusion = skip_fusion ,
138139 )
139140 self .operator_support = operator_support
141+ self .cpu_memory_budget = cpu_memory_budget
140142
141143 # Get all accelerated nodes based on operator support conditions
142144 self .acc_nodes = FxNetAccNodesFinder (
@@ -231,19 +233,15 @@ def partition_graph(self) -> torch.fx.GraphModule:
231233 subgraphs = self .remove_small_acc_subgraphs (subgraphs )
232234
233235 subgraphs = self .break_subgraphs (
234- subgraphs , size_budget = self .calculate_size_budget ()
236+ subgraphs , subgraph_size_budget = self .calculate_size_budget ()
235237 )
236238
237239 # Set the number of TRT engines to be generated
238240 self .num_trt_accelerated_subgraphs = len ([s for s in subgraphs if s .is_acc ])
239241
240242 # Tag the accelerated nodes and split the graph accordingly
241- print ([len (s .nodes ) for s in subgraphs ])
242243 self .tag (subgraphs )
243244
244- for s in subgraphs :
245- print (s .nodes )
246-
247245 gm = self .split ()
248246
249247 return gm
@@ -255,8 +253,11 @@ def calculate_size_budget(
255253 This function calculates the size budget based on the available RSS. We assume that TRT compilation
256254 needs at most 4x the memory of the model.
257255 """
258-
259- available_rss : int = psutil .virtual_memory ().available
256+ if self .cpu_memory_budget == - 1 :
257+ available_rss : int = psutil .virtual_memory ().available
258+ else :
259+ used_rss : int = psutil .virtual_memory ().used
260+ available_rss = self .cpu_memory_budget - used_rss
260261 return available_rss // engine_compilation_memory_usage_multiplier
261262
262263 def break_subgraphs_by_node (
@@ -303,24 +304,25 @@ def break_subgraphs_by_node(
303304 return new_subgraphs
304305
305306 def break_subgraphs (
306- self , subgraphs : List [Subgraph ], size_budget : int
307+ self , subgraphs : List [Subgraph ], subgraph_size_budget : int
307308 ) -> List [Subgraph ]:
308309 """
309310 This function breaks the subgraphs into smaller subgraphs to save CPU memory.
310311 """
311312 new_subgraphs = []
312313 # We throw an error if the remaining memory is almost empty compared to the model size.
313314 # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
314- sizes = [( subgraph , self .size_of_subgraph ( subgraph )) for subgraph in subgraphs ]
315- if sum ([ size for _ , size in sizes ] ) > size_budget * 40 :
315+ sizes = self .size_of_subgraphs ( subgraphs )
316+ if sum (sizes ) > subgraph_size_budget * 40 :
316317 raise ValueError (
317- f"Subgraph size { sum ([size for _ , size in sizes ])} is too large to break. Size budget: { size_budget } "
318+ f"CPU memory budget or available memory is too small to compile the model. CPU memory budget: { self .cpu_memory_budget // (1024 * 1024 ) if self .cpu_memory_budget != - 1 else "All available memory" } MB, Model size: { sum (sizes ) // (1024 * 1024 )} MB. "
319+ + "Consider setting cpu_memory_budget to a larger value or disable offload_module_to_cpu to save more CPU memory."
318320 )
319- for subgraph , size in sizes :
321+ for subgraph , size in zip ( subgraphs , sizes ) :
320322
321- while size > size_budget :
323+ while size > subgraph_size_budget :
322324 broken_subgraphs , size_0 , size_1 = self .break_subgraph_by_size (
323- subgraph , size_budget
325+ subgraph , subgraph_size_budget
324326 )
325327 size = size_1
326328 new_subgraphs .append (broken_subgraphs [0 ])
@@ -351,10 +353,11 @@ def break_subgraph_by_size(
351353 ]
352354
353355 while True :
354- new_subgraphs = self .step_and_validate (new_subgraphs )
355- size_0 , size_1 = self .size_of_subgraph (
356- new_subgraphs [0 ]
357- ), self .size_of_subgraph (new_subgraphs [1 ])
356+ step_size = (
357+ 1 if not new_subgraphs [0 ].nodes else max (1 , len (all_nodes ) // 50 )
358+ ) # Set a step size proportional to the size of the subgraph to make the algorithm more efficient
359+ new_subgraphs = self .step_and_validate (new_subgraphs , step_size )
360+ size_0 , size_1 = self .size_of_subgraphs (new_subgraphs )
358361 if size_0 > size_to_break :
359362 break
360363
@@ -451,31 +454,34 @@ def get_leaf_node(
451454 break
452455 return leaf_node
453456
454- def size_of_subgraph (self , subgraph : Subgraph ) -> int :
457+ def size_of_subgraphs (self , subgraphs : List [ Subgraph ] ) -> List [ int ] :
455458 """
456459 This function calculates the size of the subgraph.
457460 """
458- nodes_in_subgraph = set (subgraph .nodes )
461+ state_dict = self .module .state_dict (keep_vars = True )
462+ sizes = []
459463 weight_visited_nodes = set ()
460- stack = subgraph .nodes .copy ()
461- size = 0
462- while stack :
463- node = stack .pop ()
464- if node in weight_visited_nodes :
465- continue
466- if node .op == "get_attr" :
467- weight = self .module .state_dict ()[node .target ]
468- size += weight .numel () * weight .element_size ()
464+ for subgraph in subgraphs :
465+ nodes_in_subgraph = set (subgraph .nodes )
466+ stack = subgraph .nodes .copy ()
467+ size = 0
468+ while stack :
469+ node = stack .pop ()
470+ if node in weight_visited_nodes :
471+ continue
469472 weight_visited_nodes .add (node )
470- continue
471- if node not in nodes_in_subgraph :
472- # Trace to other subgraphs
473- continue
474- for input_node in node ._input_nodes :
475- if input_node not in weight_visited_nodes :
476- stack .append (input_node )
477-
478- return size
473+ if node .op == "get_attr" :
474+ weight = state_dict [node .target ]
475+ size += weight .numel () * weight .element_size ()
476+ continue
477+ if node not in nodes_in_subgraph :
478+ # Trace to other subgraphs
479+ continue
480+ for input_node in node ._input_nodes :
481+ if input_node not in weight_visited_nodes :
482+ stack .append (input_node )
483+ sizes .append (size )
484+ return sizes
479485
480486 def validate_and_correct_subgraphs (
481487 self , subgraphs : List [Subgraph ]
@@ -541,6 +547,7 @@ def partition(
541547 torch_executed_ops : Collection [Target ] = set (),
542548 require_full_compilation : bool = REQUIRE_FULL_COMPILATION ,
543549 skip_fusion : bool = False ,
550+ cpu_memory_budget : int = - 1 ,
544551) -> Tuple [torch .fx .GraphModule , OpSupportTester ]:
545552 """Partition an FX GraphModule with aten ops into TRT engines
546553 Partitioning is based on converter operator support
@@ -567,6 +574,7 @@ def partition(
567574 min_block_size = min_block_size ,
568575 require_full_compilation = require_full_compilation ,
569576 skip_fusion = skip_fusion ,
577+ cpu_memory_budget = cpu_memory_budget ,
570578 )
571579
572580 partitioned_graph = partitioner .partition_graph ()
0 commit comments