1- from ast import Assert
21import logging
32from typing import Collection , Dict , List , Optional , Tuple
43
2625)
2726
2827logger = logging .getLogger (__name__ )
28+ NON_BREAKABLE_OP_LISTS = [
29+ ["addmm" , "addmm" ],
30+ ["conv2d" , "batch_norm2d" , "relu" ],
31+ ]
2932
3033
3134class OpSupportTester (ops .OperatorSupportBase ): # type: ignore
@@ -227,8 +230,9 @@ def partition_graph(self) -> torch.fx.GraphModule:
227230 # Remove segments smaller than the block size (with exceptions)
228231 subgraphs = self .remove_small_acc_subgraphs (subgraphs )
229232
230- # num_of_break = self.calculate_num_of_break(subgraphs)
231- subgraphs = self .break_subgraphs_by_node (subgraphs , num_of_break = 5 )
233+ subgraphs = self .break_subgraphs (
234+ subgraphs , size_budget = self .calculate_size_budget ()
235+ )
232236
233237 # Set the number of TRT engines to be generated
234238 self .num_trt_accelerated_subgraphs = len ([s for s in subgraphs if s .is_acc ])
@@ -241,44 +245,27 @@ def partition_graph(self) -> torch.fx.GraphModule:
241245 print (s .nodes )
242246
243247 gm = self .split ()
244- self .weight_visited_nodes = set ()
245- [self .size_of_subgraph (s ) for s in subgraphs ]
246-
247248
248249 return gm
249-
250- def calculate_num_of_break (self , subgraphs : List [Subgraph ]) -> int :
250+
251+ def calculate_size_budget (
252+ self , engine_compilation_memory_usage_multiplier : int = 4
253+ ) -> int :
251254 """
252- This function calculates the break period based on the number of subgraphs.
255+ This function calculates the size budget based on the available RSS. We assume that TRT compilation
256+ needs at most 4x the memory of the model.
253257 """
254- rss = psutil .Process ().memory_info ().rss
255- available_rss = psutil .virtual_memory ().available
256- num_of_graphs = len (subgraphs )
257- if rss < available_rss * 0.3 :
258- num_of_graphs = 1
259- elif rss < available_rss * 0.5 :
260- num_of_graphs = 2
261- elif rss < available_rss :
262- num_of_graphs = 4
263- elif rss < available_rss * 1.5 :
264- num_of_graphs = 8
265- elif rss < available_rss * 2 :
266- num_of_graphs = 16
267- else :
268- num_of_graphs = 32
269-
270- return max (
271- 1 , num_of_graphs // ((len (subgraphs ) + 1 ) // 2 )
272- ) # If there are already graph breaks, for each TRT subgraph, we break for a few times.
273258
259+ available_rss : int = psutil .virtual_memory ().available
260+ return available_rss // engine_compilation_memory_usage_multiplier
274261
275262 def break_subgraphs_by_node (
276263 self , subgraphs : List [Subgraph ], num_of_break : int = 1
277264 ) -> List [Subgraph ]:
278265 """
279266 This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
280267 """
281- op_to_break = "add ."
268+ op_to_break = "addmm ."
282269 num_of_sdpa_node = len (
283270 [node for node in self .acc_nodes if op_to_break in str (node .target )]
284271 )
@@ -312,80 +299,200 @@ def break_subgraphs_by_node(
312299 new_subgraphs .append (subgraph )
313300
314301 new_subgraphs = self .validate_and_correct_subgraphs (new_subgraphs )
315-
302+
316303 return new_subgraphs
317304
318305 def break_subgraphs (
319- self , subgraphs : List [Subgraph ], num_of_break : int = 1
306+ self , subgraphs : List [Subgraph ], size_budget : int
320307 ) -> List [Subgraph ]:
321308 """
322- This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
309+ This function breaks the subgraphs into smaller subgraphs to save CPU memory.
323310 """
324- break_pos = [0 , 100 , 200 , 300 , 400 ]
325- current_break_idx = 0
326311 new_subgraphs = []
327- for subgraph in subgraphs :
328- if subgraph .is_acc :
329- for i , node in enumerate (subgraph .nodes ):
330- if i in break_pos :
312+ # We throw an error if the remaining memory is almost empty compared to the model size.
313+ # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
314+ sizes = [(subgraph , self .size_of_subgraph (subgraph )) for subgraph in subgraphs ]
315+ if sum ([size for _ , size in sizes ]) > size_budget * 40 :
316+ raise ValueError (
317+ f"Subgraph size { sum ([size for _ , size in sizes ])} is too large to break. Size budget: { size_budget } "
318+ )
319+ for subgraph , size in sizes :
331320
332- new_subgraphs .append (
333- Subgraph (
334- is_acc = True ,
335- nodes = subgraph .nodes [current_break_idx : i + 1 ],
336- device_ordinal = subgraph .device_ordinal ,
337- )
338- )
339- current_break_idx = i + 1
340- new_subgraphs .append (
341- Subgraph (
342- is_acc = True ,
343- nodes = subgraph .nodes [current_break_idx :],
344- device_ordinal = subgraph .device_ordinal ,
345- )
321+ while size > size_budget :
322+ broken_subgraphs , size_0 , size_1 = self .break_subgraph_by_size (
323+ subgraph , size_budget
346324 )
347- else :
348- new_subgraphs .append (subgraph )
325+ size = size_1
326+ new_subgraphs .append (broken_subgraphs [0 ])
327+ subgraph = broken_subgraphs [1 ]
328+ new_subgraphs .append (subgraph )
349329
350- new_subgraphs = self .validate_and_correct_subgraphs (new_subgraphs )
351330 return new_subgraphs
352331
332+ def break_subgraph_by_size (
333+ self , subgraph : Subgraph , size_to_break : int
334+ ) -> Tuple [List [Subgraph ], int , int ]:
335+ """
336+ This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
337+ """
338+ all_nodes = subgraph .nodes
339+ device_ordinal = subgraph .device_ordinal
340+ new_subgraphs = [
341+ Subgraph (
342+ is_acc = True ,
343+ nodes = [],
344+ device_ordinal = device_ordinal ,
345+ ),
346+ Subgraph (
347+ is_acc = True ,
348+ nodes = all_nodes ,
349+ device_ordinal = device_ordinal ,
350+ ),
351+ ]
352+
353+ while True :
354+ new_subgraphs = self .step_and_validate (new_subgraphs )
355+ size_0 , size_1 = self .size_of_subgraph (
356+ new_subgraphs [0 ]
357+ ), self .size_of_subgraph (new_subgraphs [1 ])
358+ if size_0 > size_to_break :
359+ break
360+
361+ if len (new_subgraphs [1 ].nodes ) == 0 :
362+ new_subgraphs .pop (1 )
363+ return new_subgraphs , size_0 , size_1
364+
365+ def step_and_validate (
366+ self , new_subgraphs : List [Subgraph ], step_size : int = 1
367+ ) -> List [Subgraph ]:
368+
369+ # TODO: We can change it to binary search to find the optimal break point
370+ for _ in range (step_size ):
371+ new_subgraphs [0 ].nodes .append (new_subgraphs [1 ].nodes .pop (0 ))
372+
373+ while True :
374+ new_subgraphs = self .validate_and_correct_subgraphs (new_subgraphs )
375+ nodes_in_first_subgraph = set (new_subgraphs [0 ].nodes )
376+ leaf_node = self .get_leaf_node (nodes_in_first_subgraph )
377+ broken_fusion = self .step_if_break_fusion (
378+ new_subgraphs , leaf_node , nodes_in_first_subgraph
379+ )
380+ if not broken_fusion or len (new_subgraphs [1 ].nodes ) == 0 :
381+ break
382+
383+ return new_subgraphs
384+
385+ def step_if_break_fusion (
386+ self ,
387+ subgraphs : List [Subgraph ],
388+ leaf_nodes : set [torch .fx .Node ],
389+ nodes_in_first_subgraph : set [torch .fx .Node ],
390+ ) -> bool :
391+
392+ def add_nodes (node : torch .fx .Node ) -> None :
393+ """
394+ This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order.
395+ """
396+ if node .op in CALLABLE_NODE_OPS and node not in nodes_in_first_subgraph :
397+ nodes_in_first_subgraph .add (node )
398+ for input_node in node ._input_nodes :
399+ add_nodes (input_node )
400+ subgraphs [0 ].nodes .append (node )
401+ subgraphs [1 ].nodes .remove (node )
402+
403+ def match_subgraph_and_step (node : torch .fx .Node ) -> bool :
404+ added_nodes = False
405+ for op_list in NON_BREAKABLE_OP_LISTS :
406+ for i , op in enumerate (op_list ):
407+ if i != len (op_list ) - 1 and op in str (node .target ):
408+ # Search following ops forward using BFS. We skip search previous ops because
409+ # even if it's just a subset of fusion graph, we still want it to be fused.
410+
411+ users = node .users .keys ()
412+ matching_nodes : set [torch .fx .Node ] = set ()
413+ for following_op_idx in range (i + 1 , len (op_list )):
414+ matching_nodes = set ()
415+ for user in users :
416+ if op_list [following_op_idx ] in str (user .target ):
417+ matching_nodes .add (user )
418+ if not matching_nodes :
419+ break
420+ users = set ()
421+ for matching_node in matching_nodes :
422+ for next_user in matching_node .users :
423+ users .add (next_user )
424+
425+ for matching_node in matching_nodes :
426+ added_nodes = True
427+ add_nodes (matching_node )
428+
429+ if added_nodes :
430+ # Early terminate the search if we have found a match because preceeding matches can cover following matches
431+ break
432+
433+ return True if added_nodes else False
434+
435+ found_match = False
436+ for leaf in leaf_nodes :
437+ if match_subgraph_and_step (leaf ):
438+ found_match = True
439+
440+ return found_match
441+
442+ def get_leaf_node (
443+ self , nodes_in_first_subgraph : set [torch .fx .Node ]
444+ ) -> set [torch .fx .Node ]:
445+ leaf_node = set ()
446+
447+ for node in nodes_in_first_subgraph :
448+ for user in node .users :
449+ if user not in nodes_in_first_subgraph :
450+ leaf_node .add (node )
451+ break
452+ return leaf_node
453+
353454 def size_of_subgraph (self , subgraph : Subgraph ) -> int :
354455 """
355456 This function calculates the size of the subgraph.
356457 """
458+ nodes_in_subgraph = set (subgraph .nodes )
459+ weight_visited_nodes = set ()
357460 stack = subgraph .nodes .copy ()
358461 size = 0
359462 while stack :
360463 node = stack .pop ()
361- if node in self . weight_visited_nodes :
464+ if node in weight_visited_nodes :
362465 continue
363- self .weight_visited_nodes .add (node )
364466 if node .op == "get_attr" :
365467 weight = self .module .state_dict ()[node .target ]
366468 size += weight .numel () * weight .element_size ()
367- self .weight_visited_nodes .add (node )
469+ weight_visited_nodes .add (node )
470+ continue
471+ if node not in nodes_in_subgraph :
472+ # Trace to other subgraphs
368473 continue
369- for input_node in node ._input_nodes :
370- if input_node not in self . weight_visited_nodes :
474+ for input_node in node ._input_nodes :
475+ if input_node not in weight_visited_nodes :
371476 stack .append (input_node )
372- print ( size )
477+
373478 return size
374479
375- def validate_and_correct_subgraphs (self , subgraphs : List [Subgraph ]) -> List [Subgraph ]:
480+ def validate_and_correct_subgraphs (
481+ self , subgraphs : List [Subgraph ]
482+ ) -> List [Subgraph ]:
376483 """
377484 This function validates the subgraphs by checking if the subgraphs are valid, and corrects the subgraphs if they are not valid.
378485 """
379- visited_nodes = {}
380- print ([len (s .nodes ) for s in subgraphs ])
486+ visited_nodes = (
487+ {}
488+ ) # a map from a node to the index of the subgraph it's user should belong to
381489 for i , subgraph in enumerate (subgraphs ):
382490 if i == 0 :
383491 for node in subgraph .nodes :
384492 visited_nodes [node ] = i
385493 visited_nodes [subgraph .nodes [- 1 ]] = i + 1
386494 continue
387495
388-
389496 elif not subgraph .is_acc :
390497 for node in subgraph .nodes :
391498 visited_nodes [subgraph .nodes [- 1 ]] = i + 1
@@ -401,18 +508,15 @@ def validate_and_correct_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subg
401508 for dep in self .deps [node ]:
402509 if dep in visited_nodes :
403510 subgraph_idx = max (subgraph_idx , visited_nodes [dep ])
404- else :
405- raise ValueError (f"Node { node } have a dependency that is not covered in the previous subgraphs. This is caused by a invalid subgraph segmentation." )
511+
406512 if subgraph_idx != i :
407513 subgraphs [subgraph_idx ].nodes .append (node )
408514 to_remove_nodes .append (node )
409515 visited_nodes [node ] = subgraph_idx
410516 for node in to_remove_nodes :
411517 subgraph .nodes .remove (node )
412-
518+
413519 return subgraphs
414-
415-
416520
417521 def starter_nodes (self ) -> Tuple [NodeSet , NodeSet ]:
418522 """Generates starter nodes for partitioning + segmentation"""
0 commit comments