diff --git a/test/optimize_comm.jl b/test/optimize_comm.jl index 9977646daf..e7fe206a0f 100644 --- a/test/optimize_comm.jl +++ b/test/optimize_comm.jl @@ -22,6 +22,15 @@ function dus2(x, y) return nothing end +function halo!(x) + x[9:3065, 7, :] = x[9:3065, 8, :] + x[9:3065, 3066, :] = x[9:3065, 3065, :] + + x[1:8, :, :] = x[(3065 - 8 + 1):3065, :, :] + x[3065:(3065 + 8 - 1), :, :] = x[8:15, :, :] + return nothing +end + if length(addressable_devices) ≥ 8 @testset "Rotate" begin N = min((length(Reactant.devices()) ÷ 2) * 2, 8) @@ -110,4 +119,30 @@ if length(addressable_devices) ≥ 8 @test all(x .== convert(Array, rx)) @test all(y .== convert(Array, ry)) end + + + @testset "Halo" begin + N = min((length(Reactant.devices()) ÷ 2) * 2, 8) + + mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :, 1), (:x, :y, :z)) + sharding = Sharding.NamedSharding(mesh, (:x, :y, :z)) + + M = 3072 + x = reshape(collect(Int, 1:(M * M)), M, M, 1) + rx = Reactant.to_rarray(x; sharding) + + hlo = repr(@code_xla shardy_passes = :to_mhlo_shardings halo!(rx)) + + @test !contains(hlo, "all-to-all") + @test !contains(hlo, "all-reduce") broken = true + @test !contains(hlo, "all-gather") broken = + Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT" + @test contains(hlo, "collective-permute") broken = + Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT" + + dus2(x, y) + @jit shardy_passes = :to_mhlo_shardings dus2(rx, ry) + @test all(x .== convert(Array, rx)) + @test all(y .== convert(Array, ry)) + end end