Skip to content

Commit 58e654a

Browse files
auto-unthunk partials
1 parent ea13bc0 commit 58e654a

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/chainrules.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out
1+
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out, unthunk
22
using Base.Broadcast: broadcasted
33

44
function rrule(::Type{TaylorScalar}, v::T, p::NTuple{N, T}) where {N, T}
@@ -26,6 +26,9 @@ function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T}
2626
function partials_pullback(::ZeroTangent)
2727
NoTangent(), TaylorScalar(z, ntuple(j -> zero(T), Val(N)))
2828
end
29+
function partials_pullback(v̄::ChainRulesCore.AbstractThunk)
30+
partials_pullback(unthunk(v̄))
31+
end
2932
return partials(t), partials_pullback
3033
end
3134

0 commit comments

Comments
 (0)