Skip to content

Commit 9df3d85

Browse files
feat: sharding group (#1811)
* feat: sharding group * Update src/Ops.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 5c5f0a2 commit 9df3d85

File tree

3 files changed

+81
-3
lines changed

3 files changed

+81
-3
lines changed

src/Compiler.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,8 @@ function compile_mlir!(
15801580
args,
15811581
compile_options::CompileOptions,
15821582
callcache=default_callcache(),
1583-
sdycache=default_sdycache();
1583+
sdycache=default_sdycache(),
1584+
sdygroupidcache=default_sdygroupidcache();
15841585
fn_kwargs=(),
15851586
backend="gpu",
15861587
runtime::Union{Val{:PJRT},Val{:IFRT}},
@@ -1597,6 +1598,7 @@ function compile_mlir!(
15971598
MLIR.IR.activate!(MLIR.IR.body(mod))
15981599
activate_callcache!(callcache)
15991600
activate_sdycache!(sdycache)
1601+
activate_sdygroupidcache!(sdygroupidcache)
16001602

16011603
# Save in the TLS whether we are raising. We identify that condition by
16021604
# checking whether the user set an explicit list of passes, or chose
@@ -1623,6 +1625,7 @@ function compile_mlir!(
16231625
finally
16241626
deactivate_raising!(is_raising)
16251627
deactivate_sdycache!(sdycache)
1628+
deactivate_sdygroupidcache!(sdygroupidcache)
16261629
deactivate_callcache!(callcache)
16271630
MLIR.IR.deactivate!(MLIR.IR.body(mod))
16281631
MLIR.IR.deactivate!(mod)
@@ -3832,7 +3835,7 @@ function register_thunk(
38323835
)
38333836
end
38343837

3835-
for cache_type in (:callcache, :sdycache)
3838+
for cache_type in (:callcache, :sdycache, :sdygroupidcache)
38363839
activate_fn = Symbol(:activate_, cache_type, :!)
38373840
deactivate_fn = Symbol(:deactivate_, cache_type, :!)
38383841
has_fn = Symbol(:_has_, cache_type)
@@ -3879,6 +3882,14 @@ function default_sdycache()
38793882
}()
38803883
end
38813884

3885+
mutable struct SdyGroupIDCounter{T}
3886+
@atomic group_id::T
3887+
end
3888+
3889+
function default_sdygroupidcache()
3890+
return SdyGroupIDCounter{Int}(0), Base.IdDict{Union{TracedRArray,TracedRNumber},Int}()
3891+
end
3892+
38823893
function default_callcache()
38833894
return Dict{
38843895
Vector,

src/Ops.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This module reflects the HLO ops defined in the openxla/stablehlo repo (plus some extras).
22
# If you want to add some check or test, the StableHLO spec should be taken as the source of truth, not the Julia or Reactant semantics.
3-
# Julia and Reactant semantics should be considered on the higher abstractions that use these
3+
# Julia and Reactant semantics should be considered on the higher abstractions that use these
44
module Ops
55
using ..MLIR: MLIR
66
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla
@@ -3474,4 +3474,45 @@ end
34743474
)
34753475
end
34763476

3477+
@noinline function sharding_group(
3478+
inputs::Union{TracedRArray,TracedRNumber}...;
3479+
group_id::Union{Integer,Nothing}=nothing,
3480+
location=mlir_stacktrace("sharding_group", @__FILE__, @__LINE__),
3481+
)
3482+
@assert length(inputs) > 1 "At least two inputs are required to form a sharding group, \
3483+
got $(length(inputs))"
3484+
3485+
counter, cache = Reactant.Compiler.sdygroupidcache()
3486+
3487+
group_ids = unique([cache[input] for input in inputs if haskey(cache, input)])
3488+
if length(group_ids) > 1
3489+
error("All inputs must belong to the same sharding group. Found multiple group \
3490+
ids: $(group_ids)")
3491+
end
3492+
3493+
if length(group_ids) == 0
3494+
if group_id === nothing
3495+
group_id = @atomic counter.group_id
3496+
@atomic counter.group_id += 1
3497+
end
3498+
else
3499+
found_group_id = only(group_ids)
3500+
if group_id !== nothing && found_group_id != group_id
3501+
error("Provided group_id $(group_id) does not match the existing group_id \
3502+
$(found_group_id) for the inputs. All inputs must belong to the same \
3503+
sharding group.")
3504+
end
3505+
group_id = found_group_id
3506+
end
3507+
3508+
for input in inputs
3509+
if !haskey(cache, input)
3510+
cache[input] = group_id
3511+
MLIR.Dialects.sdy.sharding_group(input.mlir_data; group_id=group_id, location)
3512+
end
3513+
end
3514+
3515+
return nothing
3516+
end
3517+
34773518
end # module Ops

test/sharding.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,3 +509,29 @@ end
509509
@warn "Not enough addressable devices to run sharding tests"
510510
end
511511
end
512+
513+
@testset "Sharding Group" begin
514+
if length(Reactant.devices()) 4 && Reactant.XLA.runtime() isa Val{:IFRT}
515+
mesh = Sharding.Mesh(reshape(0:3, 2, 2), (:x, :y))
516+
sharding = Sharding.NamedSharding(mesh, (:x, :y))
517+
518+
function shard_groups(x)
519+
y = (x' * x)[1:4, :]
520+
Reactant.Ops.sharding_group(x, y)
521+
z = y .+ x
522+
Reactant.Ops.sharding_group(z, y)
523+
return z
524+
end
525+
526+
x = Reactant.to_rarray(
527+
Reactant.TestUtils.construct_test_array(Float32, 4, 128); sharding
528+
)
529+
530+
hlo = repr(@code_hlo shard_groups(x))
531+
532+
@test count("sharding_group", hlo) == 3
533+
@test count("group_id=0", hlo) == 3
534+
else
535+
@warn "Not enough addressable devices to run sharding tests"
536+
end
537+
end

0 commit comments

Comments
 (0)