Skip to content

Commit 57b04d7

Browse files
committed
Implemeted the prototype
1 parent 013c772 commit 57b04d7

File tree

1 file changed

+174
-70
lines changed

1 file changed

+174
-70
lines changed

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 174 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from ast import Assert
21
import logging
32
from typing import Collection, Dict, List, Optional, Tuple
43

@@ -26,6 +25,10 @@
2625
)
2726

2827
logger = logging.getLogger(__name__)
28+
NON_BREAKABLE_OP_LISTS = [
29+
["addmm", "addmm"],
30+
["conv2d", "batch_norm2d", "relu"],
31+
]
2932

3033

3134
class OpSupportTester(ops.OperatorSupportBase): # type: ignore
@@ -227,8 +230,9 @@ def partition_graph(self) -> torch.fx.GraphModule:
227230
# Remove segments smaller than the block size (with exceptions)
228231
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
229232

230-
# num_of_break = self.calculate_num_of_break(subgraphs)
231-
subgraphs = self.break_subgraphs_by_node(subgraphs, num_of_break=5)
233+
subgraphs = self.break_subgraphs(
234+
subgraphs, size_budget=self.calculate_size_budget()
235+
)
232236

233237
# Set the number of TRT engines to be generated
234238
self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc])
@@ -241,44 +245,27 @@ def partition_graph(self) -> torch.fx.GraphModule:
241245
print(s.nodes)
242246

243247
gm = self.split()
244-
self.weight_visited_nodes = set()
245-
[self.size_of_subgraph(s) for s in subgraphs]
246-
247248

248249
return gm
249-
250-
def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int:
250+
251+
def calculate_size_budget(
252+
self, engine_compilation_memory_usage_multiplier: int = 4
253+
) -> int:
251254
"""
252-
This function calculates the break period based on the number of subgraphs.
255+
This function calculates the size budget based on the available RSS. We assume that TRT compilation
256+
needs at most 4x the memory of the model.
253257
"""
254-
rss = psutil.Process().memory_info().rss
255-
available_rss = psutil.virtual_memory().available
256-
num_of_graphs = len(subgraphs)
257-
if rss < available_rss * 0.3:
258-
num_of_graphs = 1
259-
elif rss < available_rss * 0.5:
260-
num_of_graphs = 2
261-
elif rss < available_rss:
262-
num_of_graphs = 4
263-
elif rss < available_rss * 1.5:
264-
num_of_graphs = 8
265-
elif rss < available_rss * 2:
266-
num_of_graphs = 16
267-
else:
268-
num_of_graphs = 32
269-
270-
return max(
271-
1, num_of_graphs // ((len(subgraphs) + 1) // 2)
272-
) # If there are already graph breaks, for each TRT subgraph, we break for a few times.
273258

259+
available_rss: int = psutil.virtual_memory().available
260+
return available_rss // engine_compilation_memory_usage_multiplier
274261

275262
def break_subgraphs_by_node(
276263
self, subgraphs: List[Subgraph], num_of_break: int = 1
277264
) -> List[Subgraph]:
278265
"""
279266
This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
280267
"""
281-
op_to_break = "add."
268+
op_to_break = "addmm."
282269
num_of_sdpa_node = len(
283270
[node for node in self.acc_nodes if op_to_break in str(node.target)]
284271
)
@@ -312,80 +299,200 @@ def break_subgraphs_by_node(
312299
new_subgraphs.append(subgraph)
313300

314301
new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs)
315-
302+
316303
return new_subgraphs
317304

318305
def break_subgraphs(
319-
self, subgraphs: List[Subgraph], num_of_break: int = 1
306+
self, subgraphs: List[Subgraph], size_budget: int
320307
) -> List[Subgraph]:
321308
"""
322-
This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
309+
This function breaks the subgraphs into smaller subgraphs to save CPU memory.
323310
"""
324-
break_pos = [0, 100, 200, 300, 400]
325-
current_break_idx = 0
326311
new_subgraphs = []
327-
for subgraph in subgraphs:
328-
if subgraph.is_acc:
329-
for i, node in enumerate(subgraph.nodes):
330-
if i in break_pos:
312+
# We throw an error if the remaining memory is almost empty compared to the model size.
313+
# i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
314+
sizes = [(subgraph, self.size_of_subgraph(subgraph)) for subgraph in subgraphs]
315+
if sum([size for _, size in sizes]) > size_budget * 40:
316+
raise ValueError(
317+
f"Subgraph size {sum([size for _, size in sizes])} is too large to break. Size budget: {size_budget}"
318+
)
319+
for subgraph, size in sizes:
331320

