Skip to content
Closed

symm op #1777

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
cca8ff3
symm op
snonk Oct 19, 2025
5a4b3a4
fix: missing dep (#1779)
avik-pal Oct 20, 2025
322bd41
fix: mark enzymexla symbols as exported (#1781)
avik-pal Oct 20, 2025
0fd57e1
feat: julia api to access device properties (#1762)
avik-pal Oct 21, 2025
6c6fcc9
[deps] Make CUDA build more robust (#1782)
giordano Oct 21, 2025
3622657
chore: update ENZYMEXLA_COMMIT hash in WORKSPACE
avik-pal Oct 22, 2025
fdc56b4
Regenerate MLIR Bindings (#1784)
github-actions[bot] Oct 23, 2025
8b05cdc
Some 1.12 fixes
wsmoses Oct 24, 2025
e77039c
Apply suggestions from code review
wsmoses Oct 24, 2025
95e2217
Add ARM64 support flag to Bazel configuration
wsmoses Oct 27, 2025
2452151
Update ENZYMEXLA_COMMIT hash in WORKSPACE
avik-pal Oct 27, 2025
8a06166
[CI] Add workflow for automatically updating Enzyme-JAX (#1793)
giordano Oct 28, 2025
96a819d
Regenerate MLIR Bindings (#1795)
github-actions[bot] Oct 28, 2025
5751b52
Static cuda attempt
wsmoses Oct 19, 2025
180ee83
f
wsmoses Oct 19, 2025
adb655f
Update ENZYMEXLA_COMMIT and ml_toolchain_workspace
wsmoses Oct 26, 2025
024548c
fix
wsmoses Oct 26, 2025
a2abd76
fix
wsmoses Oct 26, 2025
1a86d57
fix
wsmoses Oct 28, 2025
fe6ee8e
bump
wsmoses Oct 28, 2025
54f5d5d
Update ENZYMEXLA_COMMIT and ml_toolchain_workspace
wsmoses Oct 28, 2025
f498044
feat: new jll version + new compiler passes (#1791)
avik-pal Oct 28, 2025
6e1a02c
feat: some more 1.12 support
avik-pal Oct 28, 2025
1737a70
Update ENZYMEXLA_COMMIT hash in WORKSPACE
avik-pal Oct 28, 2025
158b986
fix
wsmoses Oct 28, 2025
e5fc60f
symm op
snonk Oct 19, 2025
155630f
update build
snonk Oct 29, 2025
41e2a67
untrack some
snonk Oct 29, 2025
90412ef
merge
snonk Oct 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/update-enzyme-jax.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: "Open PR to update Enzyme-JAX commit"

on:
schedule:
- cron: '19 16 * * *'
workflow_dispatch:
inputs:
enzyme_jax_commit:
description: 'The Enzyme-JAX commit to update to (optional)'
default: ''
type: 'string'

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
pr-latest-enzyme-jax:
name: 'Update Enzyme-JAX'
uses: EnzymeAD/Enzyme-JAX/.github/workflows/update-dependency.yml@main
with:
upstream_repo: 'EnzymeAD/Enzyme-JAX'
upstream_commit: ${{ inputs.enzyme_jax_commit }}
variable_name: 'ENZYMEXLA_COMMIT'
workspace_path: 'deps/ReactantExtra/WORKSPACE'
secrets: inherit
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
version = "0.2.171"
version = "0.2.172"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -105,7 +105,7 @@ PythonCall = "0.9.25"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.16"
Reactant_jll = "0.0.251"
Reactant_jll = "0.0.254"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion benchmark/aggregate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ for backend in BACKENDS
end

open(joinpath(dirname(@__FILE__), "results", "combinedbenchmarks.json"), "w") do io
JSON3.pretty(io, JSON3.write(all_results))
return JSON3.pretty(io, JSON3.write(all_results))
end
2 changes: 1 addition & 1 deletion benchmark/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ for (i, (k, v)) in enumerate(results)
end

open(joinpath(filepath, filename), "w") do io
JSON3.pretty(io, JSON3.write(standardized_results))
return JSON3.pretty(io, JSON3.write(standardized_results))
end

@info "Saved results to $(joinpath(filepath, filename))"
18 changes: 11 additions & 7 deletions deps/ReactantExtra/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ build --repo_env=RULES_PYTHON_ENABLE_PYSTAR=0

build -c opt

common:macos --define ynn_enable_arm64_sme=false

common:cuda --repo_env TF_NEED_CUDA=1
common:cuda --repo_env TF_NVCC_CLANG=1
common:cuda --repo_env TF_NCCL_USE_STUB=1
common:cuda_static --@rules_ml_toolchain//common:link_cuda_static_libs=true
common:cuda_static --@rules_ml_toolchain//common:link_nvrtc_static_libs=true
common:cuda --@local_config_cuda//:enable_cuda
common:cuda --crosstool_top="@local_config_cuda//crosstool:toolchain"
# Default hermetic CUDA and CUDNN versions.
Expand All @@ -40,19 +44,19 @@ common:cuda --@local_config_cuda//:cuda_compiler=nvcc
# common:cuda --@local_config_nvshmem//:override_include_nvshmem_libs=true
# common:cuda --@local_config_nvshmem//cuda:include_nvshmem_libs=true


common:cuda12 --config=cuda
common:cuda12 --repo_env=HERMETIC_CUDA_VERSION="12.8.1"
common:cuda12 --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
common:cuda12 --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5"
common:cuda12 --config=cuda_static
common:cuda12 --repo_env=HERMETIC_CUDA_VERSION="12.9.1"
common:cuda12 --repo_env=HERMETIC_CUDNN_VERSION="9.14.0"
common:cuda12 --repo_env=HERMETIC_NVSHMEM_VERSION="3.3.9"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
common:cuda12 --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,compute_90"
common:cuda12 --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120"


common:cuda13 --config=cuda
common:cuda13 --repo_env=HERMETIC_CUDA_VERSION="13.0.0"
common:cuda13 --repo_env=HERMETIC_CUDNN_VERSION="9.12.0"
common:cuda13 --repo_env=HERMETIC_CUDA_VERSION="13.0.2"
common:cuda13 --repo_env=HERMETIC_CUDNN_VERSION="9.14.0"
common:cuda13 --repo_env=HERMETIC_NVSHMEM_VERSION="3.3.20"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
Expand Down
2 changes: 2 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ cc_library(
"//conditions:default": [],
"@bazel_tools//src/conditions:darwin": [
"-Wl,-exported_symbol,_stablehlo*",
"-Wl,-exported_symbol,_enzymexla*",
"-Wl,-exported_symbol,_mlir*",
"-Wl,-exported_symbol,_sdy*",
"-Wl,-exported_symbol,_EnzymeJaXMapSymbol",
Expand Down Expand Up @@ -1078,6 +1079,7 @@ cc_library(
"@llvm-project//llvm:X86CodeGen",
"@enzyme_ad//src/enzyme_ad/jax:TransformOps",
"@enzyme_ad//src/enzyme_ad/jax:XLADerivatives",
"@enzyme_ad//src/enzyme_ad/jax:CInterface",
# "@enzyme_ad//src/enzyme_ad/jax:gpu",
"@xla//xla/ffi/api:ffi",
"@xla//xla/ffi:ffi_api",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"

NSYNC_SHA256 = ""

ENZYMEXLA_COMMIT = "a7c38bce984c3adedafb8e03282c0e39640ab6f9"
ENZYMEXLA_COMMIT = "c5b0090d53998673b2f728b7590b97d7bc548d2b"

ENZYMEXLA_SHA256 = ""

Expand Down
38 changes: 19 additions & 19 deletions deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ end
src_dir = joinpath(dirname(dirname(@__DIR__)), "src")

for file in [
"Builtin.jl",
"Arith.jl",
"Affine.jl",
"Func.jl",
"Enzyme.jl",
# "Builtin.jl",
# "Arith.jl",
# "Affine.jl",
# "Func.jl",
# "Enzyme.jl",
"EnzymeXLA.jl",
"StableHLO.jl",
"CHLO.jl",
"VHLO.jl",
"Llvm.jl",
"Nvvm.jl",
"Gpu.jl",
"Affine.jl",
"TPU.jl",
"MosaicGPU.jl",
"Triton.jl",
"Shardy.jl",
"MPI.jl",
"MemRef.jl",
"SparseTensor.jl",
# "StableHLO.jl",
# "CHLO.jl",
# "VHLO.jl",
# "Llvm.jl",
# "Nvvm.jl",
# "Gpu.jl",
# "Affine.jl",
# "TPU.jl",
# "MosaicGPU.jl",
# "Triton.jl",
# "Shardy.jl",
# "MPI.jl",
# "MemRef.jl",
# "SparseTensor.jl",
]
build_file(joinpath(src_dir, "mlir", "Dialects", file))
end
Expand Down
93 changes: 60 additions & 33 deletions deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,36 +115,50 @@ source_dir = joinpath(@__DIR__, "ReactantExtra")
# --@local_config_cuda//:cuda_compiler=nvcc
# --crosstool_top="@local_config_cuda//crosstool:toolchain"

build_kind = parsed_args["debug"] ? "dbg" : "opt"
abstract type AbstractBackend end
struct CPUBackend <: AbstractBackend end
struct CUDABackend <: AbstractBackend
version::VersionNumber
CUDABackend(ver::VersionNumber) = new(VersionNumber(ver.major))
end

build_backend = parsed_args["backend"]
function parse_build_backend(str::String)::AbstractBackend
if str == "cpu"
return CPUBackend()
elseif str == "cuda12"
return CUDABackend(v"12")
elseif str == "cuda13"
return CUDABackend(v"13")
end

if build_backend == "auto" || build_backend == "cuda"
cuda_ver = get_cuda_version()
@show cuda_ver
if cuda_ver === nothing
if build_backend == "cuda"
throw(
AssertionError(
"Could not detect cuda version, but requested cuda with auto version build",
),
)
end
build_backend = "cpu"
else
if Int(get_cuda_version().major) == 13
build_backend = "cuda13"
if str in ("auto", "cuda")
cuda_ver = get_cuda_version()
if isnothing(cuda_ver)
if str == "cuda"
throw(
AssertionError(
"Could not detect cuda version, but requested cuda with auto version build",
),
)
end
return CPUBackend()
else
build_backend = "cuda12"
return CUDABackend(get_cuda_version())
end
else
error("Unknown backend '$(str)'")
end
end

arg = if build_backend == "cuda12"
build_kind = parsed_args["debug"] ? "dbg" : "opt"

build_backend = parse_build_backend(parsed_args["backend"])

arg = if build_backend == CUDABackend(v"12")
"--config=cuda12"
elseif build_backend == "cuda13"
elseif build_backend == CUDABackend(v"13")
"--config=cuda13"
elseif build_backend == "cpu"
elseif build_backend == CPUBackend()
""
else
throw(AssertionError("Unknown backend `$build_backend`"))
Expand Down Expand Up @@ -197,6 +211,8 @@ push!(build_cmd_list, "--jobs=$(parsed_args["jobs"])")
push!(build_cmd_list, "--experimental_ui_max_stdouterr_bytes=-1")
push!(build_cmd_list, "--sandbox_debug")

push!(build_cmd_list, "--linkopt=-fuse-ld=lld")

for opt in parsed_args["copt"]
push!(build_cmd_list, "--copt=$(opt)")
end
Expand Down Expand Up @@ -231,35 +247,46 @@ push!(build_cmd_list, "--copt=-Wno-private-header")
push!(build_cmd_list, "--color=$(parsed_args["color"])")
push!(build_cmd_list, ":libReactantExtra.so")

@info "About to run Bazel" build_cmd_list
run(Cmd(Cmd(build_cmd_list); dir=source_dir))

# Discover built libraries
built_libs = filter(readdir(joinpath(source_dir, "bazel-bin"))) do file
endswith(file, "Extra.so") && startswith(file, "lib")
return endswith(file, "Extra.so") && startswith(file, "lib")
end

lib_path = joinpath(source_dir, "bazel-bin", only(built_libs))
isfile(lib_path) || error("Could not find library $lib_path in build directory")

if build_backend == "cuda"
if build_backend isa CUDABackend
for path in (
joinpath("bin", "ptxas"),
joinpath("bin", "fatbinary"),
joinpath("nvvm", "libdevice", "libdevice.10.bc"),
)
full_path = joinpath(source_dir, "bazel-bin", "cuda", path)
if !Base.Filesystem.ispath(full_path)
Base.Filesystem.mkpath(dirname(full_path))
Base.Filesystem.symlink(
joinpath(
source_dir,
"bazel-bin",
"libReactantExtra.so.runfiles",
"cuda_nvcc",
path,
),
full_path,
source = joinpath(
source_dir,
"bazel-bin",
"libReactantExtra.so.runfiles",
# libdevice's directory was moved in CUDA 13, before was in same
# dir as ptxas and fatbinary
if contains(basename(path), "libdevice") && build_backend.version >= v"13"
"cuda_nvvm"
else
"cuda_nvcc"
end,
path,
)
if !Base.Filesystem.ispath(source)
error(
"File $(source) does not exist, are you sure it is where you expect it to be?",
)
end
Base.Filesystem.mkpath(dirname(full_path))
@info "Symlinking $(full_path) -> $(source)"
Base.Filesystem.symlink(source, full_path)
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1282,8 +1282,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

@assert length(restys) == length(aliases)
call = MLIR.Dialects.enzymexla.kernel_call(
blk_operands...,
mlir_args;
blk_operands...;
inputs=mlir_args,
result_0=restys,
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases),
Expand Down
27 changes: 19 additions & 8 deletions ext/ReactantKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,26 @@ function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsi
return nothing
end

Reactant.@reactant_overlay @noinline Base.@nospecializeinfer function (
obj::KA.Kernel{ReactantBackend}
)(
args...; ndrange=nothing, workgroupsize=nothing
)
@nospecialize
return Reactant.call_with_reactant(
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
@static if VERSION < v"1.12-"
Reactant.@reactant_overlay Base.@nospecializeinfer @noinline function (
obj::KA.Kernel{ReactantBackend}
)(
@nospecialize args...; ndrange=nothing, workgroupsize=nothing
)
return Reactant.call_with_reactant(
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
)
end
else
Reactant.@reactant_overlay function (obj::KA.Kernel{ReactantBackend})(
args...; ndrange=nothing, workgroupsize=nothing
)
Base.@_noinline_meta
Base.@_nospecializeinfer_meta
return Reactant.call_with_reactant(
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
)
end
end

end
Loading