Skip to content

Commit 7b052da

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Support inference of additional layouts for BroadcastInDimOp.
PiperOrigin-RevId: 836237946
1 parent 4d72ab2 commit 7b052da

File tree

3 files changed

+41
-41
lines changed

3 files changed

+41
-41
lines changed

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -209,20 +209,23 @@ def extract_constant_from_broadcast_in_dim_expression_for_hint(
209209
if not isinstance(e.expression, eqns.RegisterLayout):
210210
return None
211211

212-
reduced_layout = e.expression.value
213-
214-
wgmma_tm, wgmma_tn = fa.WGMMA_LAYOUT.base_tile_shape
215-
# TODO(bchetioui): enable generators to handle TCGEN05 layout from WGMMA_COL.
216-
if reduced_layout == fa.WGMMA_COL_LAYOUT and e.axes == (1,) and e.shape[0] % wgmma_tm == 0:
217-
return eqns.RegisterLayout(fa.WGMMA_LAYOUT)
218-
219-
if reduced_layout == fa.WGMMA_ROW_LAYOUT and e.axes == (0,) and e.shape[1] % wgmma_tn == 0:
220-
return eqns.RegisterLayout(fa.WGMMA_LAYOUT)
221-
222-
tcgen05_tm, _ = fa.TCGEN05_LAYOUT.base_tile_shape
223-
if reduced_layout == fa.TCGEN05_ROW_LAYOUT and e.axes == (0,) and e.shape[0] % tcgen05_tm == 0:
224-
return eqns.RegisterLayout(fa.TCGEN05_LAYOUT)
212+
candidates = [
213+
fa.WGMMA_LAYOUT,
214+
fa.WGMMA_TRANSPOSED_LAYOUT,
215+
fa.TCGEN05_LAYOUT,
216+
fa.TCGEN05_TRANSPOSED_LAYOUT,
217+
tcgen05.TMEM_NATIVE_LAYOUT,
218+
]
219+
if e.shape[-1] % 16 == 0:
220+
candidates.append(tcgen05.fa_m64_collective_layout(e.shape[-1]))
225221

222+
# TODO(allanrenucci): Allow returning multiple valid candidates.
223+
reduction_dims = tuple(d for d in range(len(e.shape)) if d not in e.axes)
224+
for candidate in candidates:
225+
if len(candidate.base_tile_shape) > len(e.shape):
226+
continue
227+
if candidate.reduce(reduction_dims) == e.expression.value:
228+
return eqns.RegisterLayout(candidate)
226229
return None
227230

228231

tests/mosaic/gpu_layout_inference_test.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -293,38 +293,39 @@ def test_infer_layout_cast_layout(self):
293293
self.checkInLayouts(cast, [wgmma_layout])
294294
self.checkOutLayouts(cast, [wgmma_layout])
295295

296-
@parameterized.parameters(
297-
(0, mgpu.WGMMA_ROW_LAYOUT, None),
298-
(1, mgpu.WGMMA_COL_LAYOUT, None),
299-
(0, None, mgpu.WGMMA_LAYOUT),
300-
(1, None, mgpu.WGMMA_LAYOUT),
301-
(0, mgpu.TCGEN05_ROW_LAYOUT, None),
302-
(0, None, mgpu.TCGEN05_LAYOUT),
303-
(1, None, mgpu.TCGEN05_LAYOUT),
296+
@parameterized.product(
297+
layout=(
298+
mtu.RegisterLayout.WGMMA,
299+
mtu.RegisterLayout.TCGEN05,
300+
mtu.RegisterLayout.TCGEN05_TMEM_NATIVE,
301+
mtu.RegisterLayout.TCGEN05_M64_COLLECTIVE,
302+
),
303+
axis=(0, 1),
304+
hint_on_input=(True, False),
304305
)
305-
def test_infer_broadcast_in_dim_layout(self, broadcast_dim, in_cast, out_cast):
306+
def test_infer_broadcast_in_dim_layout(self, layout, axis, hint_on_input):
306307
in_shape = (128,)
307308
out_shape = (128, 128)
309+
dtype = ir.F32Type.get()
310+
out_layout = layout.to_mgpu(out_shape, dtype)
311+
in_layout = out_layout.reduce((1 - axis,))
308312

309313
with ir.InsertionPoint(self.module.body):
310-
[x] = undefs(ir.VectorType.get(in_shape, ir.F32Type.get()))
311-
x = layout_cast(x, in_cast) if in_cast is not None else x
312-
out_type = ir.VectorType.get(out_shape, ir.F32Type.get())
313-
bcast = mgpu.dialect.BroadcastInDimOp(out_type, x, [broadcast_dim])
314-
if out_cast is not None:
315-
layout_cast(bcast.result, out_cast)
316-
317-
# The tests always expect WGMMA or TCGEN05 as the out layout.
318-
if out_cast == mgpu.TCGEN05_LAYOUT or in_cast == mgpu.TCGEN05_ROW_LAYOUT:
319-
out_layout = mgpu.TCGEN05_LAYOUT
320-
else:
321-
out_layout = mgpu.WGMMA_LAYOUT
314+
[x] = undefs(ir.VectorType.get(in_shape, dtype))
315+
if hint_on_input:
316+
x = layout_cast(x, in_layout)
317+
out_type = ir.VectorType.get(out_shape, dtype)
318+
bcast = mgpu.dialect.BroadcastInDimOp(out_type, x, [axis])
319+
if not hint_on_input:
320+
layout_cast(bcast.result, out_layout)
322321

323-
in_layout = out_layout.reduce((1 - broadcast_dim,))
322+
if hint_on_input and axis == 1 and layout == mtu.RegisterLayout.TCGEN05:
323+
# Both TCGEN05 and WGMMA are valid layout candidates. WGMMA is tried first.
324+
out_layout = fa.WGMMA_LAYOUT
324325

325326
mgpu.infer_layout(self.module)
326-
self.checkInLayouts(bcast, [layouts.to_layout_attr(in_layout)])
327-
self.checkOutLayouts(bcast, [layouts.to_layout_attr(out_layout)])
327+
self.checkInLayouts(bcast, [in_layout])
328+
self.checkOutLayouts(bcast, [out_layout])
328329

329330
@parameterized.parameters(
330331
(1, mgpu.WGMMA_LAYOUT, None, None),

tests/pallas/mosaic_gpu_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2345,10 +2345,6 @@ def kernel(x_ref, y_ref, smem_ref, smem_reduced_ref, barrier_ref):
23452345
np.testing.assert_allclose(x_result, op(x, axis=axis), atol=1e-5)
23462346

23472347
def _test_broadcast_in_dim_base(self, shape, layout, *, axis, hint):
2348-
if not hint:
2349-
# When the hint is not set, inference may choose incompatible layouts.
2350-
# TODO(bchetioui): investigate and fix.
2351-
self.skip_if_wg_semantics()
23522348
assert len(shape) == 2
23532349

23542350
@functools.partial(

0 commit comments

Comments
 (0)