Skip to content

Commit 7d8554c

Browse files
committed
Added the experiment files
1 parent b072ea1 commit 7d8554c

File tree

2 files changed

+130
-10
lines changed

2 files changed

+130
-10
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -861,10 +861,10 @@ def preserve_module_specs(
861861
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))
862862

863863
submodule_node_dict = {}
864-
for node in partitioned_module.graph.nodes:
865-
if "_run_on_acc" not in node.name:
864+
for name, node in partitioned_module.named_children():
865+
if "_run_on_acc" not in name:
866866
continue
867-
submodule_node_dict[node.name] = node
867+
submodule_node_dict[name] = node
868868

869869
preserve_module_specs(original_in_spec, original_out_spec, partitioned_module)
870870
# Store TRT replicas of Torch subgraphs
@@ -877,6 +877,12 @@ def preserve_module_specs(
877877
for attr in dir(gm):
878878
if attr.startswith("_frozen_param"):
879879
delattr(gm, attr)
880+
881+
882+
883+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS
884+
DYNAMO_CONVERTERS.disallowed_targets = set()
885+
880886
for name, _ in partitioned_module.named_children():
881887
submodule = getattr(partitioned_module, name)
882888
# filter on the GraphModule

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 121 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ast import Assert
12
import logging
23
from typing import Collection, Dict, List, Optional, Tuple
34

@@ -226,16 +227,26 @@ def partition_graph(self) -> torch.fx.GraphModule:
226227
# Remove segments smaller than the block size (with exceptions)
227228
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
228229

229-
num_of_break = self.calculate_num_of_break(subgraphs)
230-
subgraphs = self.break_subgraphs(subgraphs, num_of_break=num_of_break)
230+
# num_of_break = self.calculate_num_of_break(subgraphs)
231+
subgraphs = self.break_subgraphs_by_node(subgraphs, num_of_break=5)
231232

232233
# Set the number of TRT engines to be generated
233234
self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc])
234235

235236
# Tag the accelerated nodes and split the graph accordingly
237+
print([len(s.nodes) for s in subgraphs])
236238
self.tag(subgraphs)
237-
return self.split()
238239

