Skip to content

Commit b342917

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Fix plgpu.inline_mgpu support for scalar results.
We extend MGPU `CustomPrimitiveOp` to support scalar results accordingly. PiperOrigin-RevId: 836249550
1 parent 0adb584 commit b342917

File tree

5 files changed

+72
-45
lines changed

5 files changed

+72
-45
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2603,24 +2603,6 @@ def _ref_type_to_transforms(ref_type: RefType) -> ir.ArrayAttribute:
26032603
return ir.ArrayAttr.get(transform_attrs)
26042604

26052605

2606-
def _shape_dtype_struct_to_type_and_layout(
2607-
shape_dtype_struct: ShapeDtypeStruct,
2608-
) -> tuple[ir.Type, ir.Attribute | None]:
2609-
"""Returns the type and Mosaic GPU layout for the given ShapeDtypeStruct.
2610-
2611-
Unless the input indicates a scalar, the returned type will be a vector type
2612-
and the returned layout will not be None. If the input is a scalar, the
2613-
returned type will be the type of the scalar and the returned layout will be
2614-
None.
2615-
"""
2616-
el_type = mgpu_utils.dtype_to_ir_type(shape_dtype_struct.dtype)
2617-
if not shape_dtype_struct.shape:
2618-
return el_type, None
2619-
vector_type = ir.VectorType.get(shape_dtype_struct.shape, el_type)
2620-
layout = mgpu_layouts.to_layout_attr(shape_dtype_struct.layout.to_mgpu())
2621-
return vector_type, layout
2622-
2623-
26242606
def _replace_uses_in_block(old: ir.Value, new: ir.Value, block: ir.Block):
26252607
"""Replaces all uses of the `old` value with the `new` value in `block`."""
26262608

