1+ from ast import Assert
12import logging
23from typing import Collection , Dict , List , Optional , Tuple
34
@@ -226,16 +227,26 @@ def partition_graph(self) -> torch.fx.GraphModule:
226227 # Remove segments smaller than the block size (with exceptions)
227228 subgraphs = self .remove_small_acc_subgraphs (subgraphs )
228229
229- num_of_break = self .calculate_num_of_break (subgraphs )
230- subgraphs = self .break_subgraphs (subgraphs , num_of_break = num_of_break )
230+ # num_of_break = self.calculate_num_of_break(subgraphs)
231+ subgraphs = self .break_subgraphs_by_node (subgraphs , num_of_break = 5 )
231232
232233 # Set the number of TRT engines to be generated
233234 self .num_trt_accelerated_subgraphs = len ([s for s in subgraphs if s .is_acc ])
234235
235236 # Tag the accelerated nodes and split the graph accordingly
237+ print ([len (s .nodes ) for s in subgraphs ])
236238 self .tag (subgraphs )
237- return self .split ()
238239
240+ for s in subgraphs :
241+ print (s .nodes )
242+
243+ gm = self .split ()
244+ self .weight_visited_nodes = set ()
245+ [self .size_of_subgraph (s ) for s in subgraphs ]
246+
247+
248+ return gm
249+
239250 def calculate_num_of_break (self , subgraphs : List [Subgraph ]) -> int :
240251 """
241252 This function calculates the break period based on the number of subgraphs.
@@ -260,15 +271,16 @@ def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int:
260271 1 , num_of_graphs // ((len (subgraphs ) + 1 ) // 2 )
261272 ) # If there are already graph breaks, for each TRT subgraph, we break for a few times.
262273
263- def break_subgraphs (
274+
275+ def break_subgraphs_by_node (
264276 self , subgraphs : List [Subgraph ], num_of_break : int = 1
265277 ) -> List [Subgraph ]:
266278 """
267279 This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
268280 """
269-
281+ op_to_break = "add."
270282 num_of_sdpa_node = len (
271- [node for node in self .acc_nodes if "scaled_dot" in str (node .target )]
283+ [node for node in self .acc_nodes if op_to_break in str (node .target )]
272284 )
273285 break_period = num_of_sdpa_node // num_of_break + 1
274286 current_break_idx = 0
@@ -277,7 +289,7 @@ def break_subgraphs(
277289 for subgraph in subgraphs :
278290 if subgraph .is_acc :
279291 for i , node in enumerate (subgraph .nodes ):
280- if "scaled_dot" in str (node .target ):
292+ if op_to_break in str (node .target ):
281293 current_num_break += 1
282294 if current_num_break % break_period != 0 :
283295 continue
@@ -298,8 +310,110 @@ def break_subgraphs(
298310 )
299311 else :
300312 new_subgraphs .append (subgraph )
313+
314+ new_subgraphs = self .validate_and_correct_subgraphs (new_subgraphs )
315+
301316 return new_subgraphs
302317
318+ def break_subgraphs (
319+ self , subgraphs : List [Subgraph ], num_of_break : int = 1
320+ ) -> List [Subgraph ]:
321+ """
322+ This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
323+ """
324+ break_pos = [0 , 100 , 200 , 300 , 400 ]
325+ current_break_idx = 0
326+ new_subgraphs = []
327+ for subgraph in subgraphs :
328+ if subgraph .is_acc :
329+ for i , node in enumerate (subgraph .nodes ):
330+ if i in break_pos :
331+
332+ new_subgraphs .append (
333+ Subgraph (
334+ is_acc = True ,
335+ nodes = subgraph .nodes [current_break_idx : i + 1 ],
336+ device_ordinal = subgraph .device_ordinal ,
337+ )
338+ )
339+ current_break_idx = i + 1
340+ new_subgraphs .append (
341+ Subgraph (
342+ is_acc = True ,
343+ nodes = subgraph .nodes [current_break_idx :],
344+ device_ordinal = subgraph .device_ordinal ,
345+ )
346+ )
347+ else :
348+ new_subgraphs .append (subgraph )
349+
350+ new_subgraphs = self .validate_and_correct_subgraphs (new_subgraphs )
351+ return new_subgraphs
352+
353+ def size_of_subgraph (self , subgraph : Subgraph ) -> int :
354+ """
355+ This function calculates the size of the subgraph.
356+ """
357+ stack = subgraph .nodes .copy ()
358+ size = 0
359+ while stack :
360+ node = stack .pop ()
361+ if node in self .weight_visited_nodes :
362+ continue
363+ self .weight_visited_nodes .add (node )
364+ if node .op == "get_attr" :
365+ weight = self .module .state_dict ()[node .target ]
366+ size += weight .numel () * weight .element_size ()
367+ self .weight_visited_nodes .add (node )
368+ continue
369+ for input_node in node ._input_nodes :
370+ if input_node not in self .weight_visited_nodes :
371+ stack .append (input_node )
372+ print (size )
373+ return size
374+
375+ def validate_and_correct_subgraphs (self , subgraphs : List [Subgraph ]) -> List [Subgraph ]:
376+ """
377+ This function validates the subgraphs by checking if the subgraphs are valid, and corrects the subgraphs if they are not valid.
378+ """
379+ visited_nodes = {}
380+ print ([len (s .nodes ) for s in subgraphs ])
381+ for i , subgraph in enumerate (subgraphs ):
382+ if i == 0 :
383+ for node in subgraph .nodes :
384+ visited_nodes [node ] = i
385+ visited_nodes [subgraph .nodes [- 1 ]] = i + 1
386+ continue
387+
388+
389+ elif not subgraph .is_acc :
390+ for node in subgraph .nodes :
391+ visited_nodes [subgraph .nodes [- 1 ]] = i + 1
392+ continue
393+
394+ else :
395+ to_remove_nodes = []
396+ for j , node in enumerate (subgraph .nodes ):
397+ if j == len (subgraph .nodes ) - 1 :
398+ visited_nodes [node ] = i + 1
399+ continue
400+ subgraph_idx = 0
401+ for dep in self .deps [node ]:
402+ if dep in visited_nodes :
403+ subgraph_idx = max (subgraph_idx , visited_nodes [dep ])
404+ else :
405+ raise ValueError (f"Node { node } have a dependency that is not covered in the previous subgraphs. This is caused by a invalid subgraph segmentation." )
406+ if subgraph_idx != i :
407+ subgraphs [subgraph_idx ].nodes .append (node )
408+ to_remove_nodes .append (node )
409+ visited_nodes [node ] = subgraph_idx
410+ for node in to_remove_nodes :
411+ subgraph .nodes .remove (node )
412+
413+ return subgraphs
414+
415+
416+
303417 def starter_nodes (self ) -> Tuple [NodeSet , NodeSet ]:
304418 """Generates starter nodes for partitioning + segmentation"""
305419 # Starter accelerated nodes are all callable accelerated ops
0 commit comments