Skip to content

Commit 2389d34

Browse files
committed
Added new subgraph definition paradigm and revised matching logic
1 parent f03ab2c commit 2389d34

File tree

3 files changed

+180
-40
lines changed

3 files changed

+180
-40
lines changed

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ def partition_graph(self) -> torch.fx.GraphModule:
233233
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
234234

235235
subgraphs = self.break_subgraphs(
236-
subgraphs, subgraph_size_budget=self.calculate_size_budget()
236+
subgraphs,
237+
subgraph_size_budget=500 * 1024 * 1024, # self.calculate_size_budget()
237238
)
238239

239240
# Set the number of TRT engines to be generated
@@ -309,6 +310,11 @@ def break_subgraphs(
309310
"""
310311
This function breaks the subgraphs into smaller subgraphs to save CPU memory.
311312
"""
313+
from torch_tensorrt.dynamo.partitioning.fusion_patterns import (
314+
get_node_in_fusion_pattern,
315+
)
316+
317+
self.fusion_patterns = get_node_in_fusion_pattern(self.module.graph)
312318
new_subgraphs = []
313319
# We throw an error if the remaining memory is almost empty compared to the model size.
314320
# i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
@@ -328,9 +334,26 @@ def break_subgraphs(
328334
new_subgraphs.append(broken_subgraphs[0])
329335
subgraph = broken_subgraphs[1]
330336
new_subgraphs.append(subgraph)
331-
337+
self._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs)
332338
return new_subgraphs
333339

340+
def _varify_all_fusion_nodes_in_same_subgraph(
341+
self, subgraphs: List[Subgraph]
342+
) -> None:
343+
node_to_subgraph = {}
344+
for i, s in enumerate(subgraphs):
345+
for n in s.nodes:
346+
node_to_subgraph[n] = i
347+
348+
fusion_nodes_map_list = [
349+
len({node_to_subgraph[n] for n in ns}) == 1
350+
for ns in self.fusion_patterns.values()
351+
]
352+
assert all(
353+
fusion_nodes_map_list
354+
), "All fusion nodes must be in the same subgraph"
355+
logger.info("All fusion nodes are in the same subgraph.")
356+
334357
def break_subgraph_by_size(
335358
self, subgraph: Subgraph, size_to_break: int
336359
) -> Tuple[List[Subgraph], int, int]:
@@ -376,9 +399,13 @@ def step_and_validate(
376399
while True:
377400
new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs)
378401
nodes_in_first_subgraph = set(new_subgraphs[0].nodes)
402+
nodes_in_second_subgraph = set(new_subgraphs[1].nodes)
379403
leaf_node = self.get_leaf_node(nodes_in_first_subgraph)
380404
broken_fusion = self.step_if_break_fusion(
381-
new_subgraphs, leaf_node, nodes_in_first_subgraph
405+
new_subgraphs,
406+
leaf_node,
407+
nodes_in_first_subgraph,
408+
nodes_in_second_subgraph,
382409
)
383410
if not broken_fusion or len(new_subgraphs[1].nodes) == 0:
384411
break
@@ -390,57 +417,37 @@ def step_if_break_fusion(
390417
subgraphs: List[Subgraph],
391418
leaf_nodes: set[torch.fx.Node],
392419
nodes_in_first_subgraph: set[torch.fx.Node],
420+
nodes_in_second_subgraph: set[torch.fx.Node],
393421
) -> bool:
394422

395423
def add_nodes(node: torch.fx.Node) -> None:
396424
"""
397425
This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order.
398426
"""
399-
if node.op in CALLABLE_NODE_OPS and node not in nodes_in_first_subgraph:
427+
if (
428+
node.op in CALLABLE_NODE_OPS
429+
and node not in nodes_in_first_subgraph
430+
and node in nodes_in_second_subgraph
431+
):
432+
# Exclude all nodes already in the first subgraph
400433
nodes_in_first_subgraph.add(node)
434+
nodes_in_second_subgraph.remove(node)
401435
for input_node in node._input_nodes:
402436
add_nodes(input_node)
403437
subgraphs[0].nodes.append(node)
404438
subgraphs[1].nodes.remove(node)
405439

406-
def match_subgraph_and_step(node: torch.fx.Node) -> bool:
407-
added_nodes = False
408-
for op_list in NON_BREAKABLE_OP_LISTS:
409-
for i, op in enumerate(op_list):
410-
if i != len(op_list) - 1 and op in str(node.target):
411-
# Search following ops forward using BFS. We skip search previous ops because
412-
# even if it's just a subset of fusion graph, we still want it to be fused.
413-
414-
users = node.users.keys()
415-
matching_nodes: set[torch.fx.Node] = set()
416-
for following_op_idx in range(i + 1, len(op_list)):
417-
matching_nodes = set()
418-
for user in users:
419-
if op_list[following_op_idx] in str(user.target):
420-
matching_nodes.add(user)
421-
if not matching_nodes:
422-
break
423-
users = set()
424-
for matching_node in matching_nodes:
425-
for next_user in matching_node.users:
426-
users.add(next_user)
427-
428-
for matching_node in matching_nodes:
429-
added_nodes = True
430-
add_nodes(matching_node)
431-
432-
if added_nodes:
433-
# Early terminate the search if we have found a match because preceeding matches can cover following matches
434-
break
435-
436-
return True if added_nodes else False
437-
438-
found_match = False
440+
fusion_broken = False
439441
for leaf in leaf_nodes:
440-
if match_subgraph_and_step(leaf):
441-
found_match = True
442+
for node in self.fusion_patterns.get(leaf, []):
443+
if (
444+
node not in nodes_in_first_subgraph
445+
and node in nodes_in_second_subgraph
446+
):
447+
fusion_broken = True
448+
add_nodes(node)
442449

443-
return found_match
450+
return fusion_broken
444451

445452
def get_leaf_node(
446453
self, nodes_in_first_subgraph: set[torch.fx.Node]
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from typing import Dict, Set
2+
3+
import torch
4+
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
5+
from torch.ops import aten
6+
7+
8+
class ConvBNReLU(torch.nn.Module): # type: ignore[misc]
9+
def __init__(self) -> None:
10+
super().__init__()
11+
12+
def forward(
13+
self,
14+
x: torch.Tensor,
15+
weight: torch.Tensor,
16+
bias: torch.Tensor,
17+
bn_weight: torch.Tensor,
18+
bn_bias: torch.Tensor,
19+
running_mean: torch.Tensor,
20+
running_var: torch.Tensor,
21+
) -> torch.Tensor:
22+
x = aten.convolution.default(
23+
x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
24+
)
25+
x = aten._native_batch_norm_legit_no_training.default(
26+
x, bn_weight, bn_bias, running_mean, running_var, momentum=0.1, eps=1e-05
27+
)[0]
28+
x = aten.relu.default(x)
29+
return x
30+
31+
32+
class ConvReLU(torch.nn.Module): # type: ignore[misc]
33+
def __init__(self) -> None:
34+
super().__init__()
35+
36+
def forward(
37+
self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
38+
) -> torch.Tensor:
39+
x = aten.convolution.default(
40+
x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
41+
)
42+
x = aten.relu.default(x)
43+
return x
44+
45+
46+
class ConvGelu(torch.nn.Module): # type: ignore[misc]
47+
def __init__(self) -> None:
48+
super().__init__()
49+
50+
def forward(
51+
self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
52+
) -> torch.Tensor:
53+
x = aten.convolution.default(
54+
x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
55+
)
56+
x = aten.gelu.default(x)
57+
return x
58+
59+
60+
class ConvSilu(torch.nn.Module): # type: ignore[misc]
61+
def __init__(self) -> None:
62+
super().__init__()
63+
64+
def forward(
65+
self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
66+
) -> torch.Tensor:
67+
x = aten.convolution.default(
68+
x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
69+
)
70+
x = aten.silu.default(x)
71+
return x
72+
73+
74+
class MulAdd(torch.nn.Module): # type: ignore[misc]
75+
def __init__(self) -> None:
76+
super().__init__()
77+
78+
def forward(
79+
self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
80+
) -> torch.Tensor:
81+
x = aten.mul.Tensor(x, weight)
82+
x = aten.add.Tensor(x, bias)
83+
return x
84+
85+
86+
class MulMul(torch.nn.Module): # type: ignore[misc]
87+
def __init__(self) -> None:
88+
super().__init__()
89+
90+
def forward(
91+
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
92+
) -> torch.Tensor:
93+
x = aten.mul.Tensor(x, y)
94+
x = aten.mul.Tensor(x, z)
95+
return x
96+
97+
98+
All_FUSION_PATTERNS = [
99+
ConvBNReLU,
100+
ConvReLU,
101+
ConvGelu,
102+
ConvSilu,
103+
MulAdd,
104+
MulMul,
105+
]
106+
107+
108+
def get_node_in_fusion_pattern(
109+
graph: torch.fx.Graph,
110+
) -> Dict[torch.fx.Node, Set[torch.fx.Node]]:
111+
"""
112+
This function gets the nodes map of the fusion pattern from the graph.
113+
Key: node that appears in the fusion pattern
114+
Value: the list of nodes that should be fused together
115+
"""
116+
fusion_nodes = {}
117+
for pattern in All_FUSION_PATTERNS:
118+
pattern_graph = torch.fx.symbolic_trace(pattern())
119+
subgraph_matcher = SubgraphMatcher(pattern_graph.graph)
120+
match_result = subgraph_matcher.match(graph)
121+
for match in match_result:
122+
fusion_group = {
123+
node
124+
for node in match.nodes_map.values()
125+
if node
126+
and type(node) == torch.fx.Node
127+
and node.op == "call_function"
128+
and node not in match.placeholder_nodes
129+
}
130+
for node in fusion_group:
131+
fusion_nodes[node] = fusion_group
132+
133+
return fusion_nodes

py/torch_tensorrt/dynamo/partitioning/fusion_subgraphs.py

Whitespace-only changes.

0 commit comments

Comments
 (0)