@@ -2723,18 +2705,22 @@ def _custom_primitive_in_specs(
27232705

27242706
def _custom_primitive_op_results(flat_ret_ty) -> tuple[
27252707
Sequence[ir.Type],
2726-
Sequence[ir.Attribute],
2708+
Sequence[ir.Attribute | None],
27272709
]:
27282710
"""Returns a tuple containing the list of output MLIR types, and layouts for
27292711
the given JAX return types."""
2730-
results_ty = []
2731-
out_layouts = []
2712+
results_ty: list[ir.Type] = []
2713+
out_layouts: list[ir.Attribute | None] = []
27322714
for r in flat_ret_ty:
27332715
if not isinstance(r, ShapeDtypeStruct):
27342716
raise NotImplementedError(f"Expected a ShapeDtypeStruct, but got: {r}")
2735-
ty, layout = _shape_dtype_struct_to_type_and_layout(r)
2736-
results_ty.append(ty)
2737-
if layout is not None:
2717+
el_type = mgpu_utils.dtype_to_ir_type(r.dtype)
2718+
if not r.shape: # scalar case.
2719+
results_ty.append(el_type)
2720+
out_layouts.append(None)
2721+
else:
2722+
results_ty.append(ir.VectorType.get(r.shape, el_type))
2723+
layout = mgpu_layouts.to_layout_attr(r.layout.to_mgpu())
27382724
out_layouts.append(layout)
27392725
return results_ty, out_layouts
27402726

@@ -2744,10 +2730,10 @@ def _populate_custom_primitive_op_block(
27442730
block: ir.Block,
27452731
mgpu_fn: Callable[..., Any],
27462732
pytree_args,
2747-
in_layouts : Sequence[ir.Attribute],
2733+
in_layouts: Sequence[ir.Attribute],
27482734
in_transforms: ir.ArrayAttr,
27492735
results_ty: Sequence[ir.Type],
2750-
out_layouts: Sequence[ir.Attribute],
2736+
out_layouts: Sequence[ir.Attribute | None],
27512737
):
27522738
"""Calls the given mgpu_fn to populate the block, handling inputs and outputs.
27532739
@@ -2826,17 +2812,29 @@ def _populate_custom_primitive_op_block(
28262812
for fa, result_ty, out_layout in zip(
28272813
inner_ret, results_ty, out_layouts, strict=True
28282814
):
2829-
if not ir.VectorType.isinstance(result_ty):
2830-
raise NotImplementedError(
2831-
"Only vector return types from the inline mgpu_fn are supported,"
2832-
f" but got: {result_ty}"
2833-
)
2834-
if out_layout != mgpu.layouts.to_layout_attr(fa.layout):
2835-
raise ValueError(
2836-
f"Output layout {out_layout} does not match the layout of the"
2837-
f" returned fragmented array {fa.layout}."
2815+
if not isinstance(fa, mgpu.FragmentedArray):
2816+
raise ValueError(f"Expected a FragmentedArray, but got: {fa}")
2817+
if ir.VectorType.isinstance(result_ty):
2818+
result_shape = ir.VectorType(result_ty).shape
2819+
if fa.shape != tuple(result_shape):
2820+
raise ValueError(f"Expected {result_shape} but got {fa.shape}")
2821+
if out_layout != mgpu.layouts.to_layout_attr(fa.layout):
2822+
raise ValueError(
2823+
f"Output layout {out_layout} does not match the layout of the"
2824+
f" returned fragmented array {fa.layout}."
2825+
)
2826+
ir_ret.append(
2827+
mgpu.dialect_lowering.fragmented_array_to_ir(fa, result_ty)
28382828
)
2839-
ir_ret.append(mgpu.dialect_lowering.fragmented_array_to_ir(fa, result_ty))
2829+
else: # scalar case.
2830+
assert out_layout is None
2831+
if fa.shape:
2832+
raise ValueError(f"Expected 0D shape, but got {fa.shape}")
2833+
if not isinstance(fa.layout, mgpu.WGSplatFragLayout):
2834+
raise ValueError(f"Expected WGSplatFragLayout, but got {fa.layout}")
2835+
value = fa.registers.item()
2836+
ir_ret.append(value)
2837+
28402838
mgpu.dialect.ReturnOp(operands_=ir_ret)
28412839

28422840

@@ -2893,7 +2891,7 @@ def _inline_mgpu_lowering_rule_wg_semantics(
28932891
operands_=flat_transformed_args,
28942892
in_layouts=in_layouts,
28952893
in_transforms=in_transforms,
2896-
out_layouts=out_layouts,
2894+
out_layouts=[l for l in out_layouts if l is not None],
28972895
)
28982896
block : ir.Block = custom_op.body.blocks.append(*in_types)
28992897
_populate_custom_primitive_op_block(

jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,19 @@ llvm::LogicalResult CustomPrimitiveOp::verify() {
524524
"smem.");
525525
}
526526

527-
if (getResults().size() != getOutLayouts().size()) {
528-
return emitOpError("Custom primitive must have a layout for each result.");
527+
int num_vector_results = 0;
528+
for (auto result : getResults()) {
529+
if (mlir::isa<mlir::VectorType>(result.getType())) {
530+
++num_vector_results;
531+
} else if (mlir::isa<mlir::ShapedType>(result.getType())) {
532+
return emitOpError(
533+
"Custom primitive can only return scalars or vectors.");
534+
}
535+
}
536+
537+
if (num_vector_results != getOutLayouts().size()) {
538+
return emitOpError(
539+
"Custom primitive must have a layout for each vector result.");
529540
}
530541

531542
return llvm::success();

jaxlib/mosaic/dialect/gpu/mosaic_gpu.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def MosaicGPU_ReturnOp : Op<MosaicGPU_Dialect, "return",
618618
}];
619619

620620
// The operand's type must match the parent CustomPrimitiveOp's result type.
621-
let arguments = (ins Variadic<AnyVectorOfAnyRank>:$operands);
621+
let arguments = (ins Variadic<AnyType>:$operands);
622622
let assemblyFormat = [{ attr-dict ($operands^ `:` type($operands))? }];
623623
let hasVerifier = 1;
624624
}
@@ -645,7 +645,7 @@ def MosaicGPU_CustomPrimitiveOp : Op<MosaicGPU_Dialect, "custom_primitive",
645645
ArrayAttr:$out_layouts
646646
);
647647

648-
let results = (outs Variadic<AnyVectorOfAnyRank>);
648+
let results = (outs Variadic<AnyType>);
649649
let regions = (region SizedRegion<1>:$body);
650650

651651
let hasVerifier = 1;

tests/mosaic/gpu_dialect_test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,27 @@ def test_custom_primitive_op_args_must_match_args_of_terminator(self):
852852
):
853853
self.module.operation.verify()
854854

855+
def test_custom_primitive_op_results_must_be_scalar_or_vector(self):
856+
with ir.InsertionPoint(self.module.body):
857+
ref_ty = ir.MemRefType.get((128, 128), ir.F32Type.get())
858+
op = mgpu.dialect.CustomPrimitiveOp(
859+
result=[ref_ty],
860+
operands_=[],
861+
in_layouts=[],
862+
in_transforms=[],
863+
out_layouts=[],
864+
)
865+
block = op.body.blocks.append()
866+
with ir.InsertionPoint(block):
867+
[ref] = undefs(ref_ty)
868+
mgpu.dialect.ReturnOp(operands_=[ref])
869+
870+
with self.assertRaisesRegex(
871+
ir.MLIRError,
872+
r"Custom primitive can only return scalars or vectors.",
873+
):
874+
self.module.operation.verify()
875+
855876
def test_tmem_alloc_op_must_have_smem_ref_input(self):
856877
with ir.InsertionPoint(self.module.body):
857878
(smem_ptr,) = undefs(
@@ -1431,7 +1452,7 @@ def test_custom_primitive_op_must_have_number_of_annotations_matching_operands_a
14311452
error = "transforms for each memref operand in smem"
14321453
else:
14331454
assert omit_out_layouts
1434-
error = "layout for each result"
1455+
error = "layout for each vector result"
14351456

14361457
with self.assertRaisesRegex(ir.MLIRError, error):
14371458
self.module.operation.verify()

tests/pallas/mosaic_gpu_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,9 +1442,6 @@ def kernel(o_ref):
14421442
((2, 3, 4, 5), ("a", "b", "c", "d"), (2,), ("x",)),
14431443
)
14441444
def test_axis_indices_in_grid(self, grid, grid_names, cluster, cluster_names):
1445-
# Skipping because `inline_mpgpu` isn't supported in WG semantics.
1446-
self.skip_if_wg_semantics()
1447-
14481445
@functools.partial(
14491446
self.kernel,
14501447
out_shape=[

0 commit comments

Comments
 (0)