@@ -7,7 +7,8 @@ Compressed LBFGS implementation from:
77Implemented by Paul Raynaud (supervised by Dominique Orban)
88=#
99
10- using LinearAlgebra
10+ using LinearAlgebra, LinearAlgebra. BLAS
11+ using CUDA
1112
1213export CompressedLBFGS
1314
@@ -26,14 +27,16 @@ mutable struct CompressedLBFGS{T, M<:AbstractMatrix{T}, V<:AbstractVector{T}}
2627 intermediate_2:: LowerTriangular{T,M} # 2m * 2m
2728 inverse_intermediate_1:: UpperTriangular{T,M} # 2m * 2m
2829 inverse_intermediate_2:: LowerTriangular{T,M} # 2m * 2m
30+ intermediary_vector:: V # 2m
2931 sol:: V # m
30- inverse :: Bool
32+ intermediate_structure_updated :: Bool
3133end
3234
35+ default_gpu () = CUDA. functional () ? true : false
3336default_matrix_type (gpu:: Bool , T:: DataType ) = gpu ? CuMatrix{T} : Matrix{T}
3437default_vector_type (gpu:: Bool , T:: DataType ) = gpu ? CuVector{T} : Vector{T}
3538
36- function CompressedLBFGS (m:: Int , n:: Int ; T= Float64, gpu= false , M= default_matrix_type (gpu,T), V= default_vector_type (gpu,T))
39+ function CompressedLBFGS (m:: Int , n:: Int ; T= Float64, gpu= default_gpu () , M= default_matrix_type (gpu,T), V= default_vector_type (gpu,T))
3740 α = (T)(1 )
3841 k = 0
3942 Sₖ = M (undef,n,m)
@@ -46,9 +49,10 @@ function CompressedLBFGS(m::Int, n::Int; T=Float64, gpu=false, M=default_matrix_
4649 intermediate_2 = LowerTriangular (M (undef,2 * m,2 * m))
4750 inverse_intermediate_1 = UpperTriangular (M (undef,2 * m,2 * m))
4851 inverse_intermediate_2 = LowerTriangular (M (undef,2 * m,2 * m))
52+ intermediary_vector = V (undef,2 * m)
4953 sol = V (undef,2 * m)
50- inverse = false
51- return CompressedLBFGS {T,M,V} (m, n, k, α, Sₖ, Yₖ, Dₖ, Lₖ, chol_matrix, intermediate_1, intermediate_2, inverse_intermediate_1, inverse_intermediate_2, sol, inverse )
54+ intermediate_structure_updated = false
55+ return CompressedLBFGS {T,M,V} (m, n, k, α, Sₖ, Yₖ, Dₖ, Lₖ, chol_matrix, intermediate_1, intermediate_2, inverse_intermediate_1, inverse_intermediate_2, intermediary_vector, sol, intermediate_structure_updated )
5256end
5357
5458function Base. push! (op:: CompressedLBFGS{T,M,V} , s:: V , y:: V ) where {T,M,V<: AbstractVector{T} }
@@ -69,13 +73,14 @@ function Base.push!(op::CompressedLBFGS{T,M,V}, s::V, y::V) where {T,M,V<:Abstra
6973 circshift (op. Yₖ, (0 ,- 1 ))
7074 circshift (op. Dₖ, (- 1 ,- 1 ))
7175 # circshift doesn't work for a LowerTriangular matrix
76+ # for the time being, reinstantiate completely the Lₖ matrix
7277 for j in 2 : op. k
7378 for i in 1 : j- 1
7479 op. Lₖ. data[j, i] = dot (op. Sₖ[:,j],op. Yₖ[:,i])
7580 end
7681 end
7782 end
78- op. inverse = false
83+ op. intermediate_structure_updated = false
7984 return op
8085end
8186
@@ -98,46 +103,56 @@ function Base.Matrix(op::CompressedLBFGS{T,M,V}) where {T,M,V}
98103 return Bₖ
99104end
100105
106+ # step 4, Jₖ is computed only if needed
101107function inverse_cholesky (op:: CompressedLBFGS )
102- if ! op. inverse
103- op. chol_matrix[1 : op. k,1 : op. k] .= op. α .* (transpose (op. Sₖ[:,1 : op. k]) * op. Sₖ[:,1 : op. k]) .+ op. Lₖ[1 : op. k,1 : op. k] * inv (op. Dₖ[1 : op. k,1 : op. k]) * transpose (op. Lₖ[1 : op. k,1 : op. k])
104- cholesky! (view (op. chol_matrix,1 : op. k,1 : op. k))
105- op. inverse = true
106- end
107- Jₖ = transpose (UpperTriangular (op. chol_matrix[1 : op. k,1 : op. k]))
108+ view (op. chol_matrix, 1 : op. k, 1 : op. k) .= op. α .* (transpose (view (op. Sₖ, :, 1 : op. k)) * view (op. Sₖ, :, 1 : op. k)) .+ view (op. Lₖ, 1 : op. k, 1 : op. k) * inv (op. Dₖ[1 : op. k, 1 : op. k]) * transpose (view (op. Lₖ, 1 : op. k, 1 : op. k))
109+ cholesky! (view (op. chol_matrix,1 : op. k,1 : op. k))
110+ Jₖ = transpose (UpperTriangular (view (op. chol_matrix, 1 : op. k, 1 : op. k)))
108111 return Jₖ
109112end
110113
114+ # step 6, must be improve
115+ function precompile_iterated_structure! (op:: CompressedLBFGS )
116+ Jₖ = inverse_cholesky (op)
117+
118+ view (op. intermediate_1, 1 : op. k,1 : op. k) .= .- view (op. Dₖ, 1 : op. k, 1 : op. k)^ (1 / 2 )
119+ view (op. intermediate_1, 1 : op. k,op. k+ 1 : 2 * op. k) .= view (op. Dₖ, 1 : op. k, 1 : op. k)^ (- 1 / 2 ) * transpose (view (op. Lₖ, 1 : op. k, 1 : op. k))
120+ view (op. intermediate_1, op. k+ 1 : 2 * op. k, 1 : op. k) .= 0
121+ view (op. intermediate_1, op. k+ 1 : 2 * op. k, op. k+ 1 : 2 * op. k) .= transpose (Jₖ)
122+
123+ view (op. intermediate_2, 1 : op. k, 1 : op. k) .= view (op. Dₖ, 1 : op. k, 1 : op. k)^ (1 / 2 )
124+ view (op. intermediate_2, 1 : op. k, op. k+ 1 : 2 * op. k) .= 0
125+ view (op. intermediate_2, op. k+ 1 : 2 * op. k, 1 : op. k) .= .- view (op. Lₖ, 1 : op. k, 1 : op. k) * view (op. Dₖ, 1 : op. k, 1 : op. k)^ (- 1 / 2 )
126+ view (op. intermediate_2, op. k+ 1 : 2 * op. k, op. k+ 1 : 2 * op. k) .= Jₖ
127+
128+ view (op. inverse_intermediate_1, 1 : 2 * op. k, 1 : 2 * op. k) .= inv (op. intermediate_1[ 1 : 2 * op. k,1 : 2 * op. k])
129+ view (op. inverse_intermediate_2, 1 : 2 * op. k, 1 : 2 * op. k) .= inv (op. intermediate_2[ 1 : 2 * op. k,1 : 2 * op. k])
130+
131+ op. intermediate_structure_updated = true
132+ end
133+
111134# Algorithm 3.2 (p15)
112135function LinearAlgebra. mul! (Bv:: V , op:: CompressedLBFGS{T,M,V} , v:: V ) where {T,M,V<: AbstractVector{T} }
113136 # step 1-3 mainly done by Base.push!
114- # step 4, Jₖ is computed only if needed
115- Jₖ = inverse_cholesky (op:: CompressedLBFGS )
137+
138+ # steps 4 and 6, in case the intermediary required structure are not up to date
139+ (! op. intermediate_structure_updated) && (precompile_iterated_structure! (op))
116140
117141 # step 5, try views for mul!
118- # mul!(op.sol[1:op.k], transpose(op.Yₖ[:,1:op.k]), v) # wrong result
119- # mul!(op.sol[op.k+1:2*op.k], transpose(op.Yₖ[:,1:op.k]), v, (T)(1), op.α) # wrong result
120- op. sol[1 : op. k] .= transpose (op. Yₖ[:,1 : op. k]) * v
121- op. sol[op. k+ 1 : 2 * op. k] .= op. α .* transpose (op. Sₖ[:,1 : op. k]) * v
122-
123- # step 6, must be improve
124- op. intermediate_1[1 : op. k,1 : op. k] .= .- op. Dₖ[1 : op. k,1 : op. k]^ (1 / 2 )
125- op. intermediate_1[1 : op. k,op. k+ 1 : 2 * op. k] .= op. Dₖ[1 : op. k,1 : op. k]^ (- 1 / 2 ) * transpose (op. Lₖ[1 : op. k,1 : op. k])
126- op. intermediate_1[op. k+ 1 : 2 * op. k,1 : op. k] .= 0
127- op. intermediate_1[op. k+ 1 : 2 * op. k,op. k+ 1 : 2 * op. k] .= transpose (Jₖ)
128-
129- op. intermediate_2[1 : op. k,1 : op. k] .= op. Dₖ[1 : op. k,1 : op. k]^ (1 / 2 )
130- op. intermediate_2[1 : op. k,op. k+ 1 : 2 * op. k] .= 0
131- op. intermediate_2[op. k+ 1 : 2 * op. k,1 : op. k] .= .- op. Lₖ[1 : op. k,1 : op. k] * op. Dₖ[1 : op. k,1 : op. k]^ (- 1 / 2 )
132- op. intermediate_2[op. k+ 1 : 2 * op. k,op. k+ 1 : 2 * op. k] .= Jₖ
133-
134- op. inverse_intermediate_1[1 : 2 * op. k,1 : 2 * op. k] .= inv (op. intermediate_1[1 : 2 * op. k,1 : 2 * op. k])
135- op. inverse_intermediate_2[1 : 2 * op. k,1 : 2 * op. k] .= inv (op. intermediate_2[1 : 2 * op. k,1 : 2 * op. k])
136-
137- op. sol[1 : 2 * op. k] .= op. inverse_intermediate_1[1 : 2 * op. k,1 : 2 * op. k] * (op. inverse_intermediate_2[1 : 2 * op. k,1 : 2 * op. k] * op. sol[1 : 2 * op. k])
142+ mul! (view (op. sol, 1 : op. k), transpose (view (op. Yₖ, :, 1 : op. k)), v)
143+ mul! (view (op. sol, op. k+ 1 : 2 * op. k), transpose (view (op. Sₖ, :,1 : op. k)), v)
144+ # scal!(op.α, view(op.sol, op.k+1:2*op.k)) # more allocation, slower
145+ view (op. sol, op. k+ 1 : 2 * op. k) .*= op. α
146+
147+ # view(op.sol, 1:2*op.k) .= view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k) * (view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k) * view(op.sol, 1:2*op.k))
148+ mul! (view (op. intermediary_vector, 1 : 2 * op. k), view (op. inverse_intermediate_2, 1 : 2 * op. k, 1 : 2 * op. k), view (op. sol, 1 : 2 * op. k))
149+ mul! (view (op. sol, 1 : 2 * op. k), view (op. inverse_intermediate_1, 1 : 2 * op. k, 1 : 2 * op. k), view (op. intermediary_vector, 1 : 2 * op. k))
138150
139151 # step 7
140- Bv .= op. α .* v .- (op. Yₖ[:,1 : op. k] * op. sol[1 : op. k] .+ op. α .* op. Sₖ[:,1 : op. k] * op. sol[op. k+ 1 : 2 * op. k])
141-
152+ # Bv .= op.α .* v .- (view(op.Yₖ, :,1:op.k) * view(op.sol, 1:op.k) .+ op.α .* view(op.Sₖ, :, 1:op.k) * view(op.sol, op.k+1:2*op.k))
153+
154+ mul! (Bv, view (op. Yₖ, :, 1 : op. k), view (op. sol, 1 : op. k))
155+ mul! (Bv, view (op. Sₖ, :, 1 : op. k), view (op. sol, op. k+ 1 : 2 * op. k), - op. α, (T)(- 1 ))
156+ Bv .+ = op. α .* v
142157 return Bv
143158end
0 commit comments