File tree Expand file tree Collapse file tree 1 file changed +3
-8
lines changed
jax/_src/pallas/mosaic_gpu Expand file tree Collapse file tree 1 file changed +3
-8
lines changed Original file line number Diff line number Diff line change @@ -1907,21 +1907,16 @@ def _broadcast_in_dim_lowering_rule_wg(
19071907 sharding ,
19081908):
19091909 del sharding
1910-
19111910 [x_aval ] = ctx .avals_in
1912-
1911+ mlir_type = mgpu_utils .dtype_to_ir_type (x_aval .dtype )
1912+ result_ty = ir .VectorType .get (shape , mlir_type )
19131913 if not broadcast_dimensions :
19141914 # Even though we could implement this case by passing a 0D vector as input
19151915 # to mgpu.dialect.BroadcastInDimOp we don't want that. 0D vectors are
19161916 # generally problematic and so we avoid them by specializing that case
19171917 # directly here.
19181918 x = _ensure_ir_value (x , x_aval .dtype )
1919- return vector_dialect .broadcast (
1920- ir .VectorType .get (shape , mgpu_utils .dtype_to_ir_type (x_aval .dtype )),
1921- x ,
1922- )
1923- mlir_type = mgpu_utils .dtype_to_ir_type (x_aval .dtype )
1924- result_ty = ir .VectorType .get (shape , mlir_type )
1919+ return vector_dialect .broadcast (result_ty , x )
19251920 return mgpu .dialect .broadcast_in_dim (result_ty , x , broadcast_dimensions )
19261921
19271922
You can’t perform that action at this time.
0 commit comments