@@ -112,15 +112,113 @@ end
112112function ChainRulesCore. rrule (s:: Sinus , x:: AbstractVector , y:: AbstractVector )
113113 d = x - y
114114 sind = sinpi .(d)
115- abs2_sind_r = abs2 .(sind) ./ s. r
115+ abs2_sind_r = abs2 .(sind) ./ s. r .^ 2
116116 val = sum (abs2_sind_r)
117- gradx = twoπ .* cospi .(d) .* sind ./ ( s. r .^ 2 )
117+ gradx = twoπ .* cospi .(d) .* sind ./ s. r .^ 2
118118 function evaluate_pullback (Δ:: Any )
119- return (r= - 2 Δ .* abs2_sind_r,), Δ * gradx, - Δ * gradx
119+ r̄ = - 2 Δ .* abs2_sind_r ./ s. r
120+ s̄ = ChainRulesCore. Tangent {typeof(s)} (; r= r̄)
121+ return s̄, Δ * gradx, - Δ * gradx
120122 end
121123 return val, evaluate_pullback
122124end
123125
126+ function ChainRulesCore. rrule (
127+ :: typeof (Distances. pairwise), d:: Sinus , x:: AbstractMatrix ; dims= 2
128+ )
129+ project_x = ProjectTo (x)
130+ function pairwise_pullback (z̄)
131+ Δ = unthunk (z̄)
132+ n = size (x, dims)
133+ x̄ = collect (zero (x))
134+ r̄ = zero (d. r)
135+ if dims == 1
136+ for j in 1 : n, i in 1 : n
137+ xi = view (x, i, :)
138+ xj = view (x, j, :)
139+ ds = twoπ .* Δ[i, j] .* sinpi .(xi .- xj) .* cospi .(xi .- xj) ./ d. r .^ 2
140+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- xj) .^ 2 ./ d. r .^ 3
141+ x̄[i, :] += ds
142+ x̄[j, :] -= ds
143+ end
144+ elseif dims == 2
145+ for j in 1 : n, i in 1 : n
146+ xi = view (x, :, i)
147+ xj = view (x, :, j)
148+ ds = twoπ .* Δ[i, j] .* sinpi .(xi .- xj) .* cospi .(xi .- xj) ./ d. r .^ 2
149+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- xj) .^ 2 ./ d. r .^ 3
150+ x̄[:, i] += ds
151+ x̄[:, j] -= ds
152+ end
153+ end
154+ d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
155+ return NoTangent (), d̄, @thunk (project_x (x̄))
156+ end
157+ return Distances. pairwise (d, x; dims), pairwise_pullback
158+ end
159+
160+ function ChainRulesCore. rrule (
161+ :: typeof (Distances. pairwise), d:: Sinus , x:: AbstractMatrix , y:: AbstractMatrix ; dims= 2
162+ )
163+ project_x = ProjectTo (x)
164+ project_y = ProjectTo (y)
165+ function pairwise_pullback (z̄)
166+ Δ = unthunk (z̄)
167+ n = size (x, dims)
168+ m = size (y, dims)
169+ x̄ = collect (zero (x))
170+ ȳ = collect (zero (y))
171+ r̄ = zero (d. r)
172+ if dims == 1
173+ for j in 1 : m, i in 1 : n
174+ xi = view (x, i, :)
175+ yj = view (y, j, :)
176+ ds = twoπ .* Δ[i, j] .* sinpi .(xi .- yj) .* cospi .(xi .- yj) ./ d. r .^ 2
177+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- yj) .^ 2 ./ d. r .^ 3
178+ x̄[i, :] += ds
179+ ȳ[j, :] -= ds
180+ end
181+ elseif dims == 2
182+ for j in 1 : m, i in 1 : n
183+ xi = view (x, :, i)
184+ yj = view (y, :, j)
185+ ds = twoπ .* Δ[i, j] .* sinpi .(xi .- yj) .* cospi .(xi .- yj) ./ d. r .^ 2
186+ r̄ .- = 2 .* Δ[i, j] .* sinpi .(xi .- yj) .^ 2 ./ d. r .^ 3
187+ x̄[:, i] += ds
188+ ȳ[:, j] -= ds
189+ end
190+ end
191+ d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
192+ return NoTangent (), d̄, @thunk (project_x (x̄)), @thunk (project_y (ȳ))
193+ end
194+ return Distances. pairwise (d, x, y; dims), pairwise_pullback
195+ end
196+
197+ function ChainRulesCore. rrule (
198+ :: typeof (Distances. colwise), d:: Sinus , x:: AbstractMatrix , y:: AbstractMatrix
199+ )
200+ project_x = ProjectTo (x)
201+ project_y = ProjectTo (y)
202+ function colwise_pullback (z̄)
203+ Δ = unthunk (z̄)
204+ n = size (x, 2 )
205+ x̄ = collect (zero (x))
206+ ȳ = collect (zero (y))
207+ r̄ = zero (d. r)
208+ for i in 1 : n
209+ xi = view (x, :, i)
210+ yi = view (y, :, i)
211+ ds = twoπ .* Δ[i] .* sinpi .(xi .- yi) .* cospi .(xi .- yi) ./ d. r .^ 2
212+ r̄ .- = 2 .* Δ[i] .* sinpi .(xi .- yi) .^ 2 ./ d. r .^ 3
213+ x̄[:, i] += ds
214+ ȳ[:, i] -= ds
215+ end
216+ d̄ = ChainRulesCore. Tangent {typeof(d)} (; r= r̄)
217+ return NoTangent (), d̄, @thunk (project_x (x̄)), @thunk (project_y (ȳ))
218+ end
219+ return Distances. colwise (d, x, y), colwise_pullback
220+ end
221+
124222# # Reverse Rules for matrix wrappers
125223
126224function ChainRulesCore. rrule (:: Type{<:ColVecs} , X:: AbstractMatrix )
0 commit comments