Skip to content

Commit 2d0e0e3

Browse files
wsmosesavik-palgithub-actions[bot]
authored
Use Enzyme.ignore_derivatives, now that landed (#1707)
* Use Enzyme.ignore_derivatives, now that landed * move * fix * Apply suggestion from @github-actions[bot] Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update ReactantABI to use EnzymeCore directly * Change ignore_derivatives reference to EnzymeCore * fix: docs --------- Co-authored-by: Avik Pal <avikpal@mit.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 35c74e2 commit 2d0e0e3

File tree

10 files changed

+38
-31
lines changed

10 files changed

+38
-31
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ CUDA = "5.6"
8181
DLFP8Types = "0.1"
8282
Downloads = "1.6"
8383
EnumX = "1"
84-
Enzyme = "0.13.78"
85-
EnzymeCore = "0.8.13"
84+
Enzyme = "0.13.81"
85+
EnzymeCore = "0.8.14"
8686
FillArrays = "1.13"
8787
Float8s = "0.1"
8888
Functors = "0.5"

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
44
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
56
OpenSSL_jll = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
67
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
78
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Reactant, ReactantCore
1+
using Reactant, ReactantCore, EnzymeCore
22
using Documenter, DocumenterVitepress
33

44
DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true)
@@ -7,6 +7,7 @@ makedocs(;
77
modules=[
88
Reactant,
99
ReactantCore,
10+
EnzymeCore,
1011
Reactant.XLA,
1112
Reactant.MLIR,
1213
Reactant.MLIR.API,

docs/src/api/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Reactant.addressable_devices
7878
## Differentiation Specific API
7979

8080
```@docs
81-
Reactant.ignore_derivatives
81+
EnzymeCore.ignore_derivatives
8282
```
8383

8484
## Persistent Compilation Cache

docs/src/api/internal.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,14 @@ Reactant.Compiler.codegen_unflatten!
1414
Reactant.Compiler.codegen_flatten!
1515
Reactant.Compiler.codegen_xla_call
1616
```
17+
18+
## Other Docstrings
19+
20+
!!! warning "Private"
21+
22+
These docstrings are present here to prevent missing docstring warnings. For official
23+
Enzyme documentation checkout https://enzymead.github.io/Enzyme.jl/stable/.
24+
25+
```@autodocs
26+
Modules = [EnzymeCore, EnzymeCore.EnzymeRules]
27+
```

docs/src/tutorials/automatic-differentiation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,12 @@ nothing # hide
190190

191191
### Ignoring Derivatives
192192

193-
Use [`Reactant.ignore_derivatives`](@ref) to exclude parts of computation from gradient:
193+
Use [`EnzymeCore.ignore_derivatives`](@ref) to exclude parts of computation from gradient:
194194

195195
```@example autodiff_tutorial
196196
function func_with_ignore(x)
197197
# This part won't contribute to gradient
198-
ignored_sum = Reactant.ignore_derivatives(sum(x))
198+
ignored_sum = Enzyme.ignore_derivatives(sum(x))
199199
# This part will contribute
200200
return sum(x .^ 2) + ignored_sum
201201
end

src/Enzyme.jl

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -585,24 +585,4 @@ function overload_autodiff(
585585
end
586586
end
587587

588-
"""
589-
ignore_derivatives(args...)
590-
591-
Prevents the flow of gradients (and higher-order derivatives) by creating a new value that
592-
is detached from the original value. This is an identity operation on the primal. This can
593-
be applied on a nested structure of arrays and we will apply the operation on each of the
594-
leaves.
595-
"""
596-
function ignore_derivatives(args...)
597-
res = map(ignore_derivatives_internal, args)
598-
length(args) == 1 && return only(res)
599-
return res
600-
end
601-
602-
function ignore_derivatives_internal(arg)
603-
return Functors.fmap(arg) do argᵢ
604-
argᵢ isa AnyTracedRArray && (argᵢ = materialize_traced_array(argᵢ))
605-
argᵢ isa TracedType && return @opcall ignore_derivatives(argᵢ)
606-
return argᵢ
607-
end
608-
end
588+
const ignore_derivatives = EnzymeCore.ignore_derivatives

src/Overlay.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,19 @@ end
2121
return overload_autodiff(rmode, f, rt, args...)
2222
end
2323

24+
@reactant_overlay function EnzymeCore.ignore_derivatives(args...)
25+
res = map(args) do arg
26+
return Functors.fmap(arg) do argᵢ
27+
argᵢ isa AnyTracedRArray &&
28+
(argᵢ = call_with_reactant(materialize_traced_array, argᵢ))
29+
argᵢ isa TracedType && return @opcall ignore_derivatives(argᵢ)
30+
return argᵢ
31+
end
32+
end
33+
length(args) == 1 && return only(res)
34+
return res
35+
end
36+
2437
# Random.jl overlays
2538
@reactant_overlay @noinline function Random.default_rng()
2639
return call_with_reactant(TracedRandom.default_rng)

src/Reactant.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using Enzyme:
2626
DuplicatedNoNeed,
2727
EnzymeRules,
2828
Reverse
29+
using EnzymeCore: EnzymeCore
2930

3031
export allowscalar, @allowscalar # re-exported from GPUArraysCore
3132

@@ -40,7 +41,7 @@ function precompiling()
4041
return (@ccall jl_generating_output()::Cint) == 1
4142
end
4243

43-
struct ReactantABI <: Enzyme.EnzymeCore.ABI end
44+
struct ReactantABI <: EnzymeCore.ABI end
4445

4546
include("PrimitiveTypes.jl")
4647

test/autodiff.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,15 +286,15 @@ function simple_grad_without_ignore(x::AbstractArray{T}) where {T}
286286
end
287287

288288
function simple_grad_with_ignore(x::AbstractArray{T}) where {T}
289-
return Reactant.ignore_derivatives(sum(x; dims=1), x .- 1, (x, x .+ 2)), sum(abs2, x)
289+
return Enzyme.ignore_derivatives(sum(x; dims=1), x .- 1, (x, x .+ 2)), sum(abs2, x)
290290
end
291291

292292
function zero_grad(x)
293-
return Reactant.ignore_derivatives(sum(x))
293+
return Enzyme.ignore_derivatives(sum(x))
294294
end
295295

296296
function zero_grad2(x)
297-
return Reactant.ignore_derivatives(sum(x), x)
297+
return Enzyme.ignore_derivatives(sum(x), x)
298298
end
299299

300300
@testset "ignore_derivatives" begin

0 commit comments

Comments
 (0)