Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 35 additions & 37 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2603,24 +2603,6 @@ def _ref_type_to_transforms(ref_type: RefType) -> ir.ArrayAttribute:
return ir.ArrayAttr.get(transform_attrs)


def _shape_dtype_struct_to_type_and_layout(
shape_dtype_struct: ShapeDtypeStruct,
) -> tuple[ir.Type, ir.Attribute | None]:
"""Returns the type and Mosaic GPU layout for the given ShapeDtypeStruct.

Unless the input indicates a scalar, the returned type will be a vector type
and the returned layout will not be None. If the input is a scalar, the
returned type will be the type of the scalar and the returned layout will be
None.
"""
el_type = mgpu_utils.dtype_to_ir_type(shape_dtype_struct.dtype)
if not shape_dtype_struct.shape:
return el_type, None
vector_type = ir.VectorType.get(shape_dtype_struct.shape, el_type)
layout = mgpu_layouts.to_layout_attr(shape_dtype_struct.layout.to_mgpu())
return vector_type, layout


def _replace_uses_in_block(old: ir.Value, new: ir.Value, block: ir.Block):
"""Replaces all uses of the `old` value with the `new` value in `block`."""

Expand Down Expand Up @@ -2723,18 +2705,22 @@ def _custom_primitive_in_specs(

def _custom_primitive_op_results(flat_ret_ty) -> tuple[
Sequence[ir.Type],
Sequence[ir.Attribute],
Sequence[ir.Attribute | None],
]:
"""Returns a tuple containing the list of output MLIR types, and layouts for
the given JAX return types."""
results_ty = []
out_layouts = []
results_ty: list[ir.Type] = []
out_layouts: list[ir.Attribute | None] = []
for r in flat_ret_ty:
if not isinstance(r, ShapeDtypeStruct):
raise NotImplementedError(f"Expected a ShapeDtypeStruct, but got: {r}")
ty, layout = _shape_dtype_struct_to_type_and_layout(r)
results_ty.append(ty)
if layout is not None:
el_type = mgpu_utils.dtype_to_ir_type(r.dtype)
if not r.shape: # scalar case.
results_ty.append(el_type)
out_layouts.append(None)
else:
results_ty.append(ir.VectorType.get(r.shape, el_type))
layout = mgpu_layouts.to_layout_attr(r.layout.to_mgpu())
out_layouts.append(layout)
return results_ty, out_layouts

Expand All @@ -2744,10 +2730,10 @@ def _populate_custom_primitive_op_block(
block: ir.Block,
mgpu_fn: Callable[..., Any],
pytree_args,
in_layouts : Sequence[ir.Attribute],
in_layouts: Sequence[ir.Attribute],
in_transforms: ir.ArrayAttr,
results_ty: Sequence[ir.Type],
out_layouts: Sequence[ir.Attribute],
out_layouts: Sequence[ir.Attribute | None],
):
"""Calls the given mgpu_fn to populate the block, handling inputs and outputs.

Expand Down Expand Up @@ -2826,17 +2812,29 @@ def _populate_custom_primitive_op_block(
for fa, result_ty, out_layout in zip(
inner_ret, results_ty, out_layouts, strict=True
):
if not ir.VectorType.isinstance(result_ty):
raise NotImplementedError(
"Only vector return types from the inline mgpu_fn are supported,"
f" but got: {result_ty}"
)
if out_layout != mgpu.layouts.to_layout_attr(fa.layout):
raise ValueError(
f"Output layout {out_layout} does not match the layout of the"
f" returned fragmented array {fa.layout}."
if not isinstance(fa, mgpu.FragmentedArray):
raise ValueError(f"Expected a FragmentedArray, but got: {fa}")
if ir.VectorType.isinstance(result_ty):
result_shape = ir.VectorType(result_ty).shape
if fa.shape != tuple(result_shape):
raise ValueError(f"Expected {result_shape} but got {fa.shape}")
if out_layout != mgpu.layouts.to_layout_attr(fa.layout):
raise ValueError(
f"Output layout {out_layout} does not match the layout of the"
f" returned fragmented array {fa.layout}."
)
ir_ret.append(
mgpu.dialect_lowering.fragmented_array_to_ir(fa, result_ty)
)
ir_ret.append(mgpu.dialect_lowering.fragmented_array_to_ir(fa, result_ty))
else: # scalar case.
assert out_layout is None
if fa.shape:
raise ValueError(f"Expected 0D shape, but got {fa.shape}")
if not isinstance(fa.layout, mgpu.WGSplatFragLayout):
raise ValueError(f"Expected WGSplatFragLayout, but got {fa.layout}")
value = fa.registers.item()
ir_ret.append(value)

mgpu.dialect.ReturnOp(operands_=ir_ret)


Expand Down Expand Up @@ -2893,7 +2891,7 @@ def _inline_mgpu_lowering_rule_wg_semantics(
operands_=flat_transformed_args,
in_layouts=in_layouts,
in_transforms=in_transforms,
out_layouts=out_layouts,
out_layouts=[l for l in out_layouts if l is not None],
)
block : ir.Block = custom_op.body.blocks.append(*in_types)
_populate_custom_primitive_op_block(
Expand Down
15 changes: 13 additions & 2 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,19 @@ llvm::LogicalResult CustomPrimitiveOp::verify() {
"smem.");
}

if (getResults().size() != getOutLayouts().size()) {
return emitOpError("Custom primitive must have a layout for each result.");
int num_vector_results = 0;
for (auto result : getResults()) {
if (mlir::isa<mlir::VectorType>(result.getType())) {
++num_vector_results;
} else if (mlir::isa<mlir::ShapedType>(result.getType())) {
return emitOpError(
"Custom primitive can only return scalars or vectors.");
}
}

if (num_vector_results != getOutLayouts().size()) {
return emitOpError(
"Custom primitive must have a layout for each vector result.");
}

return llvm::success();
Expand Down
4 changes: 2 additions & 2 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def MosaicGPU_ReturnOp : Op<MosaicGPU_Dialect, "return",
}];

// The operand's type must match the parent CustomPrimitiveOp's result type.
let arguments = (ins Variadic<AnyVectorOfAnyRank>:$operands);
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = [{ attr-dict ($operands^ `:` type($operands))? }];
let hasVerifier = 1;
}
Expand All @@ -645,7 +645,7 @@ def MosaicGPU_CustomPrimitiveOp : Op<MosaicGPU_Dialect, "custom_primitive",
ArrayAttr:$out_layouts
);

let results = (outs Variadic<AnyVectorOfAnyRank>);
let results = (outs Variadic<AnyType>);
let regions = (region SizedRegion<1>:$body);

let hasVerifier = 1;
Expand Down
23 changes: 22 additions & 1 deletion tests/mosaic/gpu_dialect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,27 @@ def test_custom_primitive_op_args_must_match_args_of_terminator(self):
):
self.module.operation.verify()

def test_custom_primitive_op_results_must_be_scalar_or_vector(self):
with ir.InsertionPoint(self.module.body):
ref_ty = ir.MemRefType.get((128, 128), ir.F32Type.get())
op = mgpu.dialect.CustomPrimitiveOp(
result=[ref_ty],
operands_=[],
in_layouts=[],
in_transforms=[],
out_layouts=[],
)
block = op.body.blocks.append()
with ir.InsertionPoint(block):
[ref] = undefs(ref_ty)
mgpu.dialect.ReturnOp(operands_=[ref])

with self.assertRaisesRegex(
ir.MLIRError,
r"Custom primitive can only return scalars or vectors.",
):
self.module.operation.verify()

def test_tmem_alloc_op_must_have_smem_ref_input(self):
with ir.InsertionPoint(self.module.body):
(smem_ptr,) = undefs(
Expand Down Expand Up @@ -1431,7 +1452,7 @@ def test_custom_primitive_op_must_have_number_of_annotations_matching_operands_a
error = "transforms for each memref operand in smem"
else:
assert omit_out_layouts
error = "layout for each result"
error = "layout for each vector result"

with self.assertRaisesRegex(ir.MLIRError, error):
self.module.operation.verify()
Expand Down
3 changes: 0 additions & 3 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,9 +1442,6 @@ def kernel(o_ref):
((2, 3, 4, 5), ("a", "b", "c", "d"), (2,), ("x",)),
)
def test_axis_indices_in_grid(self, grid, grid_names, cluster, cluster_names):
# Skipping because `inline_mpgpu` isn't supported in WG semantics.
self.skip_if_wg_semantics()

@functools.partial(
self.kernel,
out_shape=[
Expand Down
Loading