Skip to content

Commit d40dc47

Browse files
migrosserdpo
authored andcommitted
add chainrules for *-operator
1 parent 9025458 commit d40dc47

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1111
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
1212

13+
[weakdeps]
14+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
15+
16+
[extensions]
17+
LinearOperatorsChainRulesCoreExt = "ChainRulesCore"
18+
1319
[compat]
1420
FastClosures = "0.2, 0.3"
1521
LDLFactorizations = "0.9, 0.10"
@@ -20,6 +26,7 @@ julia = "^1.6.0"
2026
Arpack = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
2127
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2228
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
29+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2330

2431
[targets]
25-
test = ["Arpack", "Test", "TestSetExtensions"]
32+
test = ["Arpack", "Test", "TestSetExtensions", "Zygote"]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
module LinearOperatorsChainRulesCoreExt
2+
3+
using LinearOperators
4+
import ChainRulesCore
5+
6+
function ChainRulesCore.frule((_, Δx, _), ::typeof(*), op::AbstractLinearOperator{T}, x::AbstractVector{S}) where {T, S}
7+
y = op*x
8+
Δy = op*Δx
9+
return y, Δy
10+
end
11+
function ChainRulesCore.rrule(::typeof(*), op::AbstractLinearOperator{T}, x::AbstractVector{S}) where {T, S}
12+
y = op*x
13+
project_x = ChainRulesCore.ProjectTo(x)
14+
function mul_pullback(ȳ)
15+
= project_x( adjoint(op)*ChainRulesCore.unthunk(ȳ) )
16+
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
17+
end
18+
return y, mul_pullback
19+
end
20+
21+
function ChainRulesCore.frule((_, Δx, _), ::typeof(*), x::Union{LinearOperators.Adjoint{S, V}, LinearOperators.Transpose{S, V} }, op::AbstractLinearOperator{T}) where {T, S, V <: AbstractVector{S}}
22+
y = x*op
23+
Δy = Δx*op
24+
return y, Δy
25+
end
26+
function ChainRulesCore.rrule(::typeof(*), x::LinearOperators.Transpose{S, V}, op::AbstractLinearOperator{T}) where {T, S, V <: AbstractVector{S}}
27+
y = x*op
28+
project_x = ChainRulesCore.ProjectTo(x)
29+
function mul_pullback(ȳ)
30+
# needed to make sure that ȳ is recognized as Transposed
31+
# ȳ_ = transpose(collect(vec(ChainRulesCore.unthunk(ȳ))))
32+
ȳ_ = transpose(vec(ChainRulesCore.unthunk(ȳ)))
33+
= project_x(ȳ_*adjoint(op))
34+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
35+
end
36+
return y, mul_pullback
37+
end
38+
function ChainRulesCore.rrule(::typeof(*), x::LinearOperators.Adjoint{S, V}, op::AbstractLinearOperator{T}) where {T, S, V <: AbstractVector{S}}
39+
y = x*op
40+
project_x = ChainRulesCore.ProjectTo(x)
41+
function mul_pullback(ȳ)
42+
# needed to make sure that ȳ is recognized as Adjoint
43+
# ȳ_ = adjoint(collect(vec(ChainRulesCore.unthunk(ȳ))))
44+
ȳ_ = adjoint(conj.(vec(ChainRulesCore.unthunk(ȳ))))
45+
= project_x(ȳ_*adjoint(op))
46+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
47+
end
48+
return y, mul_pullback
49+
end
50+
51+
end # module

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Arpack, Test, TestSetExtensions, LinearOperators
22
using LinearAlgebra, SparseArrays
3+
using Zygote
34
include("test_aux.jl")
45

56
include("test_linop.jl")
@@ -13,3 +14,4 @@ include("test_callable.jl")
1314
include("test_deprecated.jl")
1415
include("test_normest.jl")
1516
include("test_diag.jl")
17+
include("test_chainrules.jl")

test/test_chainrules.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
using Zygote
2+
3+
function matmulOp(mat::AbstractArray{T}) where T
4+
function prod!(res,x)
5+
for i in axes(mat,1)
6+
res[i] = transpose(mat[i,:])*x
7+
end
8+
end
9+
10+
function ctprod!(res,x)
11+
for i in axes(mat,2)
12+
res[i] = dot(mat[:,i],x)
13+
end
14+
end
15+
16+
return LinearOperator{T}(size(mat,1),size(mat,2),false, false, prod!, nothing, ctprod!)
17+
end
18+
19+
function test_chainrules()
20+
@testset ExtendedTestSet "Chainrules" begin
21+
for (M,N) in zip([2,3,8,7], [2,4,8,16])
22+
for T in [Float64, ComplexF64]
23+
mat = simple_matrix(T, M, N)
24+
op = matmulOp(mat)
25+
x = rand(T,N)
26+
xᵀ = transpose(x[1:M])
27+
xᴴ = adjoint(x[1:M])
28+
29+
# test op*x
30+
y, g = Zygote.withgradient(v->sum(abs.(op*v)), x)
31+
y2, g2 = Zygote.withgradient(v->sum(abs.(mat*v)), x)
32+
@test isapprox(y, y2)
33+
@test isapprox(g[1], g2[1])
34+
35+
# test xᵀ*op
36+
yt, gt = Zygote.withgradient(v->sum(abs.(v*op)), xᵀ)
37+
yt2, gt2 = Zygote.withgradient(v->sum(abs.(v*mat)), xᵀ)
38+
@test isapprox(yt, yt2)
39+
@test isapprox(gt[1], gt2[1])
40+
41+
# test xᴴ*op
42+
yh, gh = Zygote.withgradient(v->sum(abs.(v*op)), xᴴ)
43+
yh2, gh2 = Zygote.withgradient(v->sum(abs.(v*mat)), xᴴ)
44+
@test isapprox(yh, yh2)
45+
@test isapprox(gh[1], gh2[1])
46+
end
47+
end
48+
end
49+
end
50+
51+
test_chainrules()

0 commit comments

Comments
 (0)