From f8d3203666fb717422e439c48b58191aa1f50008 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 3 Nov 2025 10:28:55 -0600 Subject: [PATCH 1/2] Add halo test --- test/optimize_comm.jl | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/optimize_comm.jl b/test/optimize_comm.jl index 9977646daf..3dc4a73761 100644 --- a/test/optimize_comm.jl +++ b/test/optimize_comm.jl @@ -22,6 +22,15 @@ function dus2(x, y) return nothing end +function halo!(x) + x[9:3065, 7, :] = x[9:3065, 8, :] + x[9:3065, 3066, :] = x[9:3065, 3065, :] + + x[1:8, :, :] = x[(3065-8+1):3065, :, :] + x[3065:(3065+8-1), :, :] = x[8:15, :, :] + return nothing +end + if length(addressable_devices) ≥ 8 @testset "Rotate" begin N = min((length(Reactant.devices()) ÷ 2) * 2, 8) @@ -110,4 +119,30 @@ if length(addressable_devices) ≥ 8 @test all(x .== convert(Array, rx)) @test all(y .== convert(Array, ry)) end + + + @testset "Halo" begin + N = min((length(Reactant.devices()) ÷ 2) * 2, 8) + + mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :, 1), (:x, :y, :z)) + sharding = Sharding.NamedSharding(mesh, (:x, :y, :z)) + + M = 3072 + x = reshape(collect(Int, 1:(M * M)), M, M, 1) + rx = Reactant.to_rarray(x; sharding) + + hlo = repr(@code_xla shardy_passes = :to_mhlo_shardings halo!(rx)) + + @test !contains(hlo, "all-to-all") + @test !contains(hlo, "all-reduce") broken = true + @test !contains(hlo, "all-gather") broken = + Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT" + @test contains(hlo, "collective-permute") broken = + Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT" + + dus2(x, y) + @jit shardy_passes = :to_mhlo_shardings dus2(rx, ry) + @test all(x .== convert(Array, rx)) + @test all(y .== convert(Array, ry)) + end end From bf1b85cb1f9412942f80399b6282873249162ab4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 3 Nov 2025 10:30:52 -0600 Subject: [PATCH 2/2] Update test/optimize_comm.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/optimize_comm.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/optimize_comm.jl b/test/optimize_comm.jl index 3dc4a73761..e7fe206a0f 100644 --- a/test/optimize_comm.jl +++ b/test/optimize_comm.jl @@ -26,8 +26,8 @@ function halo!(x) x[9:3065, 7, :] = x[9:3065, 8, :] x[9:3065, 3066, :] = x[9:3065, 3065, :] - x[1:8, :, :] = x[(3065-8+1):3065, :, :] - x[3065:(3065+8-1), :, :] = x[8:15, :, :] + x[1:8, :, :] = x[(3065 - 8 + 1):3065, :, :] + x[3065:(3065 + 8 - 1), :, :] = x[8:15, :, :] return nothing end