240+
for s in subgraphs:
241+
print(s.nodes)
242+
243+
gm = self.split()
244+
self.weight_visited_nodes = set()
245+
[self.size_of_subgraph(s) for s in subgraphs]
246+
247+
248+
return gm
249+
239250
def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int:
240251
"""
241252
This function calculates the break period based on the number of subgraphs.
@@ -260,15 +271,16 @@ def calculate_num_of_break(self, subgraphs: List[Subgraph]) -> int:
260271
1, num_of_graphs // ((len(subgraphs) + 1) // 2)
261272
) # If there are already graph breaks, for each TRT subgraph, we break for a few times.
262273

263-
def break_subgraphs(
274+
275+
def break_subgraphs_by_node(
264276
self, subgraphs: List[Subgraph], num_of_break: int = 1
265277
) -> List[Subgraph]:
266278
"""
267279
This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
268280
"""
269-
281+
op_to_break = "add."
270282
num_of_sdpa_node = len(
271-
[node for node in self.acc_nodes if "scaled_dot" in str(node.target)]
283+
[node for node in self.acc_nodes if op_to_break in str(node.target)]
272284
)
273285
break_period = num_of_sdpa_node // num_of_break + 1
274286
current_break_idx = 0
@@ -277,7 +289,7 @@ def break_subgraphs(
277289
for subgraph in subgraphs:
278290
if subgraph.is_acc:
279291
for i, node in enumerate(subgraph.nodes):
280-
if "scaled_dot" in str(node.target):
292+
if op_to_break in str(node.target):
281293
current_num_break += 1
282294
if current_num_break % break_period != 0:
283295
continue
@@ -298,8 +310,110 @@ def break_subgraphs(
298310
)
299311
else:
300312
new_subgraphs.append(subgraph)
313+
314+
new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs)
315+
301316
return new_subgraphs
302317

318+
def break_subgraphs(
319+
self, subgraphs: List[Subgraph], num_of_break: int = 1
320+
) -> List[Subgraph]:
321+
"""
322+
This function breaks the subgraphs into smaller subgraphs at the specified frequency to save CPU memory.
323+
"""
324+
break_pos = [0, 100, 200, 300, 400]
325+
current_break_idx = 0
326+
new_subgraphs = []
327+
for subgraph in subgraphs:
328+
if subgraph.is_acc:
329+
for i, node in enumerate(subgraph.nodes):
330+
if i in break_pos:
331+
332+
new_subgraphs.append(
333+
Subgraph(
334+
is_acc=True,
335+
nodes=subgraph.nodes[current_break_idx : i + 1],
336+
device_ordinal=subgraph.device_ordinal,
337+
)
338+
)
339+
current_break_idx = i + 1
340+
new_subgraphs.append(
341+
Subgraph(
342+
is_acc=True,
343+
nodes=subgraph.nodes[current_break_idx:],
344+
device_ordinal=subgraph.device_ordinal,
345+
)
346+
)
347+
else:
348+
new_subgraphs.append(subgraph)
349+
350+
new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs)
351+
return new_subgraphs
352+
353+
def size_of_subgraph(self, subgraph: Subgraph) -> int:
354+
"""
355+
This function calculates the size of the subgraph.
356+
"""
357+
stack = subgraph.nodes.copy()
358+
size = 0
359+
while stack:
360+
node = stack.pop()
361+
if node in self.weight_visited_nodes:
362+
continue
363+
self.weight_visited_nodes.add(node)
364+
if node.op == "get_attr":
365+
weight = self.module.state_dict()[node.target]
366+
size += weight.numel() * weight.element_size()
367+
self.weight_visited_nodes.add(node)
368+
continue
369+
for input_node in node._input_nodes:
370+
if input_node not in self.weight_visited_nodes:
371+
stack.append(input_node)
372+
print(size)
373+
return size
374+
375+
def validate_and_correct_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
376+
"""
377+
This function validates the subgraphs by checking if the subgraphs are valid, and corrects the subgraphs if they are not valid.
378+
"""
379+
visited_nodes = {}
380+
print([len(s.nodes) for s in subgraphs])
381+
for i, subgraph in enumerate(subgraphs):
382+
if i == 0:
383+
for node in subgraph.nodes:
384+
visited_nodes[node] = i
385+
visited_nodes[subgraph.nodes[-1]] = i + 1
386+
continue
387+
388+
389+
elif not subgraph.is_acc:
390+
for node in subgraph.nodes:
391+
visited_nodes[subgraph.nodes[-1]] = i + 1
392+
continue
393+
394+
else:
395+
to_remove_nodes = []
396+
for j, node in enumerate(subgraph.nodes):
397+
if j == len(subgraph.nodes) - 1:
398+
visited_nodes[node] = i + 1
399+
continue
400+
subgraph_idx = 0
401+
for dep in self.deps[node]:
402+
if dep in visited_nodes:
403+
subgraph_idx = max(subgraph_idx, visited_nodes[dep])
404+
else:
405+
raise ValueError(f"Node {node} have a dependency that is not covered in the previous subgraphs. This is caused by a invalid subgraph segmentation.")
406+
if subgraph_idx != i:
407+
subgraphs[subgraph_idx].nodes.append(node)
408+
to_remove_nodes.append(node)
409+
visited_nodes[node] = subgraph_idx
410+
for node in to_remove_nodes:
411+
subgraph.nodes.remove(node)
412+
413+
return subgraphs
414+
415+
416+
303417
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
304418
"""Generates starter nodes for partitioning + segmentation"""
305419
# Starter accelerated nodes are all callable accelerated ops

0 commit comments

Comments
 (0)