Skip to content

Commit 4d72ab2

Browse files
[Mosaic GPU] Add more info to exception messages in core.py
PiperOrigin-RevId: 836233068
1 parent 3d79bc8 commit 4d72ab2

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

jax/experimental/mosaic/gpu/core.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,13 @@ def _mosaic_gpu_lowering_rule(
181181
not np.array_equal(mesh.device_ids.ravel(), np.arange(mesh.size))):
182182
raise NotImplementedError(
183183
"Mosaic GPU only supports meshes with device ordering that follows"
184-
" row-major device ids."
184+
f" row-major device ids. Got: {mesh.device_ids.ravel()} device ids."
185185
)
186186
elif isinstance(axis_context, sharding_impls.ShardingContext):
187187
if axis_context.num_devices != 1:
188188
raise NotImplementedError(
189189
"Mosaic GPU only supports single-device meshes in ShardingContext."
190+
f" Got: {axis_context.num_devices} devices."
190191
)
191192
else:
192193
raise NotImplementedError(f"Unsupported sharding context: {axis_context}")
@@ -218,7 +219,7 @@ def _mosaic_gpu_lowering_rule(
218219
# SHA256 that it shouldn't be a problem.
219220
if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None:
220221
if kernel_text != module_asm:
221-
raise RuntimeError("Hash collision!")
222+
raise RuntimeError("Kernel hash collision!")
222223
else:
223224
KNOWN_KERNELS[kernel_id] = module_asm
224225

@@ -513,7 +514,11 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int:
513514
| Barrier(_, num_barriers=num_barriers)
514515
):
515516
if size % utils.MBARRIER_BYTES:
516-
raise NotImplementedError("Misaligned barrier allocation")
517+
raise NotImplementedError(
518+
"Misaligned barrier allocation. Expected smem size"
519+
f" ({size} bytes) to be divisible by the size of the barrier:"
520+
f" {utils.MBARRIER_BYTES} bytes."
521+
)
517522
size += num_barriers * utils.MBARRIER_BYTES
518523
case TMEM(_):
519524
# TODO(justinfu): This can trigger misaligned barrier allocations
@@ -538,7 +543,10 @@ def _launch(
538543
maybe_prof_buffer: ir.Value | None = None,
539544
):
540545
if (profiler_spec is None) != (maybe_prof_buffer is None):
541-
raise ValueError
546+
raise ValueError(
547+
"Both profiler_spec and maybe_prof_buffer must be specified or"
548+
" left unspecified."
549+
)
542550
index = ir.IndexType.get()
543551
i32 = ir.IntegerType.get_signless(32)
544552
i8 = ir.IntegerType.get_signless(8)
@@ -571,7 +579,7 @@ def _launch(
571579
f"{smem_bytes=} > {max_smem_bytes=}")
572580
if math.prod(cluster) != 1:
573581
if len(cluster) != 3:
574-
raise ValueError("Clusters must be 3D")
582+
raise ValueError(f"Clusters must be 3D. Got: {cluster}")
575583
cluster_kwargs = {
576584
"clusterSize" + d: c(s, index) for s, d in zip(cluster, "XYZ")
577585
}
@@ -663,8 +671,9 @@ def _launch(
663671
collective = True in collective_types
664672
if collective and math.prod(cluster) % 2:
665673
raise ValueError(
666-
"Collective TMEM allocations are only supported for"
667-
" clusters with an even number of blocks in them."
674+
"Collective TMEM allocations are only supported for clusters"
675+
" with an even number of blocks in them. Got cluster:"
676+
f" {cluster}"
668677
)
669678
if lowering_semantics == LoweringSemantics.Warpgroup:
670679
dialect.tmem_relinquish_alloc_permit(collective=collective)

0 commit comments

Comments
 (0)