@@ -49,3 +49,91 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
4949
5050 return (I, V), findnz_pullback
5151end
52+
53+ if VERSION < v " 1.7"
54+ #=
55+ The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
56+ determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
57+ for the rrules for the determinants of sparse matrices below to work, they need to be
58+ able to compute the primals as well, so this import from the future is included. For
59+ more recent versions of Julia, this definition lives in:
60+ julia/stdlib/SuiteSparse/src/umfpack.jl
61+ =#
62+ using SuiteSparse. UMFPACK: UmfpackLU
63+
64+ # compute the sign/parity of a permutation
65+ function _signperm (p)
66+ n = length (p)
67+ result = 0
68+ todo = trues (n)
69+ while any (todo)
70+ k = findfirst (todo)
71+ todo[k] = false
72+ result += 1 # increment element count
73+ j = p[k]
74+ while j != k
75+ result += 1 # increment element count
76+ todo[j] = false
77+ j = p[j]
78+ end
79+ result += 1 # increment cycle count
80+ end
81+ return ifelse (isodd (result), - 1 , 1 )
82+ end
83+
84+ function LinearAlgebra. logabsdet (F:: UmfpackLU{T, TI} ) where {T<: Union{Float64,ComplexF64} ,TI<: Union{Int32, Int64} }
85+ n = checksquare (F)
86+ issuccess (F) || return log (zero (real (T))), zero (T)
87+ U = F. U
88+ Rs = F. Rs
89+ p = F. p
90+ q = F. q
91+ s = _signperm (p)* _signperm (q)* one (real (T))
92+ P = one (T)
93+ abs_det = zero (real (T))
94+ @inbounds for i in 1 : n
95+ dg_ii = U[i, i] / Rs[i]
96+ P *= sign (dg_ii)
97+ abs_det += log (abs (dg_ii))
98+ end
99+ return abs_det, s * P
100+ end
101+ end
102+
103+
104+ function rrule (:: typeof (logabsdet), x:: SparseMatrixCSC )
105+ F = cholesky (x)
106+ L, D, U, P = SparseInverseSubset. get_ldup (F)
107+ Ω = logabsdet (D)
108+ function logabsdet_pullback (ΔΩ)
109+ (Δy, Δsigny) = ΔΩ
110+ (_, signy) = Ω
111+ f = signy' * Δsigny
112+ imagf = f - real (f)
113+ g = real (Δy) + imagf
114+ Z, P = sparseinv (F, depermute= true )
115+ ∂x = g * Z'
116+ return (NoTangent (), ∂x)
117+ end
118+ return Ω, logabsdet_pullback
119+ end
120+
121+ function rrule (:: typeof (logdet), x:: SparseMatrixCSC )
122+ Ω = logdet (x)
123+ function logdet_pullback (ΔΩ)
124+ Z, p = sparseinv (x, depermute= true )
125+ ∂x = ΔΩ * Z'
126+ return (NoTangent (), ∂x)
127+ end
128+ return Ω, logdet_pullback
129+ end
130+
131+ function rrule (:: typeof (det), x:: SparseMatrixCSC )
132+ Ω = det (x)
133+ function det_pullback (ΔΩ)
134+ Z, _ = sparseinv (x, depermute= true )
135+ ∂x = Z' * dot (Ω, ΔΩ)
136+ return (NoTangent (), ∂x)
137+ end
138+ return Ω, det_pullback
139+ end
0 commit comments