@@ -16194,6 +16194,62 @@ def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
1619416194 let hasFolder = 1;
1619516195}
1619616196
16197+
16198+ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
16199+ AllowsTypeRefinement,
16200+ HasValueSemantics,
16201+ ReadOnly
16202+ ]> {
16203+ let summary = "Generated op for `aten::flex_attention : (Tensor, Tensor, Tensor, Any?, Any?, float?, bool, Any?, bool) -> (Tensor, Tensor)`";
16204+ let description = [{
16205+ Flexible attention operator that supports custom score modification and masking.
16206+
16207+ Args:
16208+ query: Query tensor [B, H, M, E]
16209+ key: Key tensor [B, H, N, E]
16210+ value: Value tensor [B, H, N, Ev]
16211+ score_mod: Optional callable to modify attention scores (represented as None or opaque type)
16212+ block_mask: Optional BlockMask tuple for sparse attention patterns
16213+ scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim))
16214+ enable_gqa: bool for grouped query attention support
16215+ kernel_options: Optional dict of kernel configuration options
16216+ return_lse: bool to return log-sum-exp values
16217+
16218+ Returns:
16219+ - If return_lse=False: Just the output tensor [B, H, M, Ev]
16220+ - If return_lse=True: Tuple of (output [B, H, M, Ev], logsumexp [B, H, M])
16221+
16222+ Note: score_mod and block_mask are higher-order/complex types in PyTorch.
16223+ For MLIR representation, score_mod is represented as None (identity) or an opaque type,
16224+ and block_mask is represented as None or a tuple/list of tensors containing the block indices.
16225+ }];
16226+ let arguments = (ins
16227+ AnyTorchTensorType:$query,
16228+ AnyTorchTensorType:$key,
16229+ AnyTorchTensorType:$value,
16230+ AnyType:$score_mod,
16231+ AnyType:$block_mask,
16232+ AnyTorchOptionalFloatType:$scale,
16233+ Torch_BoolType:$enable_gqa,
16234+ AnyType:$kernel_options,
16235+ Torch_BoolType:$return_lse
16236+ );
16237+ let results = (outs
16238+ AnyTorchTensorType:$output,
16239+ AnyTorchOptionalTensorType:$logsumexp
16240+ );
16241+ let hasCustomAssemblyFormat = 1;
16242+ let extraClassDefinition = [{
16243+ ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
16244+ return parseDefaultTorchOp(parser, result, 9, 2);
16245+ }
16246+ void AtenFlexAttentionOp::print(OpAsmPrinter &printer) {
16247+ printDefaultTorchOp(printer, *this, 9, 2);
16248+ }
16249+ }];
16250+ }
16251+
16252+
1619716253def Torch_AtenFloatStrOp : Torch_Op<"aten.Float.str", [
1619816254 AllowsTypeRefinement,
1619916255 HasValueSemantics,
0 commit comments