@@ -4,34 +4,37 @@ using GPUArraysCore: AbstractGPUArray
44using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials
55
66function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
7- seed:: Partials{N,V} = zero (Partials{N,V}) ) where {T,V,N}
7+ seed:: Partials{N,V} ) where {T,V,N}
88 idxs = collect (ForwardDiff. structural_eachindex (duals, x))
9- duals[idxs] .= Dual {T,V,N} .(x[ idxs] , Ref (seed))
9+ duals[idxs] .= Dual {T,V,N} .(view (x, idxs) , Ref (seed))
1010 return duals
1111end
1212
1313function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
1414 seeds:: NTuple{N,Partials{N,V}} ) where {T,V,N}
15- idxs = collect (ForwardDiff. structural_eachindex (duals, x))[ 1 : N]
16- duals[idxs] .= Dual {T,V,N} .(x[ idxs], seeds[ 1 : N] )
15+ idxs = collect (Iterators . take ( ForwardDiff. structural_eachindex (duals, x), N))
16+ duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .( Ref ( seeds), 1 : length (idxs)) )
1717 return duals
1818end
1919
2020function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x, index,
21- seed:: Partials{N,V} = zero (Partials{N,V}) ) where {T,V,N}
21+ seed:: Partials{N,V} ) where {T,V,N}
2222 offset = index - 1
2323 idxs = collect (Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset))
24- duals[idxs] .= Dual {T,V,N} .(x[ idxs] , Ref (seed))
24+ duals[idxs] .= Dual {T,V,N} .(view (x, idxs) , Ref (seed))
2525 return duals
2626end
2727
2828function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x, index,
29- seeds:: NTuple{N,Partials{N,V}} , chunksize = N ) where {T,V,N}
29+ seeds:: NTuple{N,Partials{N,V}} , chunksize) where {T,V,N}
3030 offset = index - 1
3131 idxs = collect (
32- Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset)
33- )[1 : chunksize]
34- duals[idxs] .= Dual {T,V,N} .(x[idxs], seeds[1 : chunksize])
32+ Iterators. take (
33+ Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset),
34+ chunksize,
35+ ),
36+ )
37+ duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
3538 return duals
3639end
3740
@@ -41,8 +44,8 @@ function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
4144 # this closure is needed for gpu compilation
4245 partial_fn (dual, i) = partials (T, dual, i)
4346
44- idxs = ForwardDiff. structural_eachindex (result)
45- result[idxs] .= partial_fn .(Ref (dual), 1 : npartials (dual ))
47+ idxs = collect (Iterators . take ( ForwardDiff. structural_eachindex (result), npartials (dual)) )
48+ result[idxs] .= partial_fn .(Ref (dual), 1 : length (idxs ))
4649 return result
4750end
4851
@@ -53,9 +56,12 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray
5356
5457 offset = index - 1
5558 idxs = collect (
56- Iterators. drop (ForwardDiff. structural_eachindex (result), offset)
57- )[1 : chunksize]
58- result[idxs] .= partial_fn .(Ref (dual), 1 : chunksize)
59+ Iterators. take (
60+ Iterators. drop (ForwardDiff. structural_eachindex (result), offset),
61+ chunksize,
62+ )
63+ )
64+ result[idxs] .= partial_fn .(Ref (dual), 1 : length (idxs))
5965 return result
6066end
6167
0 commit comments