@@ -2603,24 +2603,6 @@ def _ref_type_to_transforms(ref_type: RefType) -> ir.ArrayAttribute:
26032603 return ir .ArrayAttr .get (transform_attrs )
26042604
26052605
2606- def _shape_dtype_struct_to_type_and_layout (
2607- shape_dtype_struct : ShapeDtypeStruct ,
2608- ) -> tuple [ir .Type , ir .Attribute | None ]:
2609- """Returns the type and Mosaic GPU layout for the given ShapeDtypeStruct.
2610-
2611- Unless the input indicates a scalar, the returned type will be a vector type
2612- and the returned layout will not be None. If the input is a scalar, the
2613- returned type will be the type of the scalar and the returned layout will be
2614- None.
2615- """
2616- el_type = mgpu_utils .dtype_to_ir_type (shape_dtype_struct .dtype )
2617- if not shape_dtype_struct .shape :
2618- return el_type , None
2619- vector_type = ir .VectorType .get (shape_dtype_struct .shape , el_type )
2620- layout = mgpu_layouts .to_layout_attr (shape_dtype_struct .layout .to_mgpu ())
2621- return vector_type , layout
2622-
2623-
26242606def _replace_uses_in_block (old : ir .Value , new : ir .Value , block : ir .Block ):
26252607 """Replaces all uses of the `old` value with the `new` value in `block`."""
26262608
@@ -2723,18 +2705,22 @@ def _custom_primitive_in_specs(
27232705
27242706def _custom_primitive_op_results (flat_ret_ty ) -> tuple [
27252707 Sequence [ir .Type ],
2726- Sequence [ir .Attribute ],
2708+ Sequence [ir .Attribute | None ],
27272709]:
27282710 """Returns a tuple containing the list of output MLIR types, and layouts for
27292711 the given JAX return types."""
2730- results_ty = []
2731- out_layouts = []
2712+ results_ty : list [ ir . Type ] = []
2713+ out_layouts : list [ ir . Attribute | None ] = []
27322714 for r in flat_ret_ty :
27332715 if not isinstance (r , ShapeDtypeStruct ):
27342716 raise NotImplementedError (f"Expected a ShapeDtypeStruct, but got: { r } " )
2735- ty , layout = _shape_dtype_struct_to_type_and_layout (r )
2736- results_ty .append (ty )
2737- if layout is not None :
2717+ el_type = mgpu_utils .dtype_to_ir_type (r .dtype )
2718+ if not r .shape : # scalar case.
2719+ results_ty .append (el_type )
2720+ out_layouts .append (None )
2721+ else :
2722+ results_ty .append (ir .VectorType .get (r .shape , el_type ))
2723+ layout = mgpu_layouts .to_layout_attr (r .layout .to_mgpu ())
27382724 out_layouts .append (layout )
27392725 return results_ty , out_layouts
27402726
@@ -2744,10 +2730,10 @@ def _populate_custom_primitive_op_block(
27442730 block : ir .Block ,
27452731 mgpu_fn : Callable [..., Any ],
27462732 pytree_args ,
2747- in_layouts : Sequence [ir .Attribute ],
2733+ in_layouts : Sequence [ir .Attribute ],
27482734 in_transforms : ir .ArrayAttr ,
27492735 results_ty : Sequence [ir .Type ],
2750- out_layouts : Sequence [ir .Attribute ],
2736+ out_layouts : Sequence [ir .Attribute | None ],
27512737):
27522738 """Calls the given mgpu_fn to populate the block, handling inputs and outputs.
27532739
@@ -2826,17 +2812,29 @@ def _populate_custom_primitive_op_block(
28262812 for fa , result_ty , out_layout in zip (
28272813 inner_ret , results_ty , out_layouts , strict = True
28282814 ):
2829- if not ir .VectorType .isinstance (result_ty ):
2830- raise NotImplementedError (
2831- "Only vector return types from the inline mgpu_fn are supported,"
2832- f" but got: { result_ty } "
2833- )
2834- if out_layout != mgpu .layouts .to_layout_attr (fa .layout ):
2835- raise ValueError (
2836- f"Output layout { out_layout } does not match the layout of the"
2837- f" returned fragmented array { fa .layout } ."
2815+ if not isinstance (fa , mgpu .FragmentedArray ):
2816+ raise ValueError (f"Expected a FragmentedArray, but got: { fa } " )
2817+ if ir .VectorType .isinstance (result_ty ):
2818+ result_shape = ir .VectorType (result_ty ).shape
2819+ if fa .shape != tuple (result_shape ):
2820+ raise ValueError (f"Expected { result_shape } but got { fa .shape } " )
2821+ if out_layout != mgpu .layouts .to_layout_attr (fa .layout ):
2822+ raise ValueError (
2823+ f"Output layout { out_layout } does not match the layout of the"
2824+ f" returned fragmented array { fa .layout } ."
2825+ )
2826+ ir_ret .append (
2827+ mgpu .dialect_lowering .fragmented_array_to_ir (fa , result_ty )
28382828 )
2839- ir_ret .append (mgpu .dialect_lowering .fragmented_array_to_ir (fa , result_ty ))
2829+ else : # scalar case.
2830+ assert out_layout is None
2831+ if fa .shape :
2832+ raise ValueError (f"Expected 0D shape, but got { fa .shape } " )
2833+ if not isinstance (fa .layout , mgpu .WGSplatFragLayout ):
2834+ raise ValueError (f"Expected WGSplatFragLayout, but got { fa .layout } " )
2835+ value = fa .registers .item ()
2836+ ir_ret .append (value )
2837+
28402838 mgpu .dialect .ReturnOp (operands_ = ir_ret )
28412839
28422840
@@ -2893,7 +2891,7 @@ def _inline_mgpu_lowering_rule_wg_semantics(
28932891 operands_ = flat_transformed_args ,
28942892 in_layouts = in_layouts ,
28952893 in_transforms = in_transforms ,
2896- out_layouts = out_layouts ,
2894+ out_layouts = [ l for l in out_layouts if l is not None ] ,
28972895 )
28982896 block : ir .Block = custom_op .body .blocks .append (* in_types )
28992897 _populate_custom_primitive_op_block (
0 commit comments