@@ -3,6 +3,13 @@ module ForwardDiffGPUArraysCoreExt
33using GPUArraysCore: AbstractGPUArray
44using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials
55
6+ struct PartialsFn{T,D<: Dual }
7+ dual:: D
8+ end
9+ PartialsFn {T} (dual:: Dual ) where {T} = PartialsFn {T,typeof(dual)} (dual)
10+
11+ (f:: PartialsFn{T} )(i) where {T} = partials (T, f. dual, i)
12+
613function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
714 seed:: Partials{N,V} ) where {T,V,N}
815 idxs = collect (ForwardDiff. structural_eachindex (duals, x))
4148# gradient
4249function ForwardDiff. extract_gradient! (:: Type{T} , result:: AbstractGPUArray ,
4350 dual:: Dual ) where {T}
44- # this closure is needed for gpu compilation
45- partial_fn (dual, i) = partials (T, dual, i)
46-
51+ fn = PartialsFn {T} (dual)
4752 idxs = collect (Iterators. take (ForwardDiff. structural_eachindex (result), npartials (dual)))
48- result[idxs] .= partial_fn .( Ref (dual), 1 : length (idxs))
53+ result[idxs] .= fn .( 1 : length (idxs))
4954 return result
5055end
5156
5257function ForwardDiff. extract_gradient_chunk! (:: Type{T} , result:: AbstractGPUArray , dual,
5358 index, chunksize) where {T}
54- # this closure is needed for gpu compilation
55- partial_fn (dual, i) = partials (T, dual, i)
56-
59+ fn = PartialsFn {T} (dual)
5760 offset = index - 1
5861 idxs = collect (
5962 Iterators. take (
6063 Iterators. drop (ForwardDiff. structural_eachindex (result), offset),
6164 chunksize,
6265 )
6366 )
64- result[idxs] .= partial_fn .( Ref (dual), 1 : length (idxs))
67+ result[idxs] .= fn .( 1 : length (idxs))
6568 return result
6669end
6770
0 commit comments