Skip to content

Commit d41210e

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU][NFC] Style nit.
PiperOrigin-RevId: 835259173
1 parent 8c75b51 commit d41210e

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,21 +1907,16 @@ def _broadcast_in_dim_lowering_rule_wg(
19071907
sharding,
19081908
):
19091909
del sharding
1910-
19111910
[x_aval] = ctx.avals_in
1912-
1911+
mlir_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
1912+
result_ty = ir.VectorType.get(shape, mlir_type)
19131913
if not broadcast_dimensions:
19141914
# Even though we could implement this case by passing a 0D vector as input
19151915
# to mgpu.dialect.BroadcastInDimOp we don't want that. 0D vectors are
19161916
# generally problematic and so we avoid them by specializing that case
19171917
# directly here.
19181918
x = _ensure_ir_value(x, x_aval.dtype)
1919-
return vector_dialect.broadcast(
1920-
ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)),
1921-
x,
1922-
)
1923-
mlir_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
1924-
result_ty = ir.VectorType.get(shape, mlir_type)
1919+
return vector_dialect.broadcast(result_ty, x)
19251920
return mgpu.dialect.broadcast_in_dim(result_ty, x, broadcast_dimensions)
19261921

19271922

0 commit comments

Comments
 (0)