@@ -10,6 +10,8 @@ PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
1010
1111(f:: PartialsFn{T} )(i) where {T} = partials (T, f. dual, i)
1212
13+ _take (itr, N:: Integer ) = Iterators. take (itr, min (length (itr), N))
14+
1315function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
1416 seed:: Partials{N,V} ) where {T,V,N}
1517 idxs = collect (ForwardDiff. structural_eachindex (duals, x))
1921
2022function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
2123 seeds:: NTuple{N,Partials{N,V}} ) where {T,V,N}
22- idxs = collect (Iterators . take (ForwardDiff. structural_eachindex (duals, x), N))
24+ idxs = collect (_take (ForwardDiff. structural_eachindex (duals, x), N))
2325 duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
2426 return duals
2527end
@@ -36,10 +38,7 @@ function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
3638 seeds:: NTuple{N,Partials{N,V}} , chunksize) where {T,V,N}
3739 offset = index - 1
3840 idxs = collect (
39- Iterators. take (
40- Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset),
41- chunksize,
42- ),
41+ _take (Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset), chunksize)
4342 )
4443 duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
4544 return duals
4948function ForwardDiff. extract_gradient! (:: Type{T} , result:: AbstractGPUArray ,
5049 dual:: Dual ) where {T}
5150 fn = PartialsFn {T} (dual)
52- idxs = collect (Iterators . take (ForwardDiff. structural_eachindex (result), npartials (dual)))
51+ idxs = collect (_take (ForwardDiff. structural_eachindex (result), npartials (dual)))
5352 result[idxs] .= fn .(1 : length (idxs))
5453 return result
5554end
@@ -59,10 +58,7 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray
5958 fn = PartialsFn {T} (dual)
6059 offset = index - 1
6160 idxs = collect (
62- Iterators. take (
63- Iterators. drop (ForwardDiff. structural_eachindex (result), offset),
64- chunksize,
65- )
61+ _take (Iterators. drop (ForwardDiff. structural_eachindex (result), offset), chunksize)
6662 )
6763 result[idxs] .= fn .(1 : length (idxs))
6864 return result
0 commit comments