Skip to content

Commit 52f1fbc

Browse files
Added AtenFluxAttentionOp
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
1 parent 6178d07 commit 52f1fbc

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16194,6 +16194,62 @@ def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
1619416194
let hasFolder = 1;
1619516195
}
1619616196

16197+
16198+
def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
16199+
AllowsTypeRefinement,
16200+
HasValueSemantics,
16201+
ReadOnly
16202+
]> {
16203+
let summary = "Generated op for `aten::flex_attention : (Tensor, Tensor, Tensor, Any?, Any?, float?, bool, Any?, bool) -> (Tensor, Tensor)`";
16204+
let description = [{
16205+
Flexible attention operator that supports custom score modification and masking.
16206+
16207+
Args:
16208+
query: Query tensor [B, H, M, E]
16209+
key: Key tensor [B, H, N, E]
16210+
value: Value tensor [B, H, N, Ev]
16211+
score_mod: Optional callable to modify attention scores (represented as None or opaque type)
16212+
block_mask: Optional BlockMask tuple for sparse attention patterns
16213+
scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim))
16214+
enable_gqa: bool for grouped query attention support
16215+
kernel_options: Optional dict of kernel configuration options
16216+
return_lse: bool to return log-sum-exp values
16217+
16218+
Returns:
16219+
- If return_lse=False: Just the output tensor [B, H, M, Ev]
16220+
- If return_lse=True: Tuple of (output [B, H, M, Ev], logsumexp [B, H, M])
16221+
16222+
Note: score_mod and block_mask are higher-order/complex types in PyTorch.
16223+
For MLIR representation, score_mod is represented as None (identity) or an opaque type,
16224+
and block_mask is represented as None or a tuple/list of tensors containing the block indices.
16225+
}];
16226+
let arguments = (ins
16227+
AnyTorchTensorType:$query,
16228+
AnyTorchTensorType:$key,
16229+
AnyTorchTensorType:$value,
16230+
AnyType:$score_mod,
16231+
AnyType:$block_mask,
16232+
AnyTorchOptionalFloatType:$scale,
16233+
Torch_BoolType:$enable_gqa,
16234+
AnyType:$kernel_options,
16235+
Torch_BoolType:$return_lse
16236+
);
16237+
let results = (outs
16238+
AnyTorchTensorType:$output,
16239+
AnyTorchOptionalTensorType:$logsumexp
16240+
);
16241+
let hasCustomAssemblyFormat = 1;
16242+
let extraClassDefinition = [{
16243+
ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
16244+
return parseDefaultTorchOp(parser, result, 9, 2);
16245+
}
16246+
void AtenFlexAttentionOp::print(OpAsmPrinter &printer) {
16247+
printDefaultTorchOp(printer, *this, 9, 2);
16248+
}
16249+
}];
16250+
}
16251+
16252+
1619716253
def Torch_AtenFloatStrOp : Torch_Op<"aten.Float.str", [
1619816254
AllowsTypeRefinement,
1619916255
HasValueSemantics,

0 commit comments

Comments
 (0)