66
77const TAG = :ReverseAD
88
9+ """
10+ const MAX_CHUNK::Int = 10
11+
12+ An upper bound on the chunk sie for forward-over-reverse. Increasing this could
13+ improve performance at the cost of extra memory allocation. It has been 10 for a
14+ long time, and nobody seems to have complained.
15+ """
16+ const MAX_CHUNK = 10
17+
918"""
1019 _eval_hessian(
1120 d::NLPEvaluator,
@@ -23,46 +32,30 @@ Returns the number of non-zeros in the computed Hessian, which will be used to
2332update the offset for the next call.
2433"""
2534function _eval_hessian (
26- d:: NLPEvaluator ,
27- f:: _FunctionStorage ,
28- H:: AbstractVector{Float64} ,
29- λ:: Float64 ,
30- offset:: Int ,
31- ):: Int
32- chunk = min (size (f. seed_matrix, 2 ), d. max_chunk)
33- # As a performance optimization, skip dynamic dispatch if the chunk is 1.
34- if chunk == 1
35- return _eval_hessian_inner (d, f, H, λ, offset, Val (1 ))
36- else
37- return _eval_hessian_inner (d, f, H, λ, offset, Val (chunk))
38- end
39- end
40-
41- function _eval_hessian_inner (
4235 d:: NLPEvaluator ,
4336 ex:: _FunctionStorage ,
4437 H:: AbstractVector{Float64} ,
4538 scale:: Float64 ,
4639 nzcount:: Int ,
47- :: Val{CHUNK} ,
48- ) where {CHUNK}
40+ ):: Int
4941 if ex. linearity == LINEAR
5042 @assert length (ex. hess_I) == 0
5143 return 0
5244 end
45+ chunk = min (size (ex. seed_matrix, 2 ), d. max_chunk)
5346 Coloring. prepare_seed_matrix! (ex. seed_matrix, ex. rinfo)
5447 # Compute hessian-vector products
5548 num_products = size (ex. seed_matrix, 2 ) # number of hessian-vector products
56- num_chunks = div (num_products, CHUNK )
49+ num_chunks = div (num_products, chunk )
5750 @assert size (ex. seed_matrix, 1 ) == length (ex. rinfo. local_indices)
58- for offset in 1 : CHUNK : (CHUNK * num_chunks)
59- _eval_hessian_chunk (d, ex, offset, CHUNK, Val (CHUNK) )
51+ for offset in 1 : chunk : (chunk * num_chunks)
52+ _eval_hessian_chunk (d, ex, offset, chunk, chunk )
6053 end
6154 # leftover chunk
62- remaining = num_products - CHUNK * num_chunks
55+ remaining = num_products - chunk * num_chunks
6356 if remaining > 0
64- offset = CHUNK * num_chunks + 1
65- _eval_hessian_chunk (d, ex, offset, remaining, Val (CHUNK) )
57+ offset = chunk * num_chunks + 1
58+ _eval_hessian_chunk (d, ex, offset, remaining, chunk )
6659 end
6760 want, got = nzcount + length (ex. hess_I), length (H)
6861 if want > got
@@ -90,32 +83,45 @@ function _eval_hessian_chunk(
9083 ex:: _FunctionStorage ,
9184 offset:: Int ,
9285 chunk:: Int ,
93- :: Val{CHUNK} ,
94- ) where {CHUNK}
86+ chunk_size :: Int ,
87+ )
9588 for r in eachindex (ex. rinfo. local_indices)
9689 # set up directional derivatives
9790 @inbounds idx = ex. rinfo. local_indices[r]
9891 # load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
9992 for s in 1 : chunk
100- # If `chunk < CHUNK `, leaves junk in the unused components
101- d. input_ϵ[(idx- 1 )* CHUNK + s] = ex. seed_matrix[r, offset+ s- 1 ]
93+ # If `chunk < chunk_size `, leaves junk in the unused components
94+ d. input_ϵ[(idx- 1 )* chunk_size + s] = ex. seed_matrix[r, offset+ s- 1 ]
10295 end
10396 end
104- _hessian_slice_inner (d, ex, Val (CHUNK) )
97+ _hessian_slice_inner (d, ex, chunk_size )
10598 fill! (d. input_ϵ, 0.0 )
10699 # collect directional derivatives
107100 for r in eachindex (ex. rinfo. local_indices)
108101 @inbounds idx = ex. rinfo. local_indices[r]
109102 # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
110103 for s in 1 : chunk
111- ex. seed_matrix[r, offset+ s- 1 ] = d. output_ϵ[(idx- 1 )* CHUNK + s]
104+ ex. seed_matrix[r, offset+ s- 1 ] = d. output_ϵ[(idx- 1 )* chunk_size + s]
112105 end
113106 end
114107 return
115108end
116109
117- function _hessian_slice_inner (d, ex, :: Val{CHUNK} ) where {CHUNK}
118- T = ForwardDiff. Partials{CHUNK,Float64} # This is our element type.
110+ # A wrapper function to avoid dynamic dispatch.
111+ function _generate_hessian_slice_inner ()
112+ exprs = map (1 : MAX_CHUNK) do id
113+ T = ForwardDiff. Partials{id,Float64}
114+ return :(return _hessian_slice_inner (d, ex, $ T))
115+ end
116+ return MOI. Nonlinear. _create_binary_switch (1 : MAX_CHUNK, exprs)
117+ end
118+
119+ @eval function _hessian_slice_inner (d, ex, id:: Int )
120+ $ (_generate_hessian_slice_inner ())
121+ return error (" Invalid chunk size: $id " )
122+ end
123+
124+ function _hessian_slice_inner (d, ex, :: Type{T} ) where {T}
119125 fill! (d. output_ϵ, 0.0 )
120126 output_ϵ = _reinterpret_unsafe (T, d. output_ϵ)
121127 subexpr_forward_values_ϵ =
0 commit comments