332-
new_subgraphs.append(
333-
Subgraph(
334-
is_acc=True,
335-
nodes=subgraph.nodes[current_break_idx : i + 1],
336-
device_ordinal=subgraph.device_ordinal,
337-
)
338-
)
339-
current_break_idx = i + 1
340-
new_subgraphs.append(
341-
Subgraph(
342-
is_acc=True,
343-
nodes=subgraph.nodes[current_break_idx:],
344-
device_ordinal=subgraph.device_ordinal,
345-
)
321+
while size > size_budget:
322+
broken_subgraphs, size_0, size_1 = self.break_subgraph_by_size(
323+
subgraph, size_budget
346324
)
347-
else:
348-
new_subgraphs.append(subgraph)
325+
size = size_1
326+
new_subgraphs.append(broken_subgraphs[0])
327+
subgraph = broken_subgraphs[1]
328+
new_subgraphs.append(subgraph)
349329

350-
new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs)
351330
return new_subgraphs
352331

332+
def break_subgraph_by_size(
333+
self, subgraph: Subgraph, size_to_break: int
334+
) -> Tuple[List[Subgraph], int, int]:
335+
"""
336+
This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
337+
"""
338+
all_nodes = subgraph.nodes
339+
device_ordinal = subgraph.device_ordinal
340+
new_subgraphs = [
341+
Subgraph(
342+
is_acc=True,
343+
nodes=[],
344+
device_ordinal=device_ordinal,
345+
),
346+
Subgraph(
347+
is_acc=True,
348+
nodes=all_nodes,
349+
device_ordinal=device_ordinal,
350+
),
351+
]
352+
353+
while True:
354+
new_subgraphs = self.step_and_validate(new_subgraphs)
355+
size_0, size_1 = self.size_of_subgraph(
356+
new_subgraphs[0]
357+
), self.size_of_subgraph(new_subgraphs[1])
358+
if size_0 > size_to_break:
359+
break
360+
361+
if len(new_subgraphs[1].nodes) == 0:
362+
new_subgraphs.pop(1)
363+
return new_subgraphs, size_0, size_1
364+
365+
def step_and_validate(
366+
self, new_subgraphs: List[Subgraph], step_size: int = 1
367+
) -> List[Subgraph]:
368+
369+
# TODO: We can change it to binary search to find the optimal break point
370+
for _ in range(step_size):
371+
new_subgraphs[0].nodes.append(new_subgraphs[1].nodes.pop(0))
372+
373+
while True:
374+
new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs)
375+
nodes_in_first_subgraph = set(new_subgraphs[0].nodes)
376+
leaf_node = self.get_leaf_node(nodes_in_first_subgraph)
377+
broken_fusion = self.step_if_break_fusion(
378+
new_subgraphs, leaf_node, nodes_in_first_subgraph
379+
)
380+
if not broken_fusion or len(new_subgraphs[1].nodes) == 0:
381+
break
382+
383+
return new_subgraphs
384+
385+
def step_if_break_fusion(
386+
self,
387+
subgraphs: List[Subgraph],
388+
leaf_nodes: set[torch.fx.Node],
389+
nodes_in_first_subgraph: set[torch.fx.Node],
390+
) -> bool:
391+
392+
def add_nodes(node: torch.fx.Node) -> None:
393+
"""
394+
This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order.
395+
"""
396+
if node.op in CALLABLE_NODE_OPS and node not in nodes_in_first_subgraph:
397+
nodes_in_first_subgraph.add(node)
398+
for input_node in node._input_nodes:
399+
add_nodes(input_node)
400+
subgraphs[0].nodes.append(node)
401+
subgraphs[1].nodes.remove(node)
402+
403+
def match_subgraph_and_step(node: torch.fx.Node) -> bool:
404+
added_nodes = False
405+
for op_list in NON_BREAKABLE_OP_LISTS:
406+
for i, op in enumerate(op_list):
407+
if i != len(op_list) - 1 and op in str(node.target):
408+
# Search following ops forward using BFS. We skip search previous ops because
409+
# even if it's just a subset of fusion graph, we still want it to be fused.
410+
411+
users = node.users.keys()
412+
matching_nodes: set[torch.fx.Node] = set()
413+
for following_op_idx in range(i + 1, len(op_list)):
414+
matching_nodes = set()
415+
for user in users:
416+
if op_list[following_op_idx] in str(user.target):
417+
matching_nodes.add(user)
418+
if not matching_nodes:
419+
break
420+
users = set()
421+
for matching_node in matching_nodes:
422+
for next_user in matching_node.users:
423+
users.add(next_user)
424+
425+
for matching_node in matching_nodes:
426+
added_nodes = True
427+
add_nodes(matching_node)
428+
429+
if added_nodes:
430+
# Early terminate the search if we have found a match because preceeding matches can cover following matches
431+
break
432+
433+
return True if added_nodes else False
434+
435+
found_match = False
436+
for leaf in leaf_nodes:
437+
if match_subgraph_and_step(leaf):
438+
found_match = True
439+
440+
return found_match
441+
442+
def get_leaf_node(
443+
self, nodes_in_first_subgraph: set[torch.fx.Node]
444+
) -> set[torch.fx.Node]:
445+
leaf_node = set()
446+
447+
for node in nodes_in_first_subgraph:
448+
for user in node.users:
449+
if user not in nodes_in_first_subgraph:
450+
leaf_node.add(node)
451+
break
452+
return leaf_node
453+
353454
def size_of_subgraph(self, subgraph: Subgraph) -> int:
354455
"""
355456
This function calculates the size of the subgraph.
356457
"""
458+
nodes_in_subgraph = set(subgraph.nodes)
459+
weight_visited_nodes = set()
357460
stack = subgraph.nodes.copy()
358461
size = 0
359462
while stack:
360463
node = stack.pop()
361-
if node in self.weight_visited_nodes:
464+
if node in weight_visited_nodes:
362465
continue
363-
self.weight_visited_nodes.add(node)
364466
if node.op == "get_attr":
365467
weight = self.module.state_dict()[node.target]
366468
size += weight.numel() * weight.element_size()
367-
self.weight_visited_nodes.add(node)
469+
weight_visited_nodes.add(node)
470+
continue
471+
if node not in nodes_in_subgraph:
472+
# Trace to other subgraphs
368473
continue
369-
for input_node in node._input_nodes:
370-
if input_node not in self.weight_visited_nodes:
474+
for input_node in node._input_nodes:
475+
if input_node not in weight_visited_nodes:
371476
stack.append(input_node)
372-
print(size)
477+
373478
return size
374479

