From c8c711c7547a989bf56c69c82888076ca1479b53 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 22 Oct 2025 09:41:11 -0700 Subject: [PATCH 01/38] Modified fx_importer to support hop_while_loop Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 213 +++++++++++++++++++++++- test/python/fx_importer/basic_test.py | 40 ++++- 2 files changed, 247 insertions(+), 6 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 35501741149d..2856376bf922 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -106,6 +106,7 @@ Context, DenseElementsAttr, DenseResourceElementsAttr, + FlatSymbolRefAttr, FloatAttr, BF16Type, ComplexType, @@ -834,8 +835,50 @@ def import_program( node_importer.return_node_values(loc, user_outputs, constant_output_values) self.symbol_table.insert(func_op) + + # Import all child graph modules recursively for HOPs + # Even though import_stateless_graph is deprecated as an entrypoint mechanism, + # HOP operator graphs are stateless graphs with no mutation, and it is correct + # to import them as stateless graphs. + self._import_all_child_modules( + prog, + func_name, + import_symbolic_shape_expressions + ) + return func_op + def _import_all_child_modules( + self, + prog: torch.export.ExportedProgram, + parent_name: str, + import_symbolic_shape_expressions: bool = False + ): + """Recursively import all child modules that have graphs. + + This simple approach imports all submodules recursively, which is sufficient + for HOP operations since they only reference existing submodules. + """ + for child_name, child_module in prog.graph.owning_module.named_children(): + if isinstance(child_module, GraphModule) and hasattr(child_module, 'graph'): + # Generate function name: parent_childname + child_func_name = f"{parent_name}_{child_name}_{id(child_module)}" + + # Import the child as a stateless graph (private function) + self.import_stateless_graph( + child_module.graph, + func_name=child_func_name, + func_visibility="private", + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) + + # Recursively import its children + self._import_all_child_modules( + child_module, + child_func_name, + import_symbolic_shape_expressions + ) + def import_frozen_program( self, prog: torch.export.ExportedProgram, @@ -996,9 +1039,17 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]: if node.op == "placeholder": input_types.append(self._cc.node_val_to_type(node)) elif node.op == "output": - # An output node's args[0] is the return value. This seems to - # always be "boxed" as a tuple, which we emit as multi-results. - for result_node in node.args[0]: + # An output node's args[0] is the return value. This is usually + # "boxed" as a tuple, which we emit as multi-results. However, + # for single returns it might be a single Node. + output_arg = node.args[0] + # Handle both single Node and tuple/list of Nodes + if isinstance(output_arg, (list, tuple)): + result_nodes = output_arg + else: + result_nodes = [output_arg] + + for result_node in result_nodes: if result_node is None: result_types.append( IrType.parse("!torch.none", context=self._c) @@ -1509,7 +1560,14 @@ def import_nodes( elif op == "output" and not skip_placeholders_outputs: # args[0] is a singleton tuple that we flatten into multiple # results. - operands = [self._import_argument(loc, arg) for arg in node.args[0]] + output_arg = node.args[0] + # Handle both single Node and tuple/list of Nodes + if isinstance(output_arg, (list, tuple)): + result_nodes = output_arg + else: + result_nodes = [output_arg] + + operands = [self._import_argument(loc, arg) for arg in result_nodes] func_dialect.ReturnOp(operands, loc=loc) if import_symbolic_shape_expressions: @@ -1612,6 +1670,139 @@ def _import_hop(self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperat ) handler(loc, node, hop) + def _import_hop_while_loop( + self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator + ): + """Imports the torch._higher_order_ops.while_loop HOP. + + Args format: (cond_fn, body_fn, carries) + The cond_fn and body_fn are get_attr nodes pointing to submodule graphs + that have already been imported by import_program(). + + Emits torch.prim.Loop with proper control flow structure. + """ + # while_loop HOP args: (cond_fn, body_fn, carries...) + # Unpack the first two args and the rest as carries + cond_fn_node, body_fn_node, *carries = node.args + + # Extract function names from get_attr nodes + # The subgraphs were imported with names like "main_{target}" + assert cond_fn_node.op == "get_attr", f"Expected get_attr for cond_fn, got {cond_fn_node.op}" + assert body_fn_node.op == "get_attr", f"Expected get_attr for body_fn, got {body_fn_node.op}" + + root_module = node.graph.owning_module + cond_fn_module = getattr(root_module, cond_fn_node.target, None) + body_fn_module = getattr(root_module, body_fn_node.target, None) + + # Generate function names with module IDs for uniqueness + cond_fn_name = f"main_{cond_fn_node.target}_{id(cond_fn_module)}" + body_fn_name = f"main_{body_fn_node.target}_{id(body_fn_module)}" + + # Import the carries (loop state variables) + carry_values = [] + for carry in carries: + if isinstance(carry, tuple): + # Handle tuple carries by importing each element + carry_values.extend(self._import_tuple_argument(loc, carry, None)) + else: + carry_values.append(self._import_argument(loc, carry)) + + # Determine result types from node metadata + node_val = node.meta.get("val") + if isinstance(node_val, (list, tuple)) and len(node_val) > 1: + result_types = [self._cc.value_info_to_type(v) for v in node_val] + self._multi_result_nodes.add(node) + else: + result_types = [self._cc.node_val_to_type(node)] + + # Call the condition function with initial carries to get initial condition + cond_result_type = self._cc.get_vtensor_type(torch.Size([]), torch.bool) + + initial_cond_call = Operation.create( + "func.call", + attributes={"callee": FlatSymbolRefAttr.get(cond_fn_name)}, + results=[cond_result_type], + operands=carry_values, + loc=loc, + ) + + # Convert vtensor to torch.bool + bool_conv = Operation.create( + name="torch.aten.Bool.Tensor", + results=[self._cc.torch_bool_type], + operands=[initial_cond_call.results[0]], + loc=loc, + ) + + # Create max iterations constant (INT64_MAX) + with loc: + max_iter = _make_constant_op( + "torch.constant.int", + self._cc.integer_attr(9223372036854775807, 64), + self._cc.torch_int_type, + ) + + # Create torch.prim.Loop operation with region + loop_op = Operation.create( + name="torch.prim.Loop", + results=result_types, + operands=[max_iter.results[0], bool_conv.results[0]] + carry_values, + regions=1, + loc=loc, + ) + + # Create loop body region with block arguments + # Block args: iteration counter (!torch.int) + all carry values + loop_region = loop_op.regions[0] + block_arg_types = [self._cc.torch_int_type] + result_types + with loc: + loop_block = Block.create_at_start(loop_region, block_arg_types) + + # Inside the loop body, call body function and condition function + with InsertionPoint(loop_block): + # Call body function with current carry values (skip iteration counter) + body_results_op = Operation.create( + name="func.call", + attributes={"callee": FlatSymbolRefAttr.get(body_fn_name)}, + results=result_types, + operands=list(loop_block.arguments[1:]), # Skip iteration counter + loc=loc, + ) + body_results = list(body_results_op.results) + + # Call condition function with updated carries + cond_result_loop = Operation.create( + name="func.call", + attributes={"callee": FlatSymbolRefAttr.get(cond_fn_name)}, + results=[IrType.parse("!torch.vtensor<[],i1>", context=self._c)], + operands=body_results, + loc=loc, + ).result + + # Convert to bool + cond_bool = Operation.create( + name="torch.aten.Bool.Tensor", + results=[self._cc.torch_bool_type], + operands=[cond_result_loop], + loc=loc, + ).result + + # Emit loop condition with updated carries + Operation.create( + name="torch.prim.Loop.condition", + results=[], + operands=[cond_bool] + body_results, + loc=loc, + ) + + # Bind the loop results to the node + if len(result_types) > 1: + self._multi_result_nodes.add(node) + for i, value in enumerate(loop_op.results): + self.bind_node_value(node, value, i) + else: + self.bind_node_value(node, loop_op.results[0]) + def _import_hop_auto_functionalized( self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator ): @@ -1823,6 +2014,9 @@ def _import_argument( argument_value = self.resolve_node_value(arg) elif isinstance(arg, torch_fx.immutable_collections.immutable_list): argument_value = self._import_list_argument(loc, arg, expected_jit_type) + elif isinstance(arg, tuple): + # Handle tuples of tensors (common in while_loop carries) + argument_value = self._import_tuple_argument(loc, arg, expected_jit_type) elif isinstance(expected_jit_type, torch.TensorType) and not isinstance( arg, torch.Tensor ): @@ -1930,6 +2124,13 @@ def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: loc=loc, ).result + def _import_tuple_argument( + self, loc: Location, arg: tuple, expected_jit_type + ) -> List[Value]: + """Import a tuple argument by importing each element separately.""" + # For tuples in while_loop carries, treat each element as a separate argument + return [self._import_argument(loc, elem, expected_jit_type) for elem in arg] + def _import_list_argument( self, loc: Location, arg: Sequence[NodeArgument], expected_jit_type ) -> Value: @@ -2040,6 +2241,8 @@ def _import_getitem(self, loc: Location, node: torch.fx.Node): # NOTE: the length of the list must be knowable at compile time. if ref_node not in self._unpack_list_values: node_result = self.resolve_node_value(ref_node, 0) + node_val = ref_node.meta.get("val") + if str(node_result.type) in TORCH_LIST_TYPES: result_types = [ self._cc.value_info_to_type(v) for v in ref_node.meta["val"] @@ -2510,4 +2713,4 @@ def aten__embedding_bag_forward_only_default(node: torch_fx.Node): def node_canonicalize(node: torch_fx.Node): if node.target in NODE_CANONICALIZE: return NODE_CANONICALIZE[node.target](node) - return node + return node \ No newline at end of file diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 7a5660b028b3..d6df5b04381f 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -205,6 +205,44 @@ def forward(self): ) print(m) +@run +# CHECK-LABEL: test_while_loop_two_returns +# CHECK: func.func @test_while_loop_two_returns +# CHECK-SAME: -> (!torch.vtensor<[],si64>, !torch.vtensor<[4,4],f32>) + +# Validate literal/init plumbing: +# CHECK: %[[ZERO:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> +# CHECK: %[[NONE:.*]] = torch.constant.none +# CHECK: %[[CLONE:.*]] = torch.aten.clone %[[ZERO]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + +# CHECK: %[[COND:.*]] = call @while_loop_cond_graph_{{[0-9]+}}(%[[CLONE]] +# CHECK: torch.aten.Bool.Tensor %[[COND]] +# CHECK: %[[MAX_ITER:.*]] = torch.constant.int 9223372036854775807 +# CHECK: torch.prim.Loop %[[MAX_ITER]] + +# CHECK: func.func private @while_loop_cond_graph_{{[0-9]+}} +# CHECK: torch.aten.lt.Scalar + +# CHECK: func.func private @while_loop_body_graph_{{[0-9]+}} +# CHECK: torch.aten.add.Scalar +# CHECK: torch.aten.mul.Tensor +def test_while_loop_two_returns(): + class M(nn.Module): + def forward(self, x): + # Simple while_loop that carries a scalar and a tensor. + def body(i, x): + return i + 1, x * x + i0 = torch.tensor(0) + from torch._higher_order_ops.while_loop import while_loop + + out_i, out_x = while_loop( + lambda i, x: i < 3, body, (i0, x) + ) + return out_i, out_x + + # Export -> import to Torch-MLIR + m = fx.export_and_import(M(), torch.randn(4, 4), func_name="test_while_loop_two_returns") + print(m) @run # CHECK-LABEL: test_stack_trace @@ -229,4 +267,4 @@ def foo(x, y): y = torch.randn(128, 128) m = fx.export_and_import(Basic(), x, y, func_name="test_stack_trace") mlir_asm = m.operation.get_asm(enable_debug_info=True) - print(mlir_asm) + print(mlir_asm) \ No newline at end of file From b250583a194be0557d2df0e663d8765a5e129c9b Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 23 Oct 2025 06:26:41 -0700 Subject: [PATCH 02/38] Addressed Comments | Simplified unique child_func_name creation Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 79 +++++++++++++++++-------- test/python/fx_importer/basic_test.py | 10 ++-- 2 files changed, 58 insertions(+), 31 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 2856376bf922..bfb19158979b 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -537,6 +537,8 @@ class FxImporter: "_py_attr_tracker", "_hooks", "symbol_table", + "_graph_module_to_func_name", + "_func_name_counter", ] def __init__( @@ -564,6 +566,8 @@ def __init__( self._hooks = hooks or FxImporterHooks() self.symbol_table = SymbolTable(self._m.operation) self._hooks.prepare_module(self._m.operation) + self._graph_module_to_func_name: Dict[int, str] = {} + self._func_name_counter: int = 0 def _config_check(self): for dname in REQUIRED_DIALCTS: @@ -824,6 +828,15 @@ def import_program( for node, (buffer_value, info) in buffer_bindings.items(): node_importer.lazy_import_buffer(loc, node, buffer_value, info) + # Import all child graph modules recursively for HOPs BEFORE importing nodes + # This is necessary because HOP nodes need to reference these functions. + # Even though import_stateless_graph is deprecated as an entrypoint mechanism, + # HOP operator graphs are stateless graphs with no mutation, and it is correct + # to import them as stateless graphs. + self._import_all_child_modules( + prog, func_name, import_symbolic_shape_expressions + ) + # Import all nodes and return. node_importer.import_nodes( all_producer_nodes.values(), @@ -836,34 +849,40 @@ def import_program( self.symbol_table.insert(func_op) - # Import all child graph modules recursively for HOPs - # Even though import_stateless_graph is deprecated as an entrypoint mechanism, - # HOP operator graphs are stateless graphs with no mutation, and it is correct - # to import them as stateless graphs. - self._import_all_child_modules( - prog, - func_name, - import_symbolic_shape_expressions - ) - return func_op def _import_all_child_modules( self, - prog: torch.export.ExportedProgram, + prog_or_module: Union[torch.export.ExportedProgram, GraphModule], parent_name: str, - import_symbolic_shape_expressions: bool = False + import_symbolic_shape_expressions: bool = False, ): """Recursively import all child modules that have graphs. This simple approach imports all submodules recursively, which is sufficient for HOP operations since they only reference existing submodules. """ - for child_name, child_module in prog.graph.owning_module.named_children(): - if isinstance(child_module, GraphModule) and hasattr(child_module, 'graph'): - # Generate function name: parent_childname - child_func_name = f"{parent_name}_{child_name}_{id(child_module)}" - + # Get the owning module from either ExportedProgram or GraphModule + if isinstance(prog_or_module, GraphModule): + owning_module = prog_or_module + else: + owning_module = prog_or_module.graph.owning_module + + for child_name, child_module in owning_module.named_children(): + if isinstance(child_module, GraphModule) and hasattr(child_module, "graph"): + # Check if we've already assigned a name to this module + module_id = id(child_module) + # Module already imported, skip it + if module_id in self._graph_module_to_func_name: + continue + # Use the child_name directly - PyTorch already provides unique names + child_func_name = child_name + # Handle collision by adding counter suffix if name already exists + if child_func_name in self._graph_module_to_func_name.values(): + child_func_name = f"{child_name}_{self._func_name_counter}" + self._func_name_counter += 1 + # Store the mapping for future lookups + self._graph_module_to_func_name[module_id] = child_func_name # Import the child as a stateless graph (private function) self.import_stateless_graph( child_module.graph, @@ -874,9 +893,7 @@ def _import_all_child_modules( # Recursively import its children self._import_all_child_modules( - child_module, - child_func_name, - import_symbolic_shape_expressions + child_module, child_func_name, import_symbolic_shape_expressions ) def import_frozen_program( @@ -1015,6 +1032,13 @@ def import_stateless_graph( self._cc, entry_block, ) + + # Import child modules (for HOPs) before importing nodes + if hasattr(g, 'owning_module') and g.owning_module is not None: + self._import_all_child_modules( + g.owning_module, func_name, import_symbolic_shape_expressions + ) + node_importer.import_nodes( g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions ) @@ -1686,17 +1710,20 @@ def _import_hop_while_loop( cond_fn_node, body_fn_node, *carries = node.args # Extract function names from get_attr nodes - # The subgraphs were imported with names like "main_{target}" - assert cond_fn_node.op == "get_attr", f"Expected get_attr for cond_fn, got {cond_fn_node.op}" - assert body_fn_node.op == "get_attr", f"Expected get_attr for body_fn, got {body_fn_node.op}" + assert ( + cond_fn_node.op == "get_attr" + ), f"Expected get_attr for cond_fn, got {cond_fn_node.op}" + assert ( + body_fn_node.op == "get_attr" + ), f"Expected get_attr for body_fn, got {body_fn_node.op}" root_module = node.graph.owning_module cond_fn_module = getattr(root_module, cond_fn_node.target, None) body_fn_module = getattr(root_module, body_fn_node.target, None) # Generate function names with module IDs for uniqueness - cond_fn_name = f"main_{cond_fn_node.target}_{id(cond_fn_module)}" - body_fn_name = f"main_{body_fn_node.target}_{id(body_fn_module)}" + cond_fn_name = self.fx_importer._graph_module_to_func_name.get(id(cond_fn_module)) + body_fn_name = self.fx_importer._graph_module_to_func_name.get(id(body_fn_module)) # Import the carries (loop state variables) carry_values = [] @@ -2713,4 +2740,4 @@ def aten__embedding_bag_forward_only_default(node: torch_fx.Node): def node_canonicalize(node: torch_fx.Node): if node.target in NODE_CANONICALIZE: return NODE_CANONICALIZE[node.target](node) - return node \ No newline at end of file + return node diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index d6df5b04381f..fbfa9df69de0 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -235,13 +235,13 @@ def body(i, x): i0 = torch.tensor(0) from torch._higher_order_ops.while_loop import while_loop - out_i, out_x = while_loop( - lambda i, x: i < 3, body, (i0, x) - ) + out_i, out_x = while_loop(lambda i, x: i < 3, body, (i0, x)) return out_i, out_x # Export -> import to Torch-MLIR - m = fx.export_and_import(M(), torch.randn(4, 4), func_name="test_while_loop_two_returns") + m = fx.export_and_import( + M(), torch.randn(4, 4), func_name="test_while_loop_two_returns" + ) print(m) @run @@ -267,4 +267,4 @@ def foo(x, y): y = torch.randn(128, 128) m = fx.export_and_import(Basic(), x, y, func_name="test_stack_trace") mlir_asm = m.operation.get_asm(enable_debug_info=True) - print(mlir_asm) \ No newline at end of file + print(mlir_asm) From db1e7e9f544edf356d3f5df2bc95a5c0073a10fa Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 24 Oct 2025 01:53:37 -0700 Subject: [PATCH 03/38] Addressed comments Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 129 ++++++++++++++---------- 1 file changed, 75 insertions(+), 54 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index bfb19158979b..198ffa86abef 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -566,7 +566,9 @@ def __init__( self._hooks = hooks or FxImporterHooks() self.symbol_table = SymbolTable(self._m.operation) self._hooks.prepare_module(self._m.operation) + # Used specifically in HOPs to map module IDs to function names self._graph_module_to_func_name: Dict[int, str] = {} + # Handles collision of function names in the same module self._func_name_counter: int = 0 def _config_check(self): @@ -834,7 +836,7 @@ def import_program( # HOP operator graphs are stateless graphs with no mutation, and it is correct # to import them as stateless graphs. self._import_all_child_modules( - prog, func_name, import_symbolic_shape_expressions + prog.graph.owning_module, func_name, import_symbolic_shape_expressions ) # Import all nodes and return. @@ -853,49 +855,27 @@ def import_program( def _import_all_child_modules( self, - prog_or_module: Union[torch.export.ExportedProgram, GraphModule], + module: GraphModule, parent_name: str, import_symbolic_shape_expressions: bool = False, ): - """Recursively import all child modules that have graphs. + """Import all child modules by delegating to import_graph_module. - This simple approach imports all submodules recursively, which is sufficient - for HOP operations since they only reference existing submodules. + This is a thin wrapper that extracts the owning module and delegates to + import_graph_module for each child. + + Note: This only imports children, not the parent module itself. """ - # Get the owning module from either ExportedProgram or GraphModule - if isinstance(prog_or_module, GraphModule): - owning_module = prog_or_module - else: - owning_module = prog_or_module.graph.owning_module - for child_name, child_module in owning_module.named_children(): + for child_name, child_module in module.named_children(): if isinstance(child_module, GraphModule) and hasattr(child_module, "graph"): - # Check if we've already assigned a name to this module - module_id = id(child_module) - # Module already imported, skip it - if module_id in self._graph_module_to_func_name: - continue - # Use the child_name directly - PyTorch already provides unique names - child_func_name = child_name - # Handle collision by adding counter suffix if name already exists - if child_func_name in self._graph_module_to_func_name.values(): - child_func_name = f"{child_name}_{self._func_name_counter}" - self._func_name_counter += 1 - # Store the mapping for future lookups - self._graph_module_to_func_name[module_id] = child_func_name - # Import the child as a stateless graph (private function) - self.import_stateless_graph( - child_module.graph, - func_name=child_func_name, + self.import_graph_module( + child_module, + func_name=child_name, func_visibility="private", import_symbolic_shape_expressions=import_symbolic_shape_expressions, ) - # Recursively import its children - self._import_all_child_modules( - child_module, child_func_name, import_symbolic_shape_expressions - ) - def import_frozen_program( self, prog: torch.export.ExportedProgram, @@ -993,13 +973,56 @@ def import_frozen_program( import_symbolic_shape_expressions=import_symbolic_shape_expressions, ) - def import_graph_module(self, gm: GraphModule) -> Operation: + def import_graph_module( + self, + gm: GraphModule, + *, + func_name: str = "main", + func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, + ) -> Operation: """Low-level import of a GraphModule assuming that it has been functionalized. + This method recursively imports all child GraphModules first, then imports + the provided GraphModule itself. This ensures that any higher-order operations + that reference child modules will find them already imported. + TODO: This mechanism is deprecated by the `import_program` entry-point and it should be removed when no longer required for backwards compatibility. + + Note: This method should only be used for HOPs. """ - return self.import_stateless_graph(gm.graph) + # Store the mapping for this module itself (HOPs will need to look this up) + module_id = id(gm) + if module_id not in self._graph_module_to_func_name: + # Ensure the func_name is unique + final_func_name = func_name + if func_name in self._graph_module_to_func_name.values(): + final_func_name = f"{func_name}_{self._func_name_counter}" + self._func_name_counter += 1 + self._graph_module_to_func_name[module_id] = final_func_name + else: + # Module already imported, use existing name + final_func_name = self._graph_module_to_func_name[module_id] + + # First, recursively import all child modules + for child_name, child_module in gm.named_children(): + if isinstance(child_module, GraphModule) and hasattr(child_module, "graph"): + # Recursively import this child (which will handle its own mapping) + self.import_graph_module( + child_module, + func_name=child_name, + func_visibility="private", + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) + + # Then import this module's own graph + return self.import_stateless_graph( + gm.graph, + func_name=final_func_name, + func_visibility=func_visibility, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) def import_stateless_graph( self, @@ -1033,11 +1056,9 @@ def import_stateless_graph( entry_block, ) - # Import child modules (for HOPs) before importing nodes - if hasattr(g, 'owning_module') and g.owning_module is not None: - self._import_all_child_modules( - g.owning_module, func_name, import_symbolic_shape_expressions - ) + # Note: Child module importing is handled by import_graph_module, which is + # the recommended entry point. This method is deprecated and should only be + # used for stateless graphs that truly have no child modules. node_importer.import_nodes( g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions @@ -1068,10 +1089,11 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]: # for single returns it might be a single Node. output_arg = node.args[0] # Handle both single Node and tuple/list of Nodes - if isinstance(output_arg, (list, tuple)): - result_nodes = output_arg - else: - result_nodes = [output_arg] + result_nodes = ( + output_arg + if isinstance(output_arg, (list, tuple)) + else [output_arg] + ) for result_node in result_nodes: if result_node is None: @@ -1586,11 +1608,11 @@ def import_nodes( # results. output_arg = node.args[0] # Handle both single Node and tuple/list of Nodes - if isinstance(output_arg, (list, tuple)): - result_nodes = output_arg - else: - result_nodes = [output_arg] - + result_nodes = ( + output_arg + if isinstance(output_arg, (list, tuple)) + else [output_arg] + ) operands = [self._import_argument(loc, arg) for arg in result_nodes] func_dialect.ReturnOp(operands, loc=loc) @@ -1722,8 +1744,8 @@ def _import_hop_while_loop( body_fn_module = getattr(root_module, body_fn_node.target, None) # Generate function names with module IDs for uniqueness - cond_fn_name = self.fx_importer._graph_module_to_func_name.get(id(cond_fn_module)) - body_fn_name = self.fx_importer._graph_module_to_func_name.get(id(body_fn_module)) + cond_fn_name = self.fx_importer._graph_module_to_func_name[id(cond_fn_module)] + body_fn_name = self.fx_importer._graph_module_to_func_name[id(body_fn_module)] # Import the carries (loop state variables) carry_values = [] @@ -1765,7 +1787,7 @@ def _import_hop_while_loop( with loc: max_iter = _make_constant_op( "torch.constant.int", - self._cc.integer_attr(9223372036854775807, 64), + torch.iinfo(torch.int64).max, self._cc.torch_int_type, ) @@ -2153,7 +2175,7 @@ def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: def _import_tuple_argument( self, loc: Location, arg: tuple, expected_jit_type - ) -> List[Value]: + ) -> list[Value]: """Import a tuple argument by importing each element separately.""" # For tuples in while_loop carries, treat each element as a separate argument return [self._import_argument(loc, elem, expected_jit_type) for elem in arg] @@ -2268,7 +2290,6 @@ def _import_getitem(self, loc: Location, node: torch.fx.Node): # NOTE: the length of the list must be knowable at compile time. if ref_node not in self._unpack_list_values: node_result = self.resolve_node_value(ref_node, 0) - node_val = ref_node.meta.get("val") if str(node_result.type) in TORCH_LIST_TYPES: result_types = [ From d9646c6108cdb2517e158cc7fb42046a2f035070 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 24 Oct 2025 01:55:27 -0700 Subject: [PATCH 04/38] Formatting Signed-off-by: Keshav Vinayak Jha --- test/python/fx_importer/basic_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index fbfa9df69de0..7f52be78a433 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -205,6 +205,7 @@ def forward(self): ) print(m) + @run # CHECK-LABEL: test_while_loop_two_returns # CHECK: func.func @test_while_loop_two_returns @@ -223,6 +224,7 @@ def forward(self): # CHECK: func.func private @while_loop_cond_graph_{{[0-9]+}} # CHECK: torch.aten.lt.Scalar + # CHECK: func.func private @while_loop_body_graph_{{[0-9]+}} # CHECK: torch.aten.add.Scalar # CHECK: torch.aten.mul.Tensor @@ -232,6 +234,7 @@ def forward(self, x): # Simple while_loop that carries a scalar and a tensor. def body(i, x): return i + 1, x * x + i0 = torch.tensor(0) from torch._higher_order_ops.while_loop import while_loop @@ -244,6 +247,7 @@ def body(i, x): ) print(m) + @run # CHECK-LABEL: test_stack_trace # CHECK: #loc[[LOC1:.+]] = loc( From cc03291c9e333e732d9cc414f534badc139463e5 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 24 Oct 2025 02:16:26 -0700 Subject: [PATCH 05/38] Added children module imports to import_frozen_program flow Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 198ffa86abef..5dbe91da6c26 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -966,6 +966,14 @@ def import_frozen_program( node.replace_all_uses_with(replacement) g.erase_node(node) + # Import child modules for HOPs before importing the main graph + # This ensures that any higher-order operations (like while_loop) can + # reference the already-imported child module functions + if hasattr(g, "owning_module") and g.owning_module is not None: + self._import_all_child_modules( + g.owning_module, func_name, import_symbolic_shape_expressions + ) + return self.import_stateless_graph( g, func_name=func_name, From 6a70e1c2fc06ae289391c36dc63d7e1fbbbef3ab Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 24 Oct 2025 02:29:02 -0700 Subject: [PATCH 06/38] Formatting and reordered CHECKs Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 4 ++-- test/python/fx_importer/basic_test.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 5dbe91da6c26..c5b9ffdad851 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1735,7 +1735,7 @@ def _import_hop_while_loop( Emits torch.prim.Loop with proper control flow structure. """ - # while_loop HOP args: (cond_fn, body_fn, carries...) + # while_loop HOP args: (cond_fn, body_fn, car`ries...) # Unpack the first two args and the rest as carries cond_fn_node, body_fn_node, *carries = node.args @@ -1795,7 +1795,7 @@ def _import_hop_while_loop( with loc: max_iter = _make_constant_op( "torch.constant.int", - torch.iinfo(torch.int64).max, + self._cc.integer_attr(torch.iinfo(torch.int64).max, 64), self._cc.torch_int_type, ) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 7f52be78a433..3e51152bc46c 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -208,26 +208,23 @@ def forward(self): @run # CHECK-LABEL: test_while_loop_two_returns +# Check that helper functions are emitted first +# CHECK: func.func private @while_loop_cond_graph_{{[0-9]+}} +# CHECK: torch.aten.lt.Scalar +# CHECK: func.func private @while_loop_body_graph_{{[0-9]+}} +# CHECK: torch.aten.add.Scalar +# CHECK: torch.aten.mul.Tensor +# Then check the main function # CHECK: func.func @test_while_loop_two_returns # CHECK-SAME: -> (!torch.vtensor<[],si64>, !torch.vtensor<[4,4],f32>) - # Validate literal/init plumbing: # CHECK: %[[ZERO:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> # CHECK: %[[NONE:.*]] = torch.constant.none # CHECK: %[[CLONE:.*]] = torch.aten.clone %[[ZERO]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - # CHECK: %[[COND:.*]] = call @while_loop_cond_graph_{{[0-9]+}}(%[[CLONE]] # CHECK: torch.aten.Bool.Tensor %[[COND]] # CHECK: %[[MAX_ITER:.*]] = torch.constant.int 9223372036854775807 # CHECK: torch.prim.Loop %[[MAX_ITER]] - -# CHECK: func.func private @while_loop_cond_graph_{{[0-9]+}} -# CHECK: torch.aten.lt.Scalar - - -# CHECK: func.func private @while_loop_body_graph_{{[0-9]+}} -# CHECK: torch.aten.add.Scalar -# CHECK: torch.aten.mul.Tensor def test_while_loop_two_returns(): class M(nn.Module): def forward(self, x): From 85e3acd790cd3837645ce06fc6a9ca68e279c0fe Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 24 Oct 2025 09:08:50 -0700 Subject: [PATCH 07/38] =?UTF-8?q?Changes=20done=20to=20TorchToScf:=20Chang?= =?UTF-8?q?e=201:=20Converts=20builtin=20tensors=20=E2=86=92=20Torch=20ten?= =?UTF-8?q?sors=20when=20entering=20the=20loop=20body=20Change=202:=20Ensu?= =?UTF-8?q?res=20Torch=20tensors=20=E2=86=92=20builtin=20tensors=20when=20?= =?UTF-8?q?yielding=20back=20to=20the=20loop=20condition=20Without=20these?= =?UTF-8?q?=20fixes,=20the=20conversion=20would=20fail=20when=20while=20lo?= =?UTF-8?q?ops=20carry=20tensor=20values?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also modified basic_test.py FILECHECK statements. Signed-off-by: Keshav Vinayak Jha --- lib/Conversion/TorchToSCF/TorchToSCF.cpp | 12 ++++-------- test/python/fx_importer/basic_test.py | 18 ++++++++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 27e0a61f4b31..6f970de06592 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -150,6 +150,10 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { targetType = Torch::IntType::get(op->getContext()); torchArg = typeConverter->materializeSourceConversion( rewriter, scfWhileOp.getLoc(), targetType, {to}); + } else if (auto tty = dyn_cast(targetType)) { + targetType = op.getIterArgsInit()[barg.index()].getType(); + torchArg = typeConverter->materializeSourceConversion( + rewriter, scfWhileOp.getLoc(), targetType, {to}); } if (!torchArg) return rewriter.notifyMatchFailure(op, @@ -173,14 +177,6 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { "unsupported type of the operand"); loopConditionIterArgs.push_back(shouldContinue); for (auto torchArg : primLoopConditionOp.getIterArgs()) { - Type torchType = torchArg.getType(); - - // If the argument is a torch tensor, directly add it in the list of - // iter args. - if (isa(torchType)) { - loopConditionIterArgs.push_back(torchArg); - continue; - } Value arg = typeConverter->materializeTargetConversion( rewriter, scfWhileOp->getLoc(), typeConverter->convertType(torchArg.getType()), {torchArg}); diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 3e51152bc46c..e490a8d3636c 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -209,22 +209,28 @@ def forward(self): @run # CHECK-LABEL: test_while_loop_two_returns # Check that helper functions are emitted first -# CHECK: func.func private @while_loop_cond_graph_{{[0-9]+}} +# CHECK: func.func private @while_loop_cond_graph_{{[0-9]+}}(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[],i1> # CHECK: torch.aten.lt.Scalar -# CHECK: func.func private @while_loop_body_graph_{{[0-9]+}} +# CHECK: func.func private @while_loop_body_graph_{{[0-9]+}}(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[4,4],f32>) -> (!torch.vtensor<[],si64>, !torch.vtensor<[4,4],f32>) # CHECK: torch.aten.add.Scalar # CHECK: torch.aten.mul.Tensor # Then check the main function -# CHECK: func.func @test_while_loop_two_returns +# CHECK: func.func @test_while_loop_two_returns(%arg0: !torch.vtensor<[4,4],f32>) # CHECK-SAME: -> (!torch.vtensor<[],si64>, !torch.vtensor<[4,4],f32>) # Validate literal/init plumbing: # CHECK: %[[ZERO:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> # CHECK: %[[NONE:.*]] = torch.constant.none # CHECK: %[[CLONE:.*]] = torch.aten.clone %[[ZERO]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> -# CHECK: %[[COND:.*]] = call @while_loop_cond_graph_{{[0-9]+}}(%[[CLONE]] -# CHECK: torch.aten.Bool.Tensor %[[COND]] +# CHECK: %[[COND:.*]] = call @while_loop_cond_graph_{{[0-9]+}}(%[[CLONE]], %arg0) +# CHECK: %[[BOOL:.*]] = torch.aten.Bool.Tensor %[[COND]] # CHECK: %[[MAX_ITER:.*]] = torch.constant.int 9223372036854775807 -# CHECK: torch.prim.Loop %[[MAX_ITER]] +# CHECK: %[[RESULT:.*]]:2 = torch.prim.Loop %[[MAX_ITER]], %[[BOOL]], init(%[[CLONE]], %arg0) +# CHECK: ^bb0(%arg1: !torch.int, %arg2: !torch.vtensor<[],si64>, %arg3: !torch.vtensor<[4,4],f32>): +# CHECK: %[[BODY_RESULT:.*]]:2 = func.call @while_loop_body_graph_{{[0-9]+}}(%arg2, %arg3) +# CHECK: %[[COND_RESULT:.*]] = func.call @while_loop_cond_graph_{{[0-9]+}}(%[[BODY_RESULT]]#0, %[[BODY_RESULT]]#1) +# CHECK: %[[BOOL_RESULT:.*]] = torch.aten.Bool.Tensor %[[COND_RESULT]] +# CHECK: torch.prim.Loop.condition %[[BOOL_RESULT]], iter(%[[BODY_RESULT]]#0, %[[BODY_RESULT]]#1 : !torch.vtensor<[],si64>, !torch.vtensor<[4,4],f32>) +# CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 def test_while_loop_two_returns(): class M(nn.Module): def forward(self, x): From e1ff87daedf67b8cec846a125397865c7b1d4292 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Mon, 27 Oct 2025 08:41:02 -0700 Subject: [PATCH 08/38] Added Control flow test Signed-off-by: Keshav Vinayak Jha --- .../test_suite/control_flow.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py index a04114043583..451bb21a17eb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py @@ -78,3 +78,36 @@ def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils): x_test = torch.zeros([7, 9]).float() module.forward(x_test) + + +# ============================================================================== + + +class TorchPrimLoopWhileLikeHOPModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([7, 9], torch.float32, True), + ] + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + from torch._higher_order_ops.while_loop import while_loop + + def body_fn(i, x): + return i + 1, x + 1 + + i0 = torch.tensor(0) + + out_i, out_x = while_loop(lambda i, x: i < 3, body_fn, (i0, x)) + return out_i, out_x + + +@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeHOPModule()) +def TorchPrimLoopWhileLikeHOPModule_basic(module, tu: TestUtils): + x_test = torch.zeros([7, 9]).float() + + module.forward(x_test) From 558c7db96b6a92ab3f6469e30ffedaf896ecf25f Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 28 Oct 2025 00:20:21 -0700 Subject: [PATCH 09/38] Cannot FX trace HOP Signed-off-by: Keshav Vinayak Jha --- .../test_suite/control_flow.py | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py index 451bb21a17eb..a04114043583 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py @@ -78,36 +78,3 @@ def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils): x_test = torch.zeros([7, 9]).float() module.forward(x_test) - - -# ============================================================================== - - -class TorchPrimLoopWhileLikeHOPModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([7, 9], torch.float32, True), - ] - ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - from torch._higher_order_ops.while_loop import while_loop - - def body_fn(i, x): - return i + 1, x + 1 - - i0 = torch.tensor(0) - - out_i, out_x = while_loop(lambda i, x: i < 3, body_fn, (i0, x)) - return out_i, out_x - - -@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeHOPModule()) -def TorchPrimLoopWhileLikeHOPModule_basic(module, tu: TestUtils): - x_test = torch.zeros([7, 9]).float() - - module.forward(x_test) From 39d5b24e3d1e4384428c608f146348777b611e13 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 28 Oct 2025 00:29:49 -0700 Subject: [PATCH 10/38] Added flex_attention hop function --- python/torch_mlir/extras/fx_importer.py | 215 +++++++++++++++++++++++- 1 file changed, 207 insertions(+), 8 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index c5b9ffdad851..c6d70cd0b71c 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1771,7 +1771,6 @@ def _import_hop_while_loop( self._multi_result_nodes.add(node) else: result_types = [self._cc.node_val_to_type(node)] - # Call the condition function with initial carries to get initial condition cond_result_type = self._cc.get_vtensor_type(torch.Size([]), torch.bool) @@ -1782,7 +1781,6 @@ def _import_hop_while_loop( operands=carry_values, loc=loc, ) - # Convert vtensor to torch.bool bool_conv = Operation.create( name="torch.aten.Bool.Tensor", @@ -1790,7 +1788,6 @@ def _import_hop_while_loop( operands=[initial_cond_call.results[0]], loc=loc, ) - # Create max iterations constant (INT64_MAX) with loc: max_iter = _make_constant_op( @@ -1814,7 +1811,6 @@ def _import_hop_while_loop( block_arg_types = [self._cc.torch_int_type] + result_types with loc: loop_block = Block.create_at_start(loop_region, block_arg_types) - # Inside the loop body, call body function and condition function with InsertionPoint(loop_block): # Call body function with current carry values (skip iteration counter) @@ -1826,7 +1822,6 @@ def _import_hop_while_loop( loc=loc, ) body_results = list(body_results_op.results) - # Call condition function with updated carries cond_result_loop = Operation.create( name="func.call", @@ -1835,7 +1830,6 @@ def _import_hop_while_loop( operands=body_results, loc=loc, ).result - # Convert to bool cond_bool = Operation.create( name="torch.aten.Bool.Tensor", @@ -1843,7 +1837,6 @@ def _import_hop_while_loop( operands=[cond_result_loop], loc=loc, ).result - # Emit loop condition with updated carries Operation.create( name="torch.prim.Loop.condition", @@ -1851,7 +1844,6 @@ def _import_hop_while_loop( operands=[cond_bool] + body_results, loc=loc, ) - # Bind the loop results to the node if len(result_types) > 1: self._multi_result_nodes.add(node) @@ -1917,6 +1909,213 @@ def _import_hop_auto_functionalized( for i, value in enumerate(operation.results): self.bind_node_value(node, value, i + bind_none) + def _import_hop_flex_attention( + self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator + ): + """Imports the torch._higher_order_ops.flex_attention HOP. + + Args format: (query, key, value, score_mod, block_mask, scale, kernel_options, ...) + The score_mod is a submodule/callable that has been imported as a private function. + The block_mask is a tuple: (kv_num_blocks, kv_indices, ..., mask_mod) + + This creates a call to aten.flex_attention with function symbol references. + """ + # flex_attention HOP args from PyTorch: + # (query, key, value, score_mod, block_mask, scale, kernel_options, return_lse_tuple, ...) + if len(node.args) < 6: + raise ValueError(f"flex_attention expects at least 6 arguments, got {len(node.args)}") + + query_arg, key_arg, value_arg, score_mod_arg, block_mask_arg, scale_arg = node.args[:6] + kernel_options = node.args[6] if len(node.args) > 6 else {} + + # Import Q, K, V tensors + query = self._import_argument(loc, query_arg, None) + key = self._import_argument(loc, key_arg, None) + value = self._import_argument(loc, value_arg, None) + + # Handle score_mod: extract function reference from submodule + score_mod_ref = None + if score_mod_arg is not None and isinstance(score_mod_arg, torch_fx.Node): + # score_mod is a GraphModule reference from get_attr + root_module = node.graph.owning_module + if hasattr(score_mod_arg, 'target'): + score_mod_name = score_mod_arg.target + score_mod_module = getattr(root_module, score_mod_name, None) + if score_mod_module is not None: + # The function was imported by _import_all_child_modules with this naming convention + score_mod_func_name = f"main_{score_mod_name}_{id(score_mod_module)}" + score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name) + + # Handle block_mask: extract mask_mod function and tensor components + # block_mask tuple format: (kv_num_blocks, kv_indices, q_num_blocks, q_indices, + # kv_block_size, q_block_size, ..., mask_mod) + mask_mod_ref = None + block_mask_tensors = [] + kv_block_size = None + q_block_size = None + + if block_mask_arg is not None and isinstance(block_mask_arg, tuple): + # Parse the block_mask tuple structure + # First two entries: kv_num_blocks (int), kv_indices (tensor) + # Next two: q_num_blocks (tensor), q_indices (tensor) + # Then: scalar dimensions and the mask_mod function at the end + root_module = node.graph.owning_module + + for i, component in enumerate(block_mask_arg): + if isinstance(component, torch_fx.Node): + # Check if it's a tensor or a submodule reference + if component.op == "get_attr" and hasattr(root_module, component.target): + obj = getattr(root_module, component.target) + # Check if it's a GraphModule (mask_mod) or a tensor + if isinstance(obj, GraphModule): + # This is the mask_mod function + mask_mod_func_name = f"main_{component.target}_{id(obj)}" + mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name) + else: + # It's a tensor (block indices) + block_mask_tensors.append(self._import_argument(loc, component, None)) + else: + # Regular tensor argument + block_mask_tensors.append(self._import_argument(loc, component, None)) + elif isinstance(component, int): + # Scalar dimensions (KV_BLOCK_SIZE, Q_BLOCK_SIZE) + if kv_block_size is None: + kv_block_size = component + elif q_block_size is None: + q_block_size = component + + # Import scale (float or None) + if scale_arg is None: + scale = Operation.create( + "torch.constant.none", + results=[self._cc.torch_none_type], + loc=loc, + ).result + elif isinstance(scale_arg, (int, float)): + with loc: + scale = _make_constant_op( + "torch.constant.float", + FloatAttr.get_f64(float(scale_arg)), + self._cc.torch_float_type, + ).result + else: + scale = self._import_argument(loc, scale_arg, None) + + # Get enable_gqa from kernel_options if present + enable_gqa = False + if isinstance(kernel_options, dict) and "enable_gqa" in kernel_options: + enable_gqa = kernel_options["enable_gqa"] + with loc: + enable_gqa_value = _make_constant_op( + "torch.constant.bool", + self._cc.integer_attr(1 if enable_gqa else 0, 1), + self._cc.torch_bool_type, + ).result + + # Determine result types from node metadata + node_val = node.meta.get("val") + if isinstance(node_val, (list, tuple)) and len(node_val) >= 2: + # flex_attention returns (output, logsumexp) + result_types = [self._cc.value_info_to_type(v) for v in node_val] + self._multi_result_nodes.add(node) + else: + # Single output + result_types = [self._cc.node_val_to_type(node)] + + # Build operands list for aten.flex_attention + # We'll pass tensors as operands and functions as attributes + operands = [query, key, value] + + # Add block_mask tensors if present + operands.extend(block_mask_tensors) + + # Add scale and enable_gqa + operands.append(scale) + operands.append(enable_gqa_value) + + # Create aten.flex_attention op directly. + with loc: + return_lse = _make_constant_op( + "torch.constant.bool", + self._cc.integer_attr( + 1 if (getattr(node_val, "return_lse", False) or ( + isinstance(node_val, (list, tuple)) and len(node_val) >= 2 + )) else 0, 1 + ), + self._cc.torch_bool_type, + ).result + + # Build operands for aten.flex_attention + # Note: score_mod and block_mask function references go as ATTRIBUTES, not operands + + # Handle block_mask: wrap tensors in a list construct if present + if block_mask_tensors: + # Wrap block_mask tensors in torch.prim.ListConstruct + block_mask_list = Operation.create( + "torch.prim.ListConstruct", + results=[IrType.parse("!torch.list", context=self._c)], + operands=block_mask_tensors, + loc=loc, + ).result + else: + # No block mask, use None + block_mask_list = Operation.create( + "torch.constant.none", + results=[self._cc.torch_none_type], + loc=loc, + ).result + + flat_operands = [ + query, + key, + value, + # score_mod placeholder (None) + Operation.create( + "torch.constant.none", + results=[self._cc.torch_none_type], + loc=loc, + ).result, + # block_mask as single list operand + block_mask_list, + scale, + enable_gqa_value, + # Kernel options as None + Operation.create( + "torch.constant.none", + results=[self._cc.torch_none_type], + loc=loc, + ).result, + # return_lse + return_lse + ] + + # Build attributes with function references + attributes = {} + if score_mod_ref is not None: + attributes["score_mod_fn"] = score_mod_ref + if mask_mod_ref is not None: + attributes["mask_mod_fn"] = mask_mod_ref + if kv_block_size is not None: + attributes["kv_block_size"] = self._cc.integer_attr(kv_block_size, 64) + if q_block_size is not None: + attributes["q_block_size"] = self._cc.integer_attr(q_block_size, 64) + + operation = Operation.create( + "torch.aten.flex_attention", + results=result_types, + operands=flat_operands, + attributes=attributes if attributes else None, + loc=loc, + ) + + # Bind results + if len(result_types) > 1: + self._multi_result_nodes.add(node) + for i, value in enumerate(operation.results): + self.bind_node_value(node, value, i) + else: + self.bind_node_value(node, operation.results[0]) + def _import_torch_op_overload( self, loc: Location, From dfdca759347375e49ddb328a70ee68c00475b2e0 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 28 Oct 2025 00:32:45 -0700 Subject: [PATCH 11/38] Formatting Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 87 +++++++++++++++---------- 1 file changed, 54 insertions(+), 33 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index c6d70cd0b71c..dd60cde7c981 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1913,58 +1913,66 @@ def _import_hop_flex_attention( self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator ): """Imports the torch._higher_order_ops.flex_attention HOP. - + Args format: (query, key, value, score_mod, block_mask, scale, kernel_options, ...) The score_mod is a submodule/callable that has been imported as a private function. The block_mask is a tuple: (kv_num_blocks, kv_indices, ..., mask_mod) - + This creates a call to aten.flex_attention with function symbol references. """ # flex_attention HOP args from PyTorch: # (query, key, value, score_mod, block_mask, scale, kernel_options, return_lse_tuple, ...) if len(node.args) < 6: - raise ValueError(f"flex_attention expects at least 6 arguments, got {len(node.args)}") - - query_arg, key_arg, value_arg, score_mod_arg, block_mask_arg, scale_arg = node.args[:6] + raise ValueError( + f"flex_attention expects at least 6 arguments, got {len(node.args)}" + ) + + query_arg, key_arg, value_arg, score_mod_arg, block_mask_arg, scale_arg = ( + node.args[:6] + ) kernel_options = node.args[6] if len(node.args) > 6 else {} - + # Import Q, K, V tensors query = self._import_argument(loc, query_arg, None) key = self._import_argument(loc, key_arg, None) value = self._import_argument(loc, value_arg, None) - + # Handle score_mod: extract function reference from submodule score_mod_ref = None if score_mod_arg is not None and isinstance(score_mod_arg, torch_fx.Node): # score_mod is a GraphModule reference from get_attr root_module = node.graph.owning_module - if hasattr(score_mod_arg, 'target'): + if hasattr(score_mod_arg, "target"): score_mod_name = score_mod_arg.target score_mod_module = getattr(root_module, score_mod_name, None) if score_mod_module is not None: # The function was imported by _import_all_child_modules with this naming convention - score_mod_func_name = f"main_{score_mod_name}_{id(score_mod_module)}" + score_mod_func_name = ( + f"main_{score_mod_name}_{id(score_mod_module)}" + ) score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name) - + # Handle block_mask: extract mask_mod function and tensor components - # block_mask tuple format: (kv_num_blocks, kv_indices, q_num_blocks, q_indices, + # block_mask tuple format: (kv_num_blocks, kv_indices, q_num_blocks, q_indices, # kv_block_size, q_block_size, ..., mask_mod) mask_mod_ref = None block_mask_tensors = [] kv_block_size = None q_block_size = None - + if block_mask_arg is not None and isinstance(block_mask_arg, tuple): # Parse the block_mask tuple structure # First two entries: kv_num_blocks (int), kv_indices (tensor) - # Next two: q_num_blocks (tensor), q_indices (tensor) + # Next two: q_num_blocks (tensor), q_indices (tensor) # Then: scalar dimensions and the mask_mod function at the end root_module = node.graph.owning_module - + for i, component in enumerate(block_mask_arg): if isinstance(component, torch_fx.Node): # Check if it's a tensor or a submodule reference - if component.op == "get_attr" and hasattr(root_module, component.target): + if component.op == "get_attr" and hasattr( + root_module, component.target + ): obj = getattr(root_module, component.target) # Check if it's a GraphModule (mask_mod) or a tensor if isinstance(obj, GraphModule): @@ -1973,17 +1981,21 @@ def _import_hop_flex_attention( mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name) else: # It's a tensor (block indices) - block_mask_tensors.append(self._import_argument(loc, component, None)) + block_mask_tensors.append( + self._import_argument(loc, component, None) + ) else: # Regular tensor argument - block_mask_tensors.append(self._import_argument(loc, component, None)) + block_mask_tensors.append( + self._import_argument(loc, component, None) + ) elif isinstance(component, int): # Scalar dimensions (KV_BLOCK_SIZE, Q_BLOCK_SIZE) if kv_block_size is None: kv_block_size = component elif q_block_size is None: q_block_size = component - + # Import scale (float or None) if scale_arg is None: scale = Operation.create( @@ -2000,7 +2012,7 @@ def _import_hop_flex_attention( ).result else: scale = self._import_argument(loc, scale_arg, None) - + # Get enable_gqa from kernel_options if present enable_gqa = False if isinstance(kernel_options, dict) and "enable_gqa" in kernel_options: @@ -2011,7 +2023,7 @@ def _import_hop_flex_attention( self._cc.integer_attr(1 if enable_gqa else 0, 1), self._cc.torch_bool_type, ).result - + # Determine result types from node metadata node_val = node.meta.get("val") if isinstance(node_val, (list, tuple)) and len(node_val) >= 2: @@ -2021,33 +2033,42 @@ def _import_hop_flex_attention( else: # Single output result_types = [self._cc.node_val_to_type(node)] - + # Build operands list for aten.flex_attention # We'll pass tensors as operands and functions as attributes operands = [query, key, value] - + # Add block_mask tensors if present operands.extend(block_mask_tensors) - + # Add scale and enable_gqa operands.append(scale) operands.append(enable_gqa_value) - + # Create aten.flex_attention op directly. with loc: return_lse = _make_constant_op( "torch.constant.bool", self._cc.integer_attr( - 1 if (getattr(node_val, "return_lse", False) or ( - isinstance(node_val, (list, tuple)) and len(node_val) >= 2 - )) else 0, 1 + ( + 1 + if ( + getattr(node_val, "return_lse", False) + or ( + isinstance(node_val, (list, tuple)) + and len(node_val) >= 2 + ) + ) + else 0 + ), + 1, ), self._cc.torch_bool_type, ).result # Build operands for aten.flex_attention # Note: score_mod and block_mask function references go as ATTRIBUTES, not operands - + # Handle block_mask: wrap tensors in a list construct if present if block_mask_tensors: # Wrap block_mask tensors in torch.prim.ListConstruct @@ -2064,7 +2085,7 @@ def _import_hop_flex_attention( results=[self._cc.torch_none_type], loc=loc, ).result - + flat_operands = [ query, key, @@ -2086,9 +2107,9 @@ def _import_hop_flex_attention( loc=loc, ).result, # return_lse - return_lse + return_lse, ] - + # Build attributes with function references attributes = {} if score_mod_ref is not None: @@ -2099,7 +2120,7 @@ def _import_hop_flex_attention( attributes["kv_block_size"] = self._cc.integer_attr(kv_block_size, 64) if q_block_size is not None: attributes["q_block_size"] = self._cc.integer_attr(q_block_size, 64) - + operation = Operation.create( "torch.aten.flex_attention", results=result_types, @@ -2107,7 +2128,7 @@ def _import_hop_flex_attention( attributes=attributes if attributes else None, loc=loc, ) - + # Bind results if len(result_types) > 1: self._multi_result_nodes.add(node) From 6178d07275a4bba52c2a51e05a9ba8be9cf9ea14 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 28 Oct 2025 00:35:50 -0700 Subject: [PATCH 12/38] Fixed merge newline removals Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index dd60cde7c981..d7cbdc7453ff 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1771,6 +1771,7 @@ def _import_hop_while_loop( self._multi_result_nodes.add(node) else: result_types = [self._cc.node_val_to_type(node)] + # Call the condition function with initial carries to get initial condition cond_result_type = self._cc.get_vtensor_type(torch.Size([]), torch.bool) @@ -1781,6 +1782,7 @@ def _import_hop_while_loop( operands=carry_values, loc=loc, ) + # Convert vtensor to torch.bool bool_conv = Operation.create( name="torch.aten.Bool.Tensor", @@ -1788,6 +1790,7 @@ def _import_hop_while_loop( operands=[initial_cond_call.results[0]], loc=loc, ) + # Create max iterations constant (INT64_MAX) with loc: max_iter = _make_constant_op( @@ -1811,6 +1814,7 @@ def _import_hop_while_loop( block_arg_types = [self._cc.torch_int_type] + result_types with loc: loop_block = Block.create_at_start(loop_region, block_arg_types) + # Inside the loop body, call body function and condition function with InsertionPoint(loop_block): # Call body function with current carry values (skip iteration counter) @@ -1822,6 +1826,7 @@ def _import_hop_while_loop( loc=loc, ) body_results = list(body_results_op.results) + # Call condition function with updated carries cond_result_loop = Operation.create( name="func.call", @@ -1830,6 +1835,7 @@ def _import_hop_while_loop( operands=body_results, loc=loc, ).result + # Convert to bool cond_bool = Operation.create( name="torch.aten.Bool.Tensor", @@ -1837,6 +1843,7 @@ def _import_hop_while_loop( operands=[cond_result_loop], loc=loc, ).result + # Emit loop condition with updated carries Operation.create( name="torch.prim.Loop.condition", @@ -1844,6 +1851,7 @@ def _import_hop_while_loop( operands=[cond_bool] + body_results, loc=loc, ) + # Bind the loop results to the node if len(result_types) > 1: self._multi_result_nodes.add(node) From 52f1fbc4984033ee228e4fff0a325b3b925e8ea7 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 29 Oct 2025 04:47:57 -0700 Subject: [PATCH 13/38] Added AtenFluxAttentionOp Signed-off-by: Keshav Vinayak Jha --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4ad03f54313f..f7805dcbf63b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16194,6 +16194,62 @@ def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ let hasFolder = 1; } + +def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::flex_attention : (Tensor, Tensor, Tensor, Any?, Any?, float?, bool, Any?, bool) -> (Tensor, Tensor)`"; + let description = [{ + Flexible attention operator that supports custom score modification and masking. + + Args: + query: Query tensor [B, H, M, E] + key: Key tensor [B, H, N, E] + value: Value tensor [B, H, N, Ev] + score_mod: Optional callable to modify attention scores (represented as None or opaque type) + block_mask: Optional BlockMask tuple for sparse attention patterns + scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim)) + enable_gqa: bool for grouped query attention support + kernel_options: Optional dict of kernel configuration options + return_lse: bool to return log-sum-exp values + + Returns: + - If return_lse=False: Just the output tensor [B, H, M, Ev] + - If return_lse=True: Tuple of (output [B, H, M, Ev], logsumexp [B, H, M]) + + Note: score_mod and block_mask are higher-order/complex types in PyTorch. + For MLIR representation, score_mod is represented as None (identity) or an opaque type, + and block_mask is represented as None or a tuple/list of tensors containing the block indices. + }]; + let arguments = (ins + AnyTorchTensorType:$query, + AnyTorchTensorType:$key, + AnyTorchTensorType:$value, + AnyType:$score_mod, + AnyType:$block_mask, + AnyTorchOptionalFloatType:$scale, + Torch_BoolType:$enable_gqa, + AnyType:$kernel_options, + Torch_BoolType:$return_lse + ); + let results = (outs + AnyTorchTensorType:$output, + AnyTorchOptionalTensorType:$logsumexp + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 2); + } + void AtenFlexAttentionOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 2); + } + }]; +} + + def Torch_AtenFloatStrOp : Torch_Op<"aten.Float.str", [ AllowsTypeRefinement, HasValueSemantics, From a56433a17b5b2c2bdd6fc936940c6d962a1fec56 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 30 Oct 2025 00:19:31 -0700 Subject: [PATCH 14/38] Added changes for correct functional references Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 35 +++++++------------------ 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index d7cbdc7453ff..873023493e63 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1954,10 +1954,7 @@ def _import_hop_flex_attention( score_mod_name = score_mod_arg.target score_mod_module = getattr(root_module, score_mod_name, None) if score_mod_module is not None: - # The function was imported by _import_all_child_modules with this naming convention - score_mod_func_name = ( - f"main_{score_mod_name}_{id(score_mod_module)}" - ) + score_mod_func_name = score_mod_name score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name) # Handle block_mask: extract mask_mod function and tensor components @@ -1985,7 +1982,7 @@ def _import_hop_flex_attention( # Check if it's a GraphModule (mask_mod) or a tensor if isinstance(obj, GraphModule): # This is the mask_mod function - mask_mod_func_name = f"main_{component.target}_{id(obj)}" + mask_mod_func_name = component.target mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name) else: # It's a tensor (block indices) @@ -2042,18 +2039,6 @@ def _import_hop_flex_attention( # Single output result_types = [self._cc.node_val_to_type(node)] - # Build operands list for aten.flex_attention - # We'll pass tensors as operands and functions as attributes - operands = [query, key, value] - - # Add block_mask tensors if present - operands.extend(block_mask_tensors) - - # Add scale and enable_gqa - operands.append(scale) - operands.append(enable_gqa_value) - - # Create aten.flex_attention op directly. with loc: return_lse = _make_constant_op( "torch.constant.bool", @@ -2119,15 +2104,13 @@ def _import_hop_flex_attention( ] # Build attributes with function references - attributes = {} - if score_mod_ref is not None: - attributes["score_mod_fn"] = score_mod_ref - if mask_mod_ref is not None: - attributes["mask_mod_fn"] = mask_mod_ref - if kv_block_size is not None: - attributes["kv_block_size"] = self._cc.integer_attr(kv_block_size, 64) - if q_block_size is not None: - attributes["q_block_size"] = self._cc.integer_attr(q_block_size, 64) + attributes = { + "score_mod_fn": score_mod_ref, + "mask_mod_fn": mask_mod_ref, + "kv_block_size": self._cc.integer_attr(kv_block_size, 64), + "q_block_size": self._cc.integer_attr(q_block_size, 64), + } + attributes = {k: v for k, v in attributes.items() if v is not None} operation = Operation.create( "torch.aten.flex_attention", From b0e8585e4e1f19e174bff1c19976acb7207f11f2 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 4 Nov 2025 11:33:39 -0800 Subject: [PATCH 15/38] QOL changes: 1. Better documentation for AtenFlexAttentionOp 2. Function referece added as attributes to aten.flex_attention 3. Updates to _import_hop_flex_attention reflecting latest changes of module import. 4. Removed discardable attributes; scored_mod_fn and mask_mod_fn added as optionalAttr Signed-off-by: Keshav Vinayak Jha --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 48 +++-- python/torch_mlir/extras/fx_importer.py | 169 +++++++----------- 2 files changed, 84 insertions(+), 133 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f7805dcbf63b..a3ca0eabe2a0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16194,62 +16194,60 @@ def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ let hasFolder = 1; } - def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::flex_attention : (Tensor, Tensor, Tensor, Any?, Any?, float?, bool, Any?, bool) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::flex_attention`"; let description = [{ - Flexible attention operator that supports custom score modification and masking. - + FlexAttention operation with flexible block-sparse attention patterns. + Args: - query: Query tensor [B, H, M, E] - key: Key tensor [B, H, N, E] + query: Query tensor [B, H, M, K] + key: Key tensor [B, H, N, K] value: Value tensor [B, H, N, Ev] - score_mod: Optional callable to modify attention scores (represented as None or opaque type) - block_mask: Optional BlockMask tuple for sparse attention patterns scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim)) - enable_gqa: bool for grouped query attention support - kernel_options: Optional dict of kernel configuration options - return_lse: bool to return log-sum-exp values + return_lse: Bool to return log-sum-exp values - Returns: - - If return_lse=False: Just the output tensor [B, H, M, Ev] - - If return_lse=True: Tuple of (output [B, H, M, Ev], logsumexp [B, H, M]) + Attributes: + score_mod_fn: Optional function symbol reference for score modification + mask_mod_fn: Optional function symbol reference for mask modification + + # TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.) - Note: score_mod and block_mask are higher-order/complex types in PyTorch. - For MLIR representation, score_mod is represented as None (identity) or an opaque type, - and block_mask is represented as None or a tuple/list of tensors containing the block indices. + Returns: + output: Result tensor [B, H, M, Ev] + logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True) }]; + let arguments = (ins AnyTorchTensorType:$query, AnyTorchTensorType:$key, AnyTorchTensorType:$value, - AnyType:$score_mod, - AnyType:$block_mask, AnyTorchOptionalFloatType:$scale, - Torch_BoolType:$enable_gqa, - AnyType:$kernel_options, - Torch_BoolType:$return_lse + Torch_BoolType:$enable_gqa + Torch_BoolType:$return_lse, + OptionalAttr:$score_mod_fn, + OptionalAttr:$mask_mod_fn ); + let results = (outs AnyTorchTensorType:$output, AnyTorchOptionalTensorType:$logsumexp ); + let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 9, 2); + return parseDefaultTorchOp(parser, result, 5, 2); } void AtenFlexAttentionOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 9, 2); + printDefaultTorchOp(printer, *this, 5, 2); } }]; } - def Torch_AtenFloatStrOp : Torch_Op<"aten.Float.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 873023493e63..8faebc13b321 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1922,14 +1922,19 @@ def _import_hop_flex_attention( ): """Imports the torch._higher_order_ops.flex_attention HOP. - Args format: (query, key, value, score_mod, block_mask, scale, kernel_options, ...) - The score_mod is a submodule/callable that has been imported as a private function. - The block_mask is a tuple: (kv_num_blocks, kv_indices, ..., mask_mod) - - This creates a call to aten.flex_attention with function symbol references. + Args format: (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, ...) + - query, key, value: Attention input tensors + - score_mod: Optional submodule/callable for score modification (imported as function) + - block_mask: Optional BlockMask tuple containing mask_mod function and runtime tensors + - scale: Optional float for attention score scaling + - enable_gqa: Boolean for grouped query attention support (TODO: NYI) + - kernel_options: Dict of performance tuning options (TODO: NYI) + + This creates a call to aten.flex_attention with function symbol references for + score_mod and mask_mod. """ # flex_attention HOP args from PyTorch: - # (query, key, value, score_mod, block_mask, scale, kernel_options, return_lse_tuple, ...) + # (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, return_lse_tuple, ...) if len(node.args) < 6: raise ValueError( f"flex_attention expects at least 6 arguments, got {len(node.args)}" @@ -1938,68 +1943,51 @@ def _import_hop_flex_attention( query_arg, key_arg, value_arg, score_mod_arg, block_mask_arg, scale_arg = ( node.args[:6] ) - kernel_options = node.args[6] if len(node.args) > 6 else {} + + # TODO: Add support for enable_gqa (grouped query attention) + # This is a boolean flag that enables GQA optimization + enable_gqa = node.args[6] if len(node.args) > 6 else False + + # TODO: Add support for kernel_options (performance tuning parameters) + # This is a dict containing options like block sizes, num_warps, etc. + kernel_options = node.args[7] if len(node.args) > 7 else {} # Import Q, K, V tensors query = self._import_argument(loc, query_arg, None) key = self._import_argument(loc, key_arg, None) value = self._import_argument(loc, value_arg, None) - # Handle score_mod: extract function reference from submodule score_mod_ref = None if score_mod_arg is not None and isinstance(score_mod_arg, torch_fx.Node): - # score_mod is a GraphModule reference from get_attr + assert ( + score_mod_arg.op == "get_attr" + ), f"Expected get_attr for score_mod, got {score_mod_arg.op}" root_module = node.graph.owning_module - if hasattr(score_mod_arg, "target"): - score_mod_name = score_mod_arg.target - score_mod_module = getattr(root_module, score_mod_name, None) - if score_mod_module is not None: - score_mod_func_name = score_mod_name - score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name) - - # Handle block_mask: extract mask_mod function and tensor components - # block_mask tuple format: (kv_num_blocks, kv_indices, q_num_blocks, q_indices, - # kv_block_size, q_block_size, ..., mask_mod) - mask_mod_ref = None - block_mask_tensors = [] - kv_block_size = None - q_block_size = None + score_mod_module = getattr(root_module, score_mod_arg.target, None) + if score_mod_module is not None: + score_mod_func_name = self.fx_importer._graph_module_to_func_name[ + id(score_mod_module) + ] + score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name) + # Handle block_mask: extract only mask_mod function reference + # Note: BlockMask contains runtime tensors (kv_num_blocks, kv_indices, etc.) + # that are materialized by evaluating mask_mod(b, h, q_idx, kv_idx). + mask_mod_ref = None if block_mask_arg is not None and isinstance(block_mask_arg, tuple): - # Parse the block_mask tuple structure - # First two entries: kv_num_blocks (int), kv_indices (tensor) - # Next two: q_num_blocks (tensor), q_indices (tensor) - # Then: scalar dimensions and the mask_mod function at the end root_module = node.graph.owning_module - - for i, component in enumerate(block_mask_arg): - if isinstance(component, torch_fx.Node): - # Check if it's a tensor or a submodule reference - if component.op == "get_attr" and hasattr( - root_module, component.target - ): - obj = getattr(root_module, component.target) - # Check if it's a GraphModule (mask_mod) or a tensor - if isinstance(obj, GraphModule): - # This is the mask_mod function - mask_mod_func_name = component.target - mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name) - else: - # It's a tensor (block indices) - block_mask_tensors.append( - self._import_argument(loc, component, None) - ) - else: - # Regular tensor argument - block_mask_tensors.append( - self._import_argument(loc, component, None) - ) - elif isinstance(component, int): - # Scalar dimensions (KV_BLOCK_SIZE, Q_BLOCK_SIZE) - if kv_block_size is None: - kv_block_size = component - elif q_block_size is None: - q_block_size = component + # The mask_mod function is the last element in the BlockMask tuple + mask_mod_arg = block_mask_arg[-1] + if mask_mod_arg is not None and isinstance(mask_mod_arg, torch_fx.Node): + assert ( + mask_mod_arg.op == "get_attr" + ), f"Expected get_attr for mask_mod, got {mask_mod_arg.op}" + mask_mod_module = getattr(root_module, mask_mod_arg.target, None) + if mask_mod_module is not None: + mask_mod_func_name = self.fx_importer._graph_module_to_func_name[ + id(mask_mod_module) + ] + mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name) # Import scale (float or None) if scale_arg is None: @@ -2018,17 +2006,6 @@ def _import_hop_flex_attention( else: scale = self._import_argument(loc, scale_arg, None) - # Get enable_gqa from kernel_options if present - enable_gqa = False - if isinstance(kernel_options, dict) and "enable_gqa" in kernel_options: - enable_gqa = kernel_options["enable_gqa"] - with loc: - enable_gqa_value = _make_constant_op( - "torch.constant.bool", - self._cc.integer_attr(1 if enable_gqa else 0, 1), - self._cc.torch_bool_type, - ).result - # Determine result types from node metadata node_val = node.meta.get("val") if isinstance(node_val, (list, tuple)) and len(node_val) >= 2: @@ -2039,6 +2016,13 @@ def _import_hop_flex_attention( # Single output result_types = [self._cc.node_val_to_type(node)] + with loc: + enable_gqa_value = _make_constant_op( + "torch.constant.bool", + self._cc.integer_attr(1 if enable_gqa else 0, 1), + self._cc.torch_bool_type, + ).result + with loc: return_lse = _make_constant_op( "torch.constant.bool", @@ -2059,58 +2043,27 @@ def _import_hop_flex_attention( self._cc.torch_bool_type, ).result - # Build operands for aten.flex_attention - # Note: score_mod and block_mask function references go as ATTRIBUTES, not operands - - # Handle block_mask: wrap tensors in a list construct if present - if block_mask_tensors: - # Wrap block_mask tensors in torch.prim.ListConstruct - block_mask_list = Operation.create( - "torch.prim.ListConstruct", - results=[IrType.parse("!torch.list", context=self._c)], - operands=block_mask_tensors, - loc=loc, - ).result - else: - # No block mask, use None - block_mask_list = Operation.create( - "torch.constant.none", - results=[self._cc.torch_none_type], - loc=loc, - ).result + # Build operands for aten.flex_attention. + # Op expects exactly 5 operands: query, key, value, scale, return_lse. + # Note: score_mod_fn and mask_mod_fn go as ATTRIBUTES, not operands. + # Note: block_mask tensors are handled by mask_mod_fn, not passed as operands. flat_operands = [ query, key, value, - # score_mod placeholder (None) - Operation.create( - "torch.constant.none", - results=[self._cc.torch_none_type], - loc=loc, - ).result, - # block_mask as single list operand - block_mask_list, scale, enable_gqa_value, - # Kernel options as None - Operation.create( - "torch.constant.none", - results=[self._cc.torch_none_type], - loc=loc, - ).result, - # return_lse return_lse, ] # Build attributes with function references - attributes = { - "score_mod_fn": score_mod_ref, - "mask_mod_fn": mask_mod_ref, - "kv_block_size": self._cc.integer_attr(kv_block_size, 64), - "q_block_size": self._cc.integer_attr(q_block_size, 64), - } - attributes = {k: v for k, v in attributes.items() if v is not None} + # Only include attributes if they're not None (OptionalAttr in TableGen) + attributes = {} + if score_mod_ref is not None: + attributes["score_mod_fn"] = score_mod_ref + if mask_mod_ref is not None: + attributes["mask_mod_fn"] = mask_mod_ref operation = Operation.create( "torch.aten.flex_attention", From 4470978d403e93866e3e221712fee1640187aa8e Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha <31160700+keshavvinayak01@users.noreply.github.com> Date: Wed, 5 Nov 2025 01:13:16 +0530 Subject: [PATCH 16/38] Update fx_importer.py to remove deprecated note Remove note about method usage for HOPs. --- python/torch_mlir/extras/fx_importer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 78a39235fe98..f109a9f0c79a 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -997,8 +997,6 @@ def import_graph_module( TODO: This mechanism is deprecated by the `import_program` entry-point and it should be removed when no longer required for backwards compatibility. - - Note: This method should only be used for HOPs. """ # Store the mapping for this module itself (HOPs will need to look this up) module_id = id(gm) From 719fe5ac9dde3a89da553d7716a91ec10f20f8ff Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha <31160700+keshavvinayak01@users.noreply.github.com> Date: Wed, 5 Nov 2025 01:18:40 +0530 Subject: [PATCH 17/38] Clarify enable_gqa support in fx_importer.py Removed TODO note for grouped query attention support in the docstring and comments. --- python/torch_mlir/extras/fx_importer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index f109a9f0c79a..f531ec2f4b42 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1915,7 +1915,7 @@ def _import_hop_flex_attention( - score_mod: Optional submodule/callable for score modification (imported as function) - block_mask: Optional BlockMask tuple containing mask_mod function and runtime tensors - scale: Optional float for attention score scaling - - enable_gqa: Boolean for grouped query attention support (TODO: NYI) + - enable_gqa: Boolean for grouped query attention support - kernel_options: Dict of performance tuning options (TODO: NYI) This creates a call to aten.flex_attention with function symbol references for @@ -1932,7 +1932,6 @@ def _import_hop_flex_attention( node.args[:6] ) - # TODO: Add support for enable_gqa (grouped query attention) # This is a boolean flag that enables GQA optimization enable_gqa = node.args[6] if len(node.args) > 6 else False From 5e024f6c337a001dd4d80f859d88013e4525eaf3 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha <31160700+keshavvinayak01@users.noreply.github.com> Date: Wed, 5 Nov 2025 01:23:48 +0530 Subject: [PATCH 18/38] Fix formatting in GeneratedTorchOps.td --- include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index a3ca0eabe2a0..750226c43c4d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16226,7 +16226,7 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ AnyTorchTensorType:$key, AnyTorchTensorType:$value, AnyTorchOptionalFloatType:$scale, - Torch_BoolType:$enable_gqa + Torch_BoolType:$enable_gqa, Torch_BoolType:$return_lse, OptionalAttr:$score_mod_fn, OptionalAttr:$mask_mod_fn From c78d69944bea2598d3d4f4398988b5fdb230990f Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 6 Nov 2025 05:10:15 -0800 Subject: [PATCH 19/38] return_lse is part of the kernel options Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index f531ec2f4b42..862839df3ff2 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1916,10 +1916,11 @@ def _import_hop_flex_attention( - block_mask: Optional BlockMask tuple containing mask_mod function and runtime tensors - scale: Optional float for attention score scaling - enable_gqa: Boolean for grouped query attention support - - kernel_options: Dict of performance tuning options (TODO: NYI) + - kernel_options: Dict of performance tuning options: + - return_lse: Boolean for whether to return the log-sum-exp tensor This creates a call to aten.flex_attention with function symbol references for - score_mod and mask_mod. + score_mod and mask_mod. The return_lse flag is extracted from kernel_options. """ # flex_attention HOP args from PyTorch: # (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, return_lse_tuple, ...) @@ -2010,23 +2011,15 @@ def _import_hop_flex_attention( self._cc.torch_bool_type, ).result + # Extract return_lse from kernel_options + return_lse_value = False + if isinstance(kernel_options, dict): + return_lse_value = kernel_options.get("return_lse", False) + with loc: return_lse = _make_constant_op( "torch.constant.bool", - self._cc.integer_attr( - ( - 1 - if ( - getattr(node_val, "return_lse", False) - or ( - isinstance(node_val, (list, tuple)) - and len(node_val) >= 2 - ) - ) - else 0 - ), - 1, - ), + self._cc.integer_attr(1 if return_lse_value else 0, 1), self._cc.torch_bool_type, ).result From da23ec98dcd065aef8a131ce9d673bb299e1bd57 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 6 Nov 2025 19:39:22 -0800 Subject: [PATCH 20/38] Moved op definition to TorchOps.td Signed-off-by: Keshav Vinayak Jha --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 54 ------------------ .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 55 +++++++++++++++++++ 2 files changed, 55 insertions(+), 54 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 750226c43c4d..4ad03f54313f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16194,60 +16194,6 @@ def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ let hasFolder = 1; } -def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::flex_attention`"; - let description = [{ - FlexAttention operation with flexible block-sparse attention patterns. - - Args: - query: Query tensor [B, H, M, K] - key: Key tensor [B, H, N, K] - value: Value tensor [B, H, N, Ev] - scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim)) - return_lse: Bool to return log-sum-exp values - - Attributes: - score_mod_fn: Optional function symbol reference for score modification - mask_mod_fn: Optional function symbol reference for mask modification - - # TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.) - - Returns: - output: Result tensor [B, H, M, Ev] - logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True) - }]; - - let arguments = (ins - AnyTorchTensorType:$query, - AnyTorchTensorType:$key, - AnyTorchTensorType:$value, - AnyTorchOptionalFloatType:$scale, - Torch_BoolType:$enable_gqa, - Torch_BoolType:$return_lse, - OptionalAttr:$score_mod_fn, - OptionalAttr:$mask_mod_fn - ); - - let results = (outs - AnyTorchTensorType:$output, - AnyTorchOptionalTensorType:$logsumexp - ); - - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 2); - } - void AtenFlexAttentionOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 2); - } - }]; -} - def Torch_AtenFloatStrOp : Torch_Op<"aten.Float.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 1595bf58e410..175efa0c987f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1442,4 +1442,59 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [ let hasCustomAssemblyFormat = 1; } + +def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::flex_attention`"; + let description = [{ + FlexAttention operation with flexible block-sparse attention patterns. + + Args: + query: Query tensor [B, H, M, K] + key: Key tensor [B, H, N, K] + value: Value tensor [B, H, N, Ev] + scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim)) + return_lse: Bool to return log-sum-exp values + + Attributes: + score_mod_fn: Optional function symbol reference for score modification + mask_mod_fn: Optional function symbol reference for mask modification + + # TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.) + + Returns: + output: Result tensor [B, H, M, Ev] + logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True) + }]; + + let arguments = (ins + AnyTorchTensorType:$query, + AnyTorchTensorType:$key, + AnyTorchTensorType:$value, + AnyTorchOptionalFloatType:$scale, + Torch_BoolType:$enable_gqa, + Torch_BoolType:$return_lse, + OptionalAttr:$score_mod_fn, + OptionalAttr:$mask_mod_fn + ); + + let results = (outs + AnyTorchTensorType:$output, + AnyTorchOptionalTensorType:$logsumexp + ); + + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); + } + void AtenFlexAttentionOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); + } + }]; +} + #endif // TORCH_OPS From af594134acac4c583a3de861003069ee6d1f7129 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 6 Nov 2025 19:41:34 -0800 Subject: [PATCH 21/38] Formatting TorchOps Signed-off-by: Keshav Vinayak Jha --- include/torch-mlir/Dialect/Torch/IR/TorchOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 175efa0c987f..224db711422a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1454,7 +1454,7 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ Args: query: Query tensor [B, H, M, K] - key: Key tensor [B, H, N, K] + key: Key tensor [B, H, N, K] value: Value tensor [B, H, N, Ev] scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim)) return_lse: Bool to return log-sum-exp values @@ -1462,7 +1462,7 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ Attributes: score_mod_fn: Optional function symbol reference for score modification mask_mod_fn: Optional function symbol reference for mask modification - + # TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.) Returns: From 0103163732893c2193d825b0c5656c863fe00033 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 6 Nov 2025 20:19:21 -0800 Subject: [PATCH 22/38] Added lit-test; Docs for FlexAttention --- .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 15 ++++++-- test/Dialect/Torch/ops.mlir | 34 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 224db711422a..58b3290a76c3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1442,7 +1442,16 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [ let hasCustomAssemblyFormat = 1; } - +//===----------------------------------------------------------------------===// +// FlexAttention operation + +// NOTE: This op is manually defined because `aten::flex_attention` exists in +// PyTorch's Python API (torch.nn.attention.flex_attention) but is not yet +// registered in PyTorch's JIT operator registry. The update_torch_ods.sh script +// validates against the JIT registry, so it cannot auto-generate this op. +// Once PyTorch adds flex_attention to the JIT registry, this can be moved to +// the auto-generated section. +//===----------------------------------------------------------------------===// def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ AllowsTypeRefinement, HasValueSemantics, @@ -1489,10 +1498,10 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 2); + return parseDefaultTorchOp(parser, result, 6, 2); } void AtenFlexAttentionOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 2); + printDefaultTorchOp(printer, *this, 6, 2); } }]; } diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index a47cbf83a318..b02742574485 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -205,3 +205,37 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to %1 = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %arg0, %arg1, %arg2, %int0, %int255 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32> return %1 : !torch.vtensor<[3,3],f32> } + +// CHECK-LABEL: func.func @torch.aten.flex_attention +func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>) { + %float1.0 = torch.constant.float 1.000000e+00 + %false_0 = torch.constant.bool false + // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} + // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool + // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + return %output, %logsumexp : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> +} + +func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Tensor %arg3, %arg4, %int1 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.int -> !torch.vtensor<[],si32> + %float1.000000e-01 = torch.constant.float 1.000000e-01 + %1 = torch.aten.mul.Scalar %arg2, %float1.000000e-01 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32> + %float1.000000e-02 = torch.constant.float 1.000000e-02 + %2 = torch.aten.mul.Scalar %0, %float1.000000e-02 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32> + %int1_0 = torch.constant.int 1 + %3 = torch.aten.add.Tensor %arg0, %2, %int1_0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> + %int1_1 = torch.constant.int 1 + %4 = torch.aten.add.Tensor %3, %1, %int1_1 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> + %5 = torch.aten.tanh %4 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %5 : !torch.vtensor<[],f32> +} + +func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { + %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} From 48f12bc459f231a64c94890c736e737b66965918 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 6 Nov 2025 20:22:07 -0800 Subject: [PATCH 23/38] Formatting Signed-off-by: Keshav Vinayak Jha --- include/torch-mlir/Dialect/Torch/IR/TorchOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 58b3290a76c3..3c58bb55af75 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1443,7 +1443,7 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [ } //===----------------------------------------------------------------------===// -// FlexAttention operation +// FlexAttention operation // NOTE: This op is manually defined because `aten::flex_attention` exists in // PyTorch's Python API (torch.nn.attention.flex_attention) but is not yet From ec3e5f87b5c5755abd485e39c0842d826817d680 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Mon, 10 Nov 2025 03:34:28 -0800 Subject: [PATCH 24/38] Modified arg extraction Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 49 ++++++++++++++++--------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 862839df3ff2..f23ab3c4644c 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1910,35 +1910,50 @@ def _import_hop_flex_attention( ): """Imports the torch._higher_order_ops.flex_attention HOP. - Args format: (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, ...) + Args format: (query, key, value, score_mod, block_mask, scale, kernel_options, ...) - query, key, value: Attention input tensors - score_mod: Optional submodule/callable for score modification (imported as function) - block_mask: Optional BlockMask tuple containing mask_mod function and runtime tensors - scale: Optional float for attention score scaling - - enable_gqa: Boolean for grouped query attention support - - kernel_options: Dict of performance tuning options: + - kernel_options: Optional Dict of performance tuning options: - return_lse: Boolean for whether to return the log-sum-exp tensor This creates a call to aten.flex_attention with function symbol references for - score_mod and mask_mod. The return_lse flag is extracted from kernel_options. + score_mod and mask_mod. """ # flex_attention HOP args from PyTorch: - # (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, return_lse_tuple, ...) - if len(node.args) < 6: + # (query, key, value, score_mod, block_mask, scale, kernel_options, ...) + if len(node.args) < 3: raise ValueError( - f"flex_attention expects at least 6 arguments, got {len(node.args)}" + f"flex_attention expects at least 3 arguments, got {len(node.args)}" ) - query_arg, key_arg, value_arg, score_mod_arg, block_mask_arg, scale_arg = ( - node.args[:6] - ) + # Required args + query_arg, key_arg, value_arg = node.args[:3] + + # Optional args (parse from remaining positionals; + score_mod_arg = None + block_mask_arg = None + scale_arg = None + kernel_options = {} + remaining = list(node.args[3:]) + + # score_mod (get_attr) if present + if remaining and isinstance(remaining[0], torch_fx.Node): + score_mod_arg = remaining.pop(0) + + # block_mask (tuple ending with mask_mod get_attr) if present + if remaining and isinstance(remaining[0], tuple): + block_mask_arg = remaining.pop(0) - # This is a boolean flag that enables GQA optimization - enable_gqa = node.args[6] if len(node.args) > 6 else False + if remaining and not isinstance(remaining[0], dict): + scale_arg = remaining.pop(0) - # TODO: Add support for kernel_options (performance tuning parameters) - # This is a dict containing options like block sizes, num_warps, etc. - kernel_options = node.args[7] if len(node.args) > 7 else {} + if remaining and isinstance(remaining[0], dict): + kernel_options = remaining.pop(0) + + # We don't support GQA yet. + enable_gqa = False # Import Q, K, V tensors query = self._import_argument(loc, query_arg, None) @@ -2015,7 +2030,6 @@ def _import_hop_flex_attention( return_lse_value = False if isinstance(kernel_options, dict): return_lse_value = kernel_options.get("return_lse", False) - with loc: return_lse = _make_constant_op( "torch.constant.bool", @@ -2052,7 +2066,6 @@ def _import_hop_flex_attention( attributes=attributes if attributes else None, loc=loc, ) - # Bind results if len(result_types) > 1: self._multi_result_nodes.add(node) @@ -2088,7 +2101,7 @@ def _import_torch_op_overload( # torch dynamo where it emits the Tensor variant of ops even when processing # scalar arguments, therefore we retrieve the schema as well so that we # consume the correct typing information when subsequently importing the - # function arguments and result types + # function arguments and result types. # i.e. the code below is basically doing `schema = torch.ops.aten.my_op.Scalar._schema` op_attrs = mlir_op_name.split(".") op_overload = getattr(torch, "ops") From fa5aba2ad09d331af15f29c8279d484ae47b7788 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 12 Nov 2025 09:22:14 -0800 Subject: [PATCH 25/38] Removed enable_gqa from flex_attention; HOP does not accept that argument Signed-off-by: Keshav Vinayak Jha --- include/torch-mlir/Dialect/Torch/IR/TorchOps.td | 5 ++--- python/torch_mlir/extras/fx_importer.py | 16 +--------------- test/Dialect/Torch/ops.mlir | 6 +++--- 3 files changed, 6 insertions(+), 21 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 3c58bb55af75..bada19be5707 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1484,7 +1484,6 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ AnyTorchTensorType:$key, AnyTorchTensorType:$value, AnyTorchOptionalFloatType:$scale, - Torch_BoolType:$enable_gqa, Torch_BoolType:$return_lse, OptionalAttr:$score_mod_fn, OptionalAttr:$mask_mod_fn @@ -1498,10 +1497,10 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 2); + return parseDefaultTorchOp(parser, result, 5, 2); } void AtenFlexAttentionOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 2); + printDefaultTorchOp(printer, *this, 5, 2); } }]; } diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index f23ab3c4644c..f251cb8ba5ce 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1952,9 +1952,6 @@ def _import_hop_flex_attention( if remaining and isinstance(remaining[0], dict): kernel_options = remaining.pop(0) - # We don't support GQA yet. - enable_gqa = False - # Import Q, K, V tensors query = self._import_argument(loc, query_arg, None) key = self._import_argument(loc, key_arg, None) @@ -2019,21 +2016,11 @@ def _import_hop_flex_attention( # Single output result_types = [self._cc.node_val_to_type(node)] - with loc: - enable_gqa_value = _make_constant_op( - "torch.constant.bool", - self._cc.integer_attr(1 if enable_gqa else 0, 1), - self._cc.torch_bool_type, - ).result - # Extract return_lse from kernel_options - return_lse_value = False - if isinstance(kernel_options, dict): - return_lse_value = kernel_options.get("return_lse", False) with loc: return_lse = _make_constant_op( "torch.constant.bool", - self._cc.integer_attr(1 if return_lse_value else 0, 1), + self._cc.integer_attr(kernel_options.get("return_lse", False)), self._cc.torch_bool_type, ).result @@ -2047,7 +2034,6 @@ def _import_hop_flex_attention( key, value, scale, - enable_gqa_value, return_lse, ] diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index b02742574485..d409fcaba149 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -212,11 +212,11 @@ func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %ar %false_0 = torch.constant.bool false // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]] // CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} - // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool + // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> - %output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> return %output, %logsumexp : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> } From 2b0637cfa4a46753d57ec1c8b0f92e8dce4ae369 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 12 Nov 2025 10:16:34 -0800 Subject: [PATCH 26/38] Typo Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index f251cb8ba5ce..b4e5e97de201 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -2020,7 +2020,7 @@ def _import_hop_flex_attention( with loc: return_lse = _make_constant_op( "torch.constant.bool", - self._cc.integer_attr(kernel_options.get("return_lse", False)), + self._cc.integer_attr(kernel_options.get("return_lse", 0), 1), self._cc.torch_bool_type, ).result From e7da0a7feab7ecb69c05556dfd3dfd500f940dc2 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 12 Nov 2025 22:52:29 -0800 Subject: [PATCH 27/38] Simplified arg extract logic Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 37 ++++++------------------- 1 file changed, 9 insertions(+), 28 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index b4e5e97de201..40124f081204 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1923,34 +1923,15 @@ def _import_hop_flex_attention( """ # flex_attention HOP args from PyTorch: # (query, key, value, score_mod, block_mask, scale, kernel_options, ...) - if len(node.args) < 3: - raise ValueError( - f"flex_attention expects at least 3 arguments, got {len(node.args)}" - ) - - # Required args - query_arg, key_arg, value_arg = node.args[:3] - - # Optional args (parse from remaining positionals; - score_mod_arg = None - block_mask_arg = None - scale_arg = None - kernel_options = {} - remaining = list(node.args[3:]) - - # score_mod (get_attr) if present - if remaining and isinstance(remaining[0], torch_fx.Node): - score_mod_arg = remaining.pop(0) - - # block_mask (tuple ending with mask_mod get_attr) if present - if remaining and isinstance(remaining[0], tuple): - block_mask_arg = remaining.pop(0) - - if remaining and not isinstance(remaining[0], dict): - scale_arg = remaining.pop(0) - - if remaining and isinstance(remaining[0], dict): - kernel_options = remaining.pop(0) + ( + query_arg, + key_arg, + value_arg, + score_mod_arg, + block_mask_arg, + scale_arg, + kernel_options, + ) = node.args[:7] # Import Q, K, V tensors query = self._import_argument(loc, query_arg, None) From 53dd19a0416c1dd660324af2b612b872ead41079 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 12 Nov 2025 22:59:47 -0800 Subject: [PATCH 28/38] return_lse should be booltype not i1 Signed-off-by: Keshav Vinayak Jha --- python/torch_mlir/extras/fx_importer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 40124f081204..4ad96d35deb3 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -2001,7 +2001,7 @@ def _import_hop_flex_attention( with loc: return_lse = _make_constant_op( "torch.constant.bool", - self._cc.integer_attr(kernel_options.get("return_lse", 0), 1), + self._cc.integer_attr(bool(kernel_options.get("return_lse", 0)), 1), self._cc.torch_bool_type, ).result From de91ca237cac350cad3eb22dbe09523391362861 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 13 Nov 2025 20:26:24 -0800 Subject: [PATCH 29/38] Added basic_test for flex_attention Signed-off-by: Keshav Vinayak Jha --- test/python/fx_importer/basic_test.py | 53 +++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index e490a8d3636c..1b2a2987fccb 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -251,6 +251,59 @@ def body(i, x): print(m) +@run +# CHECK-LABEL: test_flex_attention +# CHECK: func.func @test_flex_attention +def test_flex_attention(): + from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop + from torch.nn.attention.flex_attention import BlockMask, _LARGE_SPARSE_BLOCK_SIZE, create_block_mask, flex_attention + from torch import Tensor + def _create_empty_block_mask(query: Tensor, key: Tensor): + # Default block mask for flex attention. + device = query.device + return BlockMask.from_kv_blocks( + kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device), + kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device), + BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE, + seq_lengths=(1, 1), + ).as_tuple() + + def relative_position_bias( + score: Tensor, + batch: Tensor, + head: Tensor, + token_q: Tensor, + token_kv: Tensor, + ) -> Tensor: + # Simple score mod function. + return torch.tanh(score) + + class FlexAttention(torch.nn.Module): + def __init__(self, block_mask): + super().__init__() + self.block_mask=block_mask + + def forward(self, q, k, v): + output, logsumexp = flex_attention_hop( + q, k, v, + score_mod=relative_position_bias, + block_mask=self.block_mask, + scale=1.0, + kernel_options={"return_lse": 0}, + ) + return output, logsumexp + + # Export -> import to Torch-MLIR + B, Hq, Hkv, L, S, E, Ev = 4, 8, 8, 1024, 1024, 64, 64 + q = torch.ones(B, Hq, L, E) + k = torch.ones(B, Hkv, S, E) + v = torch.ones(B, Hkv, S, Ev) + m = fx.export_and_import( + FlexAttention(_create_empty_block_mask(q,k)), q,k,v, func_name="test_flex_attention" + ) + print(m) + + @run # CHECK-LABEL: test_stack_trace # CHECK: #loc[[LOC1:.+]] = loc( From 47803e304c87afcc764d96b3c4a53e385444285d Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Sun, 16 Nov 2025 08:25:26 -0800 Subject: [PATCH 30/38] Formatting and allowed unused unpacked vals Signed-off-by: Keshav Vinayak Jha --- test/python/fx_importer/basic_test.py | 28 ++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 1b2a2987fccb..2132f025820d 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -255,9 +255,17 @@ def body(i, x): # CHECK-LABEL: test_flex_attention # CHECK: func.func @test_flex_attention def test_flex_attention(): - from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop - from torch.nn.attention.flex_attention import BlockMask, _LARGE_SPARSE_BLOCK_SIZE, create_block_mask, flex_attention + from torch._higher_order_ops.flex_attention import ( + flex_attention as flex_attention_hop, + ) + from torch.nn.attention.flex_attention import ( + BlockMask, + _LARGE_SPARSE_BLOCK_SIZE, + create_block_mask, + flex_attention, + ) from torch import Tensor + def _create_empty_block_mask(query: Tensor, key: Tensor): # Default block mask for flex attention. device = query.device @@ -281,11 +289,13 @@ def relative_position_bias( class FlexAttention(torch.nn.Module): def __init__(self, block_mask): super().__init__() - self.block_mask=block_mask - + self.block_mask = block_mask + def forward(self, q, k, v): - output, logsumexp = flex_attention_hop( - q, k, v, + output, logsumexp, *_ = flex_attention_hop( + q, + k, + v, score_mod=relative_position_bias, block_mask=self.block_mask, scale=1.0, @@ -299,7 +309,11 @@ def forward(self, q, k, v): k = torch.ones(B, Hkv, S, E) v = torch.ones(B, Hkv, S, Ev) m = fx.export_and_import( - FlexAttention(_create_empty_block_mask(q,k)), q,k,v, func_name="test_flex_attention" + FlexAttention(_create_empty_block_mask(q, k)), + q, + k, + v, + func_name="test_flex_attention", ) print(m) From 207621c107d860dc8adc3a2c4aba36ff59e8c10b Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 19 Nov 2025 04:17:40 -0800 Subject: [PATCH 31/38] Added max_scores; changes to match pytorch naming conventions; Added working lit test Signed-off-by: Keshav Vinayak Jha --- .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 9 +++++--- python/torch_mlir/extras/fx_importer.py | 14 ++++++++--- test/python/fx_importer/basic_test.py | 23 +++++++++++++++---- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index bada19be5707..e6c84b286556 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1477,6 +1477,7 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ Returns: output: Result tensor [B, H, M, Ev] logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True) + max_scores: Optional max-scores tensor [B, H, M] (if return_max_scores=True) }]; let arguments = (ins @@ -1485,22 +1486,24 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ AnyTorchTensorType:$value, AnyTorchOptionalFloatType:$scale, Torch_BoolType:$return_lse, + Torch_BoolType:$return_max_scores, OptionalAttr:$score_mod_fn, OptionalAttr:$mask_mod_fn ); let results = (outs AnyTorchTensorType:$output, - AnyTorchOptionalTensorType:$logsumexp + AnyTorchOptionalTensorType:$logsumexp, + AnyTorchOptionalTensorType:$max_scores ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 2); + return parseDefaultTorchOp(parser, result, 6, 3); } void AtenFlexAttentionOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 2); + printDefaultTorchOp(printer, *this, 6, 3); } }]; } diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 4ad96d35deb3..09768594579e 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1997,16 +1997,23 @@ def _import_hop_flex_attention( # Single output result_types = [self._cc.node_val_to_type(node)] - # Extract return_lse from kernel_options + # Extract OUTPUT_LOGSUMEXP and OUTPUT_MAX from kernel_options with loc: return_lse = _make_constant_op( "torch.constant.bool", - self._cc.integer_attr(bool(kernel_options.get("return_lse", 0)), 1), + self._cc.integer_attr( + bool(kernel_options.get("OUTPUT_LOGSUMEXP", 0)), 1 + ), + self._cc.torch_bool_type, + ).result + return_max_scores = _make_constant_op( + "torch.constant.bool", + self._cc.integer_attr(bool(kernel_options.get("OUTPUT_MAX", 0)), 1), self._cc.torch_bool_type, ).result # Build operands for aten.flex_attention. - # Op expects exactly 5 operands: query, key, value, scale, return_lse. + # Op expects exactly 6 operands: query, key, value, scale, return_lse, return_max_scores. # Note: score_mod_fn and mask_mod_fn go as ATTRIBUTES, not operands. # Note: block_mask tensors are handled by mask_mod_fn, not passed as operands. @@ -2016,6 +2023,7 @@ def _import_hop_flex_attention( value, scale, return_lse, + return_max_scores, ] # Build attributes with function references diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 2132f025820d..9702d5c3b58e 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -253,7 +253,22 @@ def body(i, x): @run # CHECK-LABEL: test_flex_attention -# CHECK: func.func @test_flex_attention +# Check that helper functions are emitted first +# CHECK: func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> +# CHECK: torch.aten.tanh +# CHECK: func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> +# CHECK: torch.aten.new_ones +# Then check the main function +# CHECK: func.func @test_flex_attention(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) +# CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32> +# Validate flex_attention op with 3 results and 6 operands: +# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 +# CHECK: %[[RETURN_LSE:.*]] = torch.constant.bool false +# CHECK: %[[RETURN_MAX:.*]] = torch.constant.bool false +# CHECK: %[[OUTPUT:.*]], %[[LOGSUMEXP:.*]], %[[MAX_SCORES:.*]] = torch.aten.flex_attention %arg0, %arg1, %arg2, %[[SCALE]], %[[RETURN_LSE]], %[[RETURN_MAX]] {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} +# CHECK-SAME: : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool +# CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> +# CHECK: return %[[OUTPUT]] def test_flex_attention(): from torch._higher_order_ops.flex_attention import ( flex_attention as flex_attention_hop, @@ -292,16 +307,16 @@ def __init__(self, block_mask): self.block_mask = block_mask def forward(self, q, k, v): - output, logsumexp, *_ = flex_attention_hop( + output, lse, max_scores = flex_attention_hop( q, k, v, score_mod=relative_position_bias, block_mask=self.block_mask, scale=1.0, - kernel_options={"return_lse": 0}, + kernel_options={}, ) - return output, logsumexp + return output # Export -> import to Torch-MLIR B, Hq, Hkv, L, S, E, Ev = 4, 8, 8, 1024, 1024, 64, 64 From acc3ade4be57a94972200816e0ddf1c025ccc3b6 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 19 Nov 2025 05:28:06 -0800 Subject: [PATCH 32/38] Corrected lit test Signed-off-by: Keshav Vinayak Jha --- test/Dialect/Torch/ops.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index d409fcaba149..58c810580cdf 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -207,17 +207,17 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to } // CHECK-LABEL: func.func @torch.aten.flex_attention -func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>) { +func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { %float1.0 = torch.constant.float 1.000000e+00 %false_0 = torch.constant.bool false // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]] + // CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] // CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> - %output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> - return %output, %logsumexp : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp, %maxscore = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> + return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> } func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { From 16fc70c5801d9c60a51677d4382a29646d73b6ff Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 20 Nov 2025 07:09:38 -0800 Subject: [PATCH 33/38] Renamed aten.flex_attention -> hop_flex_attention; Added more lit tests Signed-off-by: Keshav Vinayak Jha --- .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 10 +-- python/torch_mlir/extras/fx_importer.py | 4 +- test/Dialect/Torch/ops.mlir | 76 ++++++++++++++----- test/python/fx_importer/basic_test.py | 2 +- 4 files changed, 64 insertions(+), 28 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index e6c84b286556..0841398e5dd4 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1445,19 +1445,19 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [ //===----------------------------------------------------------------------===// // FlexAttention operation -// NOTE: This op is manually defined because `aten::flex_attention` exists in +// NOTE: This op is manually defined because flex_attention exists in // PyTorch's Python API (torch.nn.attention.flex_attention) but is not yet // registered in PyTorch's JIT operator registry. The update_torch_ods.sh script // validates against the JIT registry, so it cannot auto-generate this op. // Once PyTorch adds flex_attention to the JIT registry, this can be moved to // the auto-generated section. //===----------------------------------------------------------------------===// -def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ +def Torch_HigherOrderFlexAttentionOp : Torch_Op<"hop_flex_attention", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::flex_attention`"; + let summary = "Computes the flex_attention operation (1-1 with torch._higher_order_ops.flex_attention)"; let description = [{ FlexAttention operation with flexible block-sparse attention patterns. @@ -1499,10 +1499,10 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [ let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult HigherOrderFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 6, 3); } - void AtenFlexAttentionOp::print(OpAsmPrinter &printer) { + void HigherOrderFlexAttentionOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 6, 3); } }]; diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 09768594579e..d79ba099d8ec 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1918,7 +1918,7 @@ def _import_hop_flex_attention( - kernel_options: Optional Dict of performance tuning options: - return_lse: Boolean for whether to return the log-sum-exp tensor - This creates a call to aten.flex_attention with function symbol references for + This creates a call to hop_flex_attention with function symbol references for score_mod and mask_mod. """ # flex_attention HOP args from PyTorch: @@ -2035,7 +2035,7 @@ def _import_hop_flex_attention( attributes["mask_mod_fn"] = mask_mod_ref operation = Operation.create( - "torch.aten.flex_attention", + "torch.hop_flex_attention", results=result_types, operands=flat_operands, attributes=attributes if attributes else None, diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 58c810580cdf..169b1094c8b5 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -206,36 +206,72 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to return %1 : !torch.vtensor<[3,3],f32> } -// CHECK-LABEL: func.func @torch.aten.flex_attention -func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { + +//===----------------------------------------------------------------------===// +// FlexAttention variant tests +//===----------------------------------------------------------------------===// + +func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { + %5 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %5 : !torch.vtensor<[],f32> +} + +func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { + %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// CHECK-LABEL: func.func @torch.hop_flex_attention +func.func @torch.hop_flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { %float1.0 = torch.constant.float 1.000000e+00 %false_0 = torch.constant.bool false // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] // CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> - %output, %logsumexp, %maxscore = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> } -func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { - %int1 = torch.constant.int 1 - %0 = torch.aten.sub.Tensor %arg3, %arg4, %int1 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.int -> !torch.vtensor<[],si32> - %float1.000000e-01 = torch.constant.float 1.000000e-01 - %1 = torch.aten.mul.Scalar %arg2, %float1.000000e-01 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32> - %float1.000000e-02 = torch.constant.float 1.000000e-02 - %2 = torch.aten.mul.Scalar %0, %float1.000000e-02 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32> - %int1_0 = torch.constant.int 1 - %3 = torch.aten.add.Tensor %arg0, %2, %int1_0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> - %int1_1 = torch.constant.int 1 - %4 = torch.aten.add.Tensor %3, %1, %int1_1 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> - %5 = torch.aten.tanh %4 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> - return %5 : !torch.vtensor<[],f32> +// CHECK-LABEL: func.func @torch.hop_flex_attention_nomask +func.func @torch.hop_flex_attention_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { + %float1.0 = torch.constant.float 1.000000e+00 + %false_0 = torch.constant.bool false + // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK-SAME: {score_mod_fn = @sdpa_score0} + // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool + // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> + return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> } -func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { - %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> - return %0 : !torch.vtensor<[],i1> +// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore +func.func @torch.hop_flex_attention_noscore (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { + %float1.0 = torch.constant.float 1.000000e+00 + %false_0 = torch.constant.bool false + // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK-SAME: {mask_mod_fn = @sdpa_mask0} + // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool + // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> + return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> +} + +// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore_nomask +func.func @torch.hop_flex_attention_noscore_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { + %float1.0 = torch.constant.float 1.000000e+00 + %false_0 = torch.constant.bool false + // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool + // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> + return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> } diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 9702d5c3b58e..bb424f0489b0 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -265,7 +265,7 @@ def body(i, x): # CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 # CHECK: %[[RETURN_LSE:.*]] = torch.constant.bool false # CHECK: %[[RETURN_MAX:.*]] = torch.constant.bool false -# CHECK: %[[OUTPUT:.*]], %[[LOGSUMEXP:.*]], %[[MAX_SCORES:.*]] = torch.aten.flex_attention %arg0, %arg1, %arg2, %[[SCALE]], %[[RETURN_LSE]], %[[RETURN_MAX]] {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} +# CHECK: %[[OUTPUT:.*]], %[[LOGSUMEXP:.*]], %[[MAX_SCORES:.*]] = torch.hop_flex_attention %arg0, %arg1, %arg2, %[[SCALE]], %[[RETURN_LSE]], %[[RETURN_MAX]] {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} # CHECK-SAME: : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool # CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> # CHECK: return %[[OUTPUT]] From 9334c1a8618bd861b3c730785f8eb0e1c02d6a9b Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Mon, 24 Nov 2025 02:49:25 -0800 Subject: [PATCH 34/38] Using direct calls to flex_attention for basic_tests; removed useless RT tests Signed-off-by: Keshav Vinayak Jha --- test/Dialect/Torch/ops.mlir | 47 +-------------- test/python/fx_importer/basic_test.py | 84 +++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 52 deletions(-) diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 169b1094c8b5..519b9992dacb 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -206,11 +206,7 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to return %1 : !torch.vtensor<[3,3],f32> } - -//===----------------------------------------------------------------------===// -// FlexAttention variant tests -//===----------------------------------------------------------------------===// - +// Round trip test for flex_attention. func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { %5 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> return %5 : !torch.vtensor<[],f32> @@ -234,44 +230,3 @@ func.func @torch.hop_flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> } - -// CHECK-LABEL: func.func @torch.hop_flex_attention_nomask -func.func @torch.hop_flex_attention_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { - %float1.0 = torch.constant.float 1.000000e+00 - %false_0 = torch.constant.bool false - // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] - // CHECK-SAME: {score_mod_fn = @sdpa_score0} - // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool - // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> - %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> - return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> -} - -// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore -func.func @torch.hop_flex_attention_noscore (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { - %float1.0 = torch.constant.float 1.000000e+00 - %false_0 = torch.constant.bool false - // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] - // CHECK-SAME: {mask_mod_fn = @sdpa_mask0} - // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool - // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> - %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> - return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> -} - -// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore_nomask -func.func @torch.hop_flex_attention_noscore_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { - %float1.0 = torch.constant.float 1.000000e+00 - %false_0 = torch.constant.bool false - // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] - // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool - // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> - %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> - return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> -} diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index bb424f0489b0..45617a85dc45 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -270,13 +270,10 @@ def body(i, x): # CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> # CHECK: return %[[OUTPUT]] def test_flex_attention(): - from torch._higher_order_ops.flex_attention import ( - flex_attention as flex_attention_hop, - ) + from torch._subclasses.fake_tensor import FakeTensor from torch.nn.attention.flex_attention import ( BlockMask, _LARGE_SPARSE_BLOCK_SIZE, - create_block_mask, flex_attention, ) from torch import Tensor @@ -289,7 +286,7 @@ def _create_empty_block_mask(query: Tensor, key: Tensor): kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device), BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE, seq_lengths=(1, 1), - ).as_tuple() + ) def relative_position_bias( score: Tensor, @@ -307,7 +304,7 @@ def __init__(self, block_mask): self.block_mask = block_mask def forward(self, q, k, v): - output, lse, max_scores = flex_attention_hop( + output = flex_attention( q, k, v, @@ -316,6 +313,8 @@ def forward(self, q, k, v): scale=1.0, kernel_options={}, ) + # flex_attention returns a single output tensor. + assert isinstance(output, FakeTensor) return output # Export -> import to Torch-MLIR @@ -330,6 +329,79 @@ def forward(self, q, k, v): v, func_name="test_flex_attention", ) + m.operation.verify() + print(m) + + +@run +# CHECK-LABEL: test_flex_attention_noblock_return_lse +# Check that helper functions are emitted first +# CHECK: func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> +# CHECK: torch.aten.tanh +# Note how the mask function is automaticalluy generated and not provided. +# CHECK: func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> +# CHECK: torch.aten.new_ones +# Then check the main function +# CHECK: func.func @test_flex_attention(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) +# CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32> +# Validate flex_attention op with 3 results and 6 operands: +# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 +# CHECK: %[[RETURN_LSE:.*]] = torch.constant.bool true +# CHECK: %[[RETURN_MAX:.*]] = torch.constant.bool false +# CHECK: %[[OUTPUT:.*]], %[[LOGSUMEXP:.*]], %[[MAX_SCORES:.*]] = torch.hop_flex_attention %arg0, %arg1, %arg2, %[[SCALE]], %[[RETURN_LSE]], %[[RETURN_MAX]] {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} +# CHECK-SAME: : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool +# CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> +# CHECK: return %[[OUTPUT]] +def test_flex_attention_noblock_return_lse(): + # from torch._higher_order_ops.flex_attention import ( + # flex_attention as flex_attention_hop, + # ) + from torch.nn.attention.flex_attention import flex_attention, AuxRequest, AuxOutput + from torch import Tensor + + def relative_position_bias( + score: Tensor, + batch: Tensor, + head: Tensor, + token_q: Tensor, + token_kv: Tensor, + ) -> Tensor: + # Simple score mod function. + return torch.tanh(score) + + class FlexAttention(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + outputs = flex_attention( + q, + k, + v, + score_mod=relative_position_bias, + block_mask=None, + scale=1.0, + return_aux=AuxRequest(lse=True), + kernel_options={}, + ) + # Note: Returning max scores is not supported on CPU, and will raise a + # NotImplementedError if max_scores is specified in the AuxRequest input. + assert isinstance(outputs[1], AuxOutput) and outputs[1].max_scores == None + return outputs[0] + + # Export -> import to Torch-MLIR + B, Hq, Hkv, L, S, E, Ev = 4, 8, 8, 1024, 1024, 64, 64 + q = torch.ones(B, Hq, L, E) + k = torch.ones(B, Hkv, S, E) + v = torch.ones(B, Hkv, S, Ev) + m = fx.export_and_import( + FlexAttention(), + q, + k, + v, + func_name="test_flex_attention", + ) + print(m) From cb1fbcd4a00f1a0f65bbfff562709233ed99fc40 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Mon, 24 Nov 2025 02:52:18 -0800 Subject: [PATCH 35/38] Formatting Signed-off-by: Keshav Vinayak Jha --- test/python/fx_importer/basic_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 45617a85dc45..a3d5f30e961d 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -384,7 +384,7 @@ def forward(self, q, k, v): return_aux=AuxRequest(lse=True), kernel_options={}, ) - # Note: Returning max scores is not supported on CPU, and will raise a + # Note: Returning max scores is not supported on CPU, and will raise a # NotImplementedError if max_scores is specified in the AuxRequest input. assert isinstance(outputs[1], AuxOutput) and outputs[1].max_scores == None return outputs[0] From f145056dd38418145ce7bdd0d46efd95dd815ebd Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha <31160700+keshavvinayak01@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:25:00 +0530 Subject: [PATCH 36/38] Fix typos in comments for basic_test.py --- test/python/fx_importer/basic_test.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index a3d5f30e961d..5ee88fe55a26 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -253,12 +253,12 @@ def body(i, x): @run # CHECK-LABEL: test_flex_attention -# Check that helper functions are emitted first +# Check that helper functions are emitted first. # CHECK: func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> # CHECK: torch.aten.tanh # CHECK: func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> # CHECK: torch.aten.new_ones -# Then check the main function +# Then check the main function. # CHECK: func.func @test_flex_attention(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) # CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32> # Validate flex_attention op with 3 results and 6 operands: @@ -329,19 +329,18 @@ def forward(self, q, k, v): v, func_name="test_flex_attention", ) - m.operation.verify() print(m) @run # CHECK-LABEL: test_flex_attention_noblock_return_lse -# Check that helper functions are emitted first +# Check that helper functions are emitted first. # CHECK: func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> # CHECK: torch.aten.tanh -# Note how the mask function is automaticalluy generated and not provided. +# Note how the mask function is automatically generated and not provided. # CHECK: func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> # CHECK: torch.aten.new_ones -# Then check the main function +# Then check the main function. # CHECK: func.func @test_flex_attention(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) # CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32> # Validate flex_attention op with 3 results and 6 operands: From 4ba9d8d10861d9142dc81d145d7cc365f9b01354 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 25 Nov 2025 01:03:12 -0800 Subject: [PATCH 37/38] Added Verifier to HigherOrderFlexAttention operation Signed-off-by: Keshav Vinayak Jha --- .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 0841398e5dd4..31bc717914fd 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1506,6 +1506,7 @@ def Torch_HigherOrderFlexAttentionOp : Torch_Op<"hop_flex_attention", [ printDefaultTorchOp(printer, *this, 6, 3); } }]; + let hasVerifier = 1; } #endif // TORCH_OPS diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a4888a218fae..acbe75878637 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -234,6 +234,51 @@ static Value getScalarFloatValue(Value input, Location loc, return nullptr; } +//===----------------------------------------------------------------------===// +// HigherOrderFlexAttentionOp +//===----------------------------------------------------------------------===// + +LogicalResult HigherOrderFlexAttentionOp::verify() { + static constexpr int kAttentionRank = 4; + Value query = getQuery(); + Value key = getKey(); + Value value = getValue(); + + if (!isa(getReturnLse().getType())) { + return emitError() << "expected return_lse to be a bool type"; + } + if (!isa(getReturnMaxScores().getType())) { + return emitError() << "expected return_max_scores to be a bool type"; + } + + auto queryType = dyn_cast(query.getType()); + auto keyType = dyn_cast(key.getType()); + auto valueType = dyn_cast(value.getType()); + + if (!queryType || !keyType || !valueType || !queryType.hasSizes() || + !keyType.hasSizes() || !valueType.hasSizes()) { + return emitError() << "expected input(s) types having sizes"; + } + + ArrayRef queryShape = queryType.getSizes(); + + // Query shape: [B, H, M, E]. + if (queryShape.size() != kAttentionRank) { + return emitError() << "expected 4D query tensor"; + } + // Dynamic head dim is not supported. + if (queryShape[3] == kUnknownSize) { + return emitError() << "NYI: dynamic head dimension"; + } + + // Check if the element type is a float. + if (!isa(queryType.getDtype())) { + return emitError() << "expected float element type"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // MethodOp //===----------------------------------------------------------------------===// From 93330286844ff78a631523a24339c78f7e643e0d Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 26 Nov 2025 20:56:59 -0800 Subject: [PATCH 38/38] Removed Dynamic head check (backend specific) Signed-off-by: Keshav Vinayak Jha --- lib/Dialect/Torch/IR/TorchOps.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index acbe75878637..51f3b79cc3a1 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -266,11 +266,6 @@ LogicalResult HigherOrderFlexAttentionOp::verify() { if (queryShape.size() != kAttentionRank) { return emitError() << "expected 4D query tensor"; } - // Dynamic head dim is not supported. - if (queryShape[3] == kUnknownSize) { - return emitError() << "NYI: dynamic head dimension"; - } - // Check if the element type is a float. if (!isa(queryType.getDtype())) { return emitError() << "expected float element type";