From b342917331ddb04662b8647ebcf2c6191062ae67 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 24 Nov 2025 09:10:38 -0800 Subject: [PATCH] [Pallas:MGPU] Fix `plgpu.inline_mgpu` support for scalar results. We extend MGPU `CustomPrimitiveOp` to support scalar results accordingly. PiperOrigin-RevId: 836249550 --- jax/_src/pallas/mosaic_gpu/primitives.py | 72 ++++++++++++------------ jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 15 ++++- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 4 +- tests/mosaic/gpu_dialect_test.py | 23 +++++++- tests/pallas/mosaic_gpu_test.py | 3 - 5 files changed, 72 insertions(+), 45 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 3342114b5fea..e072c92c2646 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -2603,24 +2603,6 @@ def _ref_type_to_transforms(ref_type: RefType) -> ir.ArrayAttribute: return ir.ArrayAttr.get(transform_attrs) -def _shape_dtype_struct_to_type_and_layout( - shape_dtype_struct: ShapeDtypeStruct, -) -> tuple[ir.Type, ir.Attribute | None]: - """Returns the type and Mosaic GPU layout for the given ShapeDtypeStruct. - - Unless the input indicates a scalar, the returned type will be a vector type - and the returned layout will not be None. If the input is a scalar, the - returned type will be the type of the scalar and the returned layout will be - None. - """ - el_type = mgpu_utils.dtype_to_ir_type(shape_dtype_struct.dtype) - if not shape_dtype_struct.shape: - return el_type, None - vector_type = ir.VectorType.get(shape_dtype_struct.shape, el_type) - layout = mgpu_layouts.to_layout_attr(shape_dtype_struct.layout.to_mgpu()) - return vector_type, layout - - def _replace_uses_in_block(old: ir.Value, new: ir.Value, block: ir.Block): """Replaces all uses of the `old` value with the `new` value in `block`.""" @@ -2723,18 +2705,22 @@ def _custom_primitive_in_specs( def _custom_primitive_op_results(flat_ret_ty) -> tuple[ Sequence[ir.Type], - Sequence[ir.Attribute], + Sequence[ir.Attribute | None], ]: """Returns a tuple containing the list of output MLIR types, and layouts for the given JAX return types.""" - results_ty = [] - out_layouts = [] + results_ty: list[ir.Type] = [] + out_layouts: list[ir.Attribute | None] = [] for r in flat_ret_ty: if not isinstance(r, ShapeDtypeStruct): raise NotImplementedError(f"Expected a ShapeDtypeStruct, but got: {r}") - ty, layout = _shape_dtype_struct_to_type_and_layout(r) - results_ty.append(ty) - if layout is not None: + el_type = mgpu_utils.dtype_to_ir_type(r.dtype) + if not r.shape: # scalar case. + results_ty.append(el_type) + out_layouts.append(None) + else: + results_ty.append(ir.VectorType.get(r.shape, el_type)) + layout = mgpu_layouts.to_layout_attr(r.layout.to_mgpu()) out_layouts.append(layout) return results_ty, out_layouts @@ -2744,10 +2730,10 @@ def _populate_custom_primitive_op_block( block: ir.Block, mgpu_fn: Callable[..., Any], pytree_args, - in_layouts : Sequence[ir.Attribute], + in_layouts: Sequence[ir.Attribute], in_transforms: ir.ArrayAttr, results_ty: Sequence[ir.Type], - out_layouts: Sequence[ir.Attribute], + out_layouts: Sequence[ir.Attribute | None], ): """Calls the given mgpu_fn to populate the block, handling inputs and outputs. @@ -2826,17 +2812,29 @@ def _populate_custom_primitive_op_block( for fa, result_ty, out_layout in zip( inner_ret, results_ty, out_layouts, strict=True ): - if not ir.VectorType.isinstance(result_ty): - raise NotImplementedError( - "Only vector return types from the inline mgpu_fn are supported," - f" but got: {result_ty}" - ) - if out_layout != mgpu.layouts.to_layout_attr(fa.layout): - raise ValueError( - f"Output layout {out_layout} does not match the layout of the" - f" returned fragmented array {fa.layout}." + if not isinstance(fa, mgpu.FragmentedArray): + raise ValueError(f"Expected a FragmentedArray, but got: {fa}") + if ir.VectorType.isinstance(result_ty): + result_shape = ir.VectorType(result_ty).shape + if fa.shape != tuple(result_shape): + raise ValueError(f"Expected {result_shape} but got {fa.shape}") + if out_layout != mgpu.layouts.to_layout_attr(fa.layout): + raise ValueError( + f"Output layout {out_layout} does not match the layout of the" + f" returned fragmented array {fa.layout}." + ) + ir_ret.append( + mgpu.dialect_lowering.fragmented_array_to_ir(fa, result_ty) ) - ir_ret.append(mgpu.dialect_lowering.fragmented_array_to_ir(fa, result_ty)) + else: # scalar case. + assert out_layout is None + if fa.shape: + raise ValueError(f"Expected 0D shape, but got {fa.shape}") + if not isinstance(fa.layout, mgpu.WGSplatFragLayout): + raise ValueError(f"Expected WGSplatFragLayout, but got {fa.layout}") + value = fa.registers.item() + ir_ret.append(value) + mgpu.dialect.ReturnOp(operands_=ir_ret) @@ -2893,7 +2891,7 @@ def _inline_mgpu_lowering_rule_wg_semantics( operands_=flat_transformed_args, in_layouts=in_layouts, in_transforms=in_transforms, - out_layouts=out_layouts, + out_layouts=[l for l in out_layouts if l is not None], ) block : ir.Block = custom_op.body.blocks.append(*in_types) _populate_custom_primitive_op_block( diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index dca6ecda54d0..c7625e3b4336 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -524,8 +524,19 @@ llvm::LogicalResult CustomPrimitiveOp::verify() { "smem."); } - if (getResults().size() != getOutLayouts().size()) { - return emitOpError("Custom primitive must have a layout for each result."); + int num_vector_results = 0; + for (auto result : getResults()) { + if (mlir::isa(result.getType())) { + ++num_vector_results; + } else if (mlir::isa(result.getType())) { + return emitOpError( + "Custom primitive can only return scalars or vectors."); + } + } + + if (num_vector_results != getOutLayouts().size()) { + return emitOpError( + "Custom primitive must have a layout for each vector result."); } return llvm::success(); diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 9309738d79df..b4de05896c81 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -618,7 +618,7 @@ def MosaicGPU_ReturnOp : Op:$operands); + let arguments = (ins Variadic:$operands); let assemblyFormat = [{ attr-dict ($operands^ `:` type($operands))? }]; let hasVerifier = 1; } @@ -645,7 +645,7 @@ def MosaicGPU_CustomPrimitiveOp : Op); + let results = (outs Variadic); let regions = (region SizedRegion<1>:$body); let hasVerifier = 1; diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index ea105c29e3f4..188231b795d8 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -852,6 +852,27 @@ def test_custom_primitive_op_args_must_match_args_of_terminator(self): ): self.module.operation.verify() + def test_custom_primitive_op_results_must_be_scalar_or_vector(self): + with ir.InsertionPoint(self.module.body): + ref_ty = ir.MemRefType.get((128, 128), ir.F32Type.get()) + op = mgpu.dialect.CustomPrimitiveOp( + result=[ref_ty], + operands_=[], + in_layouts=[], + in_transforms=[], + out_layouts=[], + ) + block = op.body.blocks.append() + with ir.InsertionPoint(block): + [ref] = undefs(ref_ty) + mgpu.dialect.ReturnOp(operands_=[ref]) + + with self.assertRaisesRegex( + ir.MLIRError, + r"Custom primitive can only return scalars or vectors.", + ): + self.module.operation.verify() + def test_tmem_alloc_op_must_have_smem_ref_input(self): with ir.InsertionPoint(self.module.body): (smem_ptr,) = undefs( @@ -1431,7 +1452,7 @@ def test_custom_primitive_op_must_have_number_of_annotations_matching_operands_a error = "transforms for each memref operand in smem" else: assert omit_out_layouts - error = "layout for each result" + error = "layout for each vector result" with self.assertRaisesRegex(ir.MLIRError, error): self.module.operation.verify() diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d60da4584b35..e235f56a9f98 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1442,9 +1442,6 @@ def kernel(o_ref): ((2, 3, 4, 5), ("a", "b", "c", "d"), (2,), ("x",)), ) def test_axis_indices_in_grid(self, grid, grid_names, cluster, cluster_names): - # Skipping because `inline_mpgpu` isn't supported in WG semantics. - self.skip_if_wg_semantics() - @functools.partial( self.kernel, out_shape=[