375-
def validate_and_correct_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
480+
def validate_and_correct_subgraphs(
481+
self, subgraphs: List[Subgraph]
482+
) -> List[Subgraph]:
376483
"""
377484
This function validates the subgraphs by checking if the subgraphs are valid, and corrects the subgraphs if they are not valid.
378485
"""
379-
visited_nodes = {}
380-
print([len(s.nodes) for s in subgraphs])
486+
visited_nodes = (
487+
{}
488+
) # a map from a node to the index of the subgraph it's user should belong to
381489
for i, subgraph in enumerate(subgraphs):
382490
if i == 0:
383491
for node in subgraph.nodes:
384492
visited_nodes[node] = i
385493
visited_nodes[subgraph.nodes[-1]] = i + 1
386494
continue
387495

388-
389496
elif not subgraph.is_acc:
390497
for node in subgraph.nodes:
391498
visited_nodes[subgraph.nodes[-1]] = i + 1
@@ -401,18 +508,15 @@ def validate_and_correct_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subg
401508
for dep in self.deps[node]:
402509
if dep in visited_nodes:
403510
subgraph_idx = max(subgraph_idx, visited_nodes[dep])
404-
else:
405-
raise ValueError(f"Node {node} have a dependency that is not covered in the previous subgraphs. This is caused by a invalid subgraph segmentation.")
511+
406512
if subgraph_idx != i:
407513
subgraphs[subgraph_idx].nodes.append(node)
408514
to_remove_nodes.append(node)
409515
visited_nodes[node] = subgraph_idx
410516
for node in to_remove_nodes:
411517
subgraph.nodes.remove(node)
412-
518+
413519
return subgraphs
414-
415-
416520

417521
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
418522
"""Generates starter nodes for partitioning + segmentation"""

0 commit comments

Comments
 (0)