Skip to content

Commit 7e98c4c

Browse files
yushangdiSilv3S
authored andcommitted
Fix no source name in backward kernel names; Add flex_attention HOP to "original_aten" node meta (pytorch#167749)
Fixes pytorch#167706 - Add `torch.fx.experimental.proxy_tensor.set_original_aten_op()` around flex_atention HOP dispatch so we have `original_aten` populated for flex_attention - Update the usages of `original_aten` to also expect HOP in addition to OpOverload Pull Request resolved: pytorch#167749 Approved by: https://github.com/drisspg
1 parent 6058747 commit 7e98c4c

File tree

5 files changed

+90
-46
lines changed

5 files changed

+90
-46
lines changed

test/inductor/test_flex_attention.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3249,7 +3249,14 @@ def test_strided_backwards(self, device):
32493249
V_sliced = V[:, :, :-128]
32503250

32513251
out_eager = flex_attention(Q, K_sliced, V_sliced)
3252-
out_compiled = func(Q, K_sliced, V_sliced)
3252+
3253+
out_compiled, code = run_and_get_code(func, Q, K_sliced, V_sliced)
3254+
3255+
# Make sure flex attention kernels have flex_attention in name
3256+
FileCheck().check_regex("triton_tem_fused_flex_attention.*").run(code[0])
3257+
FileCheck().check_regex("triton_tem_fused_flex_attention_backward.*").run(
3258+
code[1]
3259+
)
32533260

32543261
grad = torch.rand_like(out_eager)
32553262

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3652,24 +3652,26 @@ def _call_function(
36523652
# - lifted args from tracing subgraph: [score_mod_other_buffers, mask_fn_other_buffers]
36533653
_, _, _, inp_arg_block_mask, inp_arg_scale, inp_arg_kernel_options = inp_args
36543654
block_mask = tuple(inp_arg_block_mask + (mask_fn_node,))
3655-
return wrap_fx_proxy(
3656-
tx=tx,
3657-
proxy=tx.output.create_proxy(
3658-
"call_function",
3659-
self.value,
3660-
args=inp_args[:3]
3661-
+ (
3662-
score_mod_node,
3663-
block_mask,
3664-
inp_arg_scale,
3665-
inp_arg_kernel_options,
3666-
score_mod_lifted_args,
3667-
mask_fn_lifted_args,
3655+
with torch.fx.experimental.proxy_tensor.set_original_aten_op(self.value):
3656+
proxy = wrap_fx_proxy(
3657+
tx=tx,
3658+
proxy=tx.output.create_proxy(
3659+
"call_function",
3660+
self.value,
3661+
args=inp_args[:3]
3662+
+ (
3663+
score_mod_node,
3664+
block_mask,
3665+
inp_arg_scale,
3666+
inp_arg_kernel_options,
3667+
score_mod_lifted_args,
3668+
mask_fn_lifted_args,
3669+
),
3670+
kwargs={},
36683671
),
3669-
kwargs={},
3670-
),
3671-
example_value=None,
3672-
)
3672+
example_value=None,
3673+
)
3674+
return proxy
36733675

36743676

36753677
class AutogradFunctionApplyVariable(VariableTracker):

torch/_higher_order_ops/flex_attention.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,10 @@ def trace_flex_attention(
356356
)
357357
# pyrefly: ignore [missing-attribute]
358358
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
359-
out_proxy = proxy_mode.tracer.create_proxy(
360-
"call_function", flex_attention, proxy_args, {}
361-
)
359+
with torch.fx.experimental.proxy_tensor.set_original_aten_op(flex_attention):
360+
out_proxy = proxy_mode.tracer.create_proxy(
361+
"call_function", flex_attention, proxy_args, {}
362+
)
362363
return track_tensor_tree(
363364
example_out,
364365
out_proxy,
@@ -1114,23 +1115,26 @@ def flex_attention_backward_proxy_torch_dispatch_mode(
11141115
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
11151116
]:
11161117
assert mode is not None, "Mode should always be enabled for python fallback key"
1117-
return trace_flex_attention_backward(
1118-
mode,
1119-
query,
1120-
key,
1121-
value,
1122-
out,
1123-
logsumexp,
1124-
grad_out,
1125-
grad_logsumexp,
1126-
fw_graph,
1127-
joint_graph,
1128-
block_mask,
1129-
scale,
1130-
kernel_options,
1131-
score_mod_other_buffers,
1132-
mask_mod_other_buffers,
1133-
)
1118+
with torch.fx.experimental.proxy_tensor.set_original_aten_op(
1119+
flex_attention_backward
1120+
):
1121+
return trace_flex_attention_backward(
1122+
mode,
1123+
query,
1124+
key,
1125+
value,
1126+
out,
1127+
logsumexp,
1128+
grad_out,
1129+
grad_logsumexp,
1130+
fw_graph,
1131+
joint_graph,
1132+
block_mask,
1133+
scale,
1134+
kernel_options,
1135+
score_mod_other_buffers,
1136+
mask_mod_other_buffers,
1137+
)
11341138

11351139

11361140
@flex_attention_backward.py_functionalize_impl

torch/_inductor/utils.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -781,9 +781,19 @@ def get_fused_kernel_name(
781781
) -> str:
782782
all_origins = aggregate_origins(node_schedule)
783783
if descriptive_names == "original_aten":
784+
785+
def get_origin_meta_str(origin):
786+
original_aten = origin.meta["original_aten"]
787+
key = ""
788+
if isinstance(original_aten, torch._ops.OpOverload):
789+
key = original_aten._overloadpacket.__name__
790+
elif isinstance(original_aten, torch._ops.HigherOrderOperator):
791+
key = str(original_aten.name())
792+
return key
793+
784794
# Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
785795
sources = [
786-
origin.meta["original_aten"]._overloadpacket.__name__
796+
get_origin_meta_str(origin)
787797
for origin in all_origins
788798
if origin.op == "call_function"
789799
and "original_aten" in origin.meta
@@ -794,12 +804,22 @@ def get_fused_kernel_name(
794804
# Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
795805
sources = []
796806
for origin in all_origins:
797-
if origin.op == "call_function" and "source_fn_stack" in origin.meta:
798-
source_fn = origin.meta["source_fn_stack"][-1]
807+
if origin.op == "call_function":
808+
source_fn = None
809+
suffix = ""
810+
if "source_fn_stack" in origin.meta:
811+
source_fn = origin.meta["source_fn_stack"][-1]
812+
elif "fwd_source_fn_stack" in origin.meta:
813+
# backward nodes have "fwd_source_fn_stack" instead
814+
source_fn = origin.meta["fwd_source_fn_stack"][-1]
815+
suffix = "backward"
816+
if not source_fn:
817+
continue
799818
if isinstance(source_fn[1], str):
800-
sources.append(source_fn[1])
819+
sources.append(source_fn[1] + suffix)
801820
else:
802-
sources.append(source_fn[1].__name__)
821+
sources.append(source_fn[1].__name__ + suffix)
822+
803823
sources = sorted(OrderedSet(sources))
804824
elif descriptive_names == "inductor_node":
805825
sources = [
@@ -852,11 +872,20 @@ def get_kernel_metadata(
852872

853873
for node in inductor_nodes:
854874
if "original_aten" in node.meta and node.meta["original_aten"] is not None:
855-
key = str(node.meta["original_aten"]._overloadpacket)
856-
original_aten_dict[key].append(node.name)
875+
original_aten = node.meta["original_aten"]
876+
key = None
877+
if isinstance(original_aten, torch._ops.OpOverload):
878+
key = str(original_aten._overloadpacket)
879+
elif isinstance(original_aten, torch._ops.HigherOrderOperator):
880+
key = str(original_aten.name())
881+
if key:
882+
original_aten_dict[key].append(node.name)
857883
if "from_node" in node.meta:
858884
key = node.meta["from_node"][0].name
859885
from_node_dict[key].append(node.name)
886+
elif node.meta.get("partitioner_tag") == "is_backward":
887+
# backward nodes currently don't have a "from node"
888+
from_node_dict[node.name].append(node.name)
860889
sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted"
861890
metadata = (
862891
f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], "

torch/fx/experimental/proxy_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1543,7 +1543,9 @@ def get_sym_proxy_slot(t: PySymType) -> Proxy:
15431543

15441544

15451545
@contextmanager
1546-
def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]:
1546+
def set_original_aten_op(
1547+
func: OpOverload | torch._ops.HigherOrderOperator,
1548+
) -> Generator[None, None, None]:
15471549
global ORIGINAL_ATEN
15481550
if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta():
15491551
ORIGINAL_ATEN = func

0 commit comments

Comments
 (0)