@@ -100,12 +100,12 @@ struct ∂⃖weaveInnerOdd{N, O}; b̄; end
100100end
101101@Base . constprop :aggressive function (w: :∂⃖weaveInnerOdd {N, O})(Δ) where {N, O}
102102 @destruct c, c̄ = w. b̄ (Δ... )
103- return (c̄, c), ∂⃖weaveInnerEven {plus1(N) , O} ()
103+ return (c̄, c), ∂⃖weaveInnerEven {N+1 , O} ()
104104end
105105struct ∂⃖weaveInnerEven{N, O}; end
106106@Base . constprop :aggressive function (w: :∂⃖weaveInnerEven {N, O})(Δ′, x... ) where {N, O}
107107 @destruct y, ȳ = Δ′ (x... )
108- return y, ∂⃖weaveInnerOdd {plus1(N) , O} (ȳ)
108+ return y, ∂⃖weaveInnerOdd {N+1 , O} (ȳ)
109109end
110110
111111struct ∂⃖weaveOuterOdd{N, O}; end
@@ -114,15 +114,15 @@ struct ∂⃖weaveOuterOdd{N, O}; end
114114end
115115@Base . constprop :aggressive function (w: :∂⃖weaveOuterOdd {N, O})((Δ′′, Δ′′′)) where {N, O}
116116 @destruct α, ᾱ = Δ′′′ (Δ′′)
117- return (NoTangent (), α... ), ∂⃖weaveOuterEven {plus1(N) , O} (ᾱ)
117+ return (NoTangent (), α... ), ∂⃖weaveOuterEven {N+1 , O} (ᾱ)
118118end
119119struct ∂⃖weaveOuterEven{N, O}; ᾱ end
120120@Base . constprop :aggressive function (w: :∂⃖weaveOuterEven {N, O})(Δ⁴... ) where {N, O}
121- return w. ᾱ (Base. tail (Δ⁴)... ), ∂⃖weaveOuterOdd {plus1(N) , O} ()
121+ return w. ᾱ (Base. tail (Δ⁴)... ), ∂⃖weaveOuterOdd {N+1 , O} ()
122122end
123123
124124function (:: ∂⃖{N})(:: ∂⃖{1 }, args... ) where {N}
125- @destruct (a, ā) = ∂⃖ {plus1(N) } ()(args... )
125+ @destruct (a, ā) = ∂⃖ {N+1 } ()(args... )
126126 let O = c_order (N)
127127 (a, Protected {N} (@opaque Δ-> begin
128128 (b, b̄) = ā (Δ)
@@ -187,10 +187,10 @@ end
187187(:: ∂⃖rruleD{N, N})(Δ... ) where {N} = error (" Should not be reached" )
188188
189189# ∂⃖rrule
190- @Base . pure term_depth (N) = 2 ^ (N- 2 )
190+ term_depth (N) = 1 << (N- 2 )
191191function (:: ∂⃖rrule{N})(z, z̄) where {N}
192192 @destruct (y, ȳ) = z
193- y, ∂⃖rruleA {term_depth(N), 1} (∂⃖ {minus1(N) } (), ȳ, z̄)
193+ y, ∂⃖rruleA {term_depth(N), 1} (∂⃖ {N-1 } (), ȳ, z̄)
194194end
195195
196196function (:: ∂⃖{N})(f:: Core.IntrinsicFunction , args... ) where {N}
@@ -216,7 +216,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
216216 end
217217 return z
218218 else
219- ∂⃖p = ∂⃖ {minus1(N) } ()
219+ ∂⃖p = ∂⃖ {N-1 } ()
220220 @destruct z, z̄ = ∂⃖p (rrule, f, args... )
221221 if z === nothing
222222 return ∂⃖recurse {N} ()(f, args... )
@@ -230,7 +230,7 @@ function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) wher
230230 Tuple {Any, Any} (∂⃖ {1} ()(f, args... ))
231231end
232232
233- @Base . pure function (:: ∂⃖{1 })(:: typeof (Core. apply_type), head, args... )
233+ @Base . assume_effects :total function (:: ∂⃖{1 })(:: typeof (Core. apply_type), head, args... )
234234 return rrule (Core. apply_type, head, args... )
235235end
236236
@@ -283,8 +283,8 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end
283283EvenOddEven {O, P} (f:: F , g:: G ) where {O, P, F, G} = EvenOddEven {O, P, F, G} (f, g)
284284struct EvenOddOdd{O, P, F, G}; f:: F ; g:: G ; end
285285EvenOddOdd {O, P} (f:: F , g:: G ) where {O, P, F, G} = EvenOddOdd {O, P, F, G} (f, g)
286- @Base . constprop :aggressive (o:: EvenOddOdd{O, P, F, G} )(Δ) where {O, P, F, G} = (o. f (Δ), EvenOddEven {plus1(O) , P, F, G} (o. f, o. g))
287- @Base . constprop :aggressive (e:: EvenOddEven{O, P, F, G} )(Δ... ) where {O, P, F, G} = (e. g (Δ... ), EvenOddOdd {plus1(O) , P, F, G} (e. f, e. g))
286+ @Base . constprop :aggressive (o:: EvenOddOdd{O, P, F, G} )(Δ) where {O, P, F, G} = (o. f (Δ), EvenOddEven {O+1 , P, F, G} (o. f, o. g))
287+ @Base . constprop :aggressive (e:: EvenOddEven{O, P, F, G} )(Δ... ) where {O, P, F, G} = (e. g (Δ... ), EvenOddOdd {O+1 , P, F, G} (e. f, e. g))
288288@Base . constprop :aggressive (o:: EvenOddOdd{O, O} )(Δ) where {O} = o. f (Δ)
289289
290290
@@ -362,11 +362,11 @@ struct ApplyOdd{O, P}; u; ∂⃖f; end
362362struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end
363363@Base . constprop :aggressive function (a:: ApplyOdd{O, P} )(Δ) where {O, P}
364364 r, ∂⃖∂⃖f = a.∂⃖f (Δ)
365- (a. u (r), ApplyEven {plus1(O) , P} (a. u, ∂⃖∂⃖f))
365+ (a. u (r), ApplyEven {O+1 , P} (a. u, ∂⃖∂⃖f))
366366end
367367@Base . constprop :aggressive function (a:: ApplyEven{O, P} )(_, _, ff, args... ) where {O, P}
368368 r, ∂⃖∂⃖∂⃖f = Core. _apply_iterate (iterate, a.∂⃖∂⃖f, (ff,), args... )
369- (r, ApplyOdd {plus1(O) , P} (a. u, ∂⃖∂⃖∂⃖f))
369+ (r, ApplyOdd {O+1 , P} (a. u, ∂⃖∂⃖∂⃖f))
370370end
371371@Base . constprop :aggressive function (a:: ApplyOdd{O, O} )(Δ) where {O}
372372 r = a.∂⃖f (Δ)
@@ -380,10 +380,10 @@ function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Unio
380380end
381381
382382
383- @Base . pure c_order (N:: Int ) = 2 ^ N - 1
383+ c_order (N:: Int ) = 1 << N - 1
384384
385- @Base . pure function (:: ∂⃖{N})(:: typeof (Core. apply_type), head, args... ) where {N}
386- Core. apply_type (head, args... ), NonDiffOdd {plus1(plus1( length(args))) , 1, c_order(N)} ()
385+ @Base . assume_effects :total function (:: ∂⃖{N})(:: typeof (Core. apply_type), head, args... ) where {N}
386+ Core. apply_type (head, args... ), NonDiffOdd {length(args)+2 , 1, c_order(N)} ()
387387end
388388
389389@Base . constprop :aggressive lifted_getfield (x, s) = getfield (x, s)
0 commit comments