From 0ca22861d8cce1e216971f259c63b96117750947 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Tue, 3 Aug 2021 09:47:01 -0700 Subject: [PATCH 1/4] Draft for moment matching --- src/ImportanceSampling.jl | 46 +++++++++------- src/InternalHelpers.jl | 8 +++ src/LeaveOneOut.jl | 8 +-- src/LooStructs.jl | 6 +-- src/MomentMatch.jl | 108 ++++++++++++++++++++++++++++++++++++++ src/PublicHelpers.jl | 6 +-- 6 files changed, 149 insertions(+), 33 deletions(-) create mode 100644 src/MomentMatch.jl diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index a34e581..a9bbb9e 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -51,13 +51,14 @@ function psis( dims = size(log_ratios) data_size = dims[1] - post_sample_size = dims[2] * dims[3] + mcmc_count = dims[2] * dims[3] + # Reshape to matrix (easier to deal with) - log_ratios = reshape(log_ratios, data_size, post_sample_size) + log_ratios = reshape(log_ratios, data_size, mcmc_count) weights = similar(log_ratios) - # Shift ratios by maximum to prevent overflow - @tturbo @. weights = exp(log_ratios - $maximum(log_ratios; dims=2)) + # Shift ratios by maximum to avoid overflow, and log(mcmc_count) to avoid subnormals + @tturbo @. weights = exp(log_ratios - $maximum(log_ratios; dims=2) + log(mcmc_count)) r_eff = _generate_r_eff(weights, dims, r_eff, source) _check_input_validity_psis(reshape(log_ratios, dims), r_eff) @@ -65,21 +66,20 @@ function psis( tail_length = similar(log_ratios, Int, data_size) ξ = similar(log_ratios, data_size) @inbounds Threads.@threads for i in eachindex(tail_length) - tail_length[i] = @views _def_tail_length(post_sample_size, r_eff[i]) - ξ[i] = @views ParetoSmooth._do_psis_i!(weights[i,:], tail_length[i]) + tail_length[i] = @views _def_tail_length(mcmc_count, r_eff[i]) + ξ[i] = @views ParetoSmooth._psis_smooth!(weights[i,:], tail_length[i]) end - - @tullio norm_const[i] := weights[i, j] - @turbo weights .= weights ./ norm_const - ess = psis_ess(weights, r_eff) + _normalize!(weights) + + ess = psis_ess(weights, r_eff) weights = reshape(weights, dims) if log_weights @tturbo @. weights = log(weights) end - return Psis(weights, ξ, ess, r_eff, tail_length, post_sample_size, data_size) + return Psis(weights, ξ, ess, r_eff, tail_length, mcmc_count, data_size) end @@ -95,24 +95,25 @@ end """ - _do_psis_i!(is_ratios::AbstractVector{Real}, tail_length::Integer) -> T + _psis_smooth!(is_ratios::AbstractVector{AbstractFloat}, tail_length::Integer) -> T -Do PSIS on a single vector, smoothing its tail values. +Do PSIS on a single vector, smoothing its tail values in place before returning ξ. # Arguments -- `is_ratios::AbstractVector{Real}`: A vector of importance sampling ratios, -scaled to have a maximum of 1. + - `is_ratios::AbstractVector{AbstractFloat}`: A vector of importance sampling ratios, + scaled to have a maximum of 1. # Returns -- `T<:Real`: ξ, the shape parameter for the GPD; big numbers indicate thick tails. + - `T<:AbstractFloat`: ξ, the estimated shape parameter for the GPD. Bigger numbers + indicate thicker tails. # Extended help Additional information can be found in the LOO package from R. """ -function _do_psis_i!( +function _psis_smooth!( is_ratios::AbstractVector{T}, tail_length::Integer ) where {T<:Real} @@ -135,8 +136,8 @@ function _do_psis_i!( cutoff = is_ratios[tail_start - 1] ξ = _psis_smooth_tail!(tail, cutoff) - # truncate at max of raw weights (1 after scaling) - clamp!(is_ratios, 0, 1) + # truncate at max of raw weights (equal to len after scaling) + clamp!(is_ratios, 0, len) # unsort the ratios to their original position: is_ratios .= @views is_ratios[invperm(ordering)] @@ -175,6 +176,11 @@ function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} end +function _normalize!(weights::AbstractArray) + @tullio norm_const[i] := weights[i, j] + @turbo @. weights /= norm_const +end + ########################## #### HELPER FUNCTIONS #### @@ -251,7 +257,7 @@ function _check_tail(tail::AbstractVector{T}) where {T<:Real} throw( ArgumentError( "Unable to fit generalized Pareto distribution: tail length was too " * - "short. Likely causese are: \n$LIKELY_ERROR_CAUSES" + "short. Likely causes are: \n$LIKELY_ERROR_CAUSES" ), ) end diff --git a/src/InternalHelpers.jl b/src/InternalHelpers.jl index f25335c..f683169 100644 --- a/src/InternalHelpers.jl +++ b/src/InternalHelpers.jl @@ -39,6 +39,14 @@ function _assume_one_chain(matrix) end +""" +Safely exponentiate a vector for a scale-invariant operation (exponentiate x - maximum(x)) +""" +function _safe_exp(x::AbstractVector) + return exp.(x .- maximum(x)) +end + + """ Convert a matrix+chain_index representation to a 3d array representation. """ diff --git a/src/LeaveOneOut.jl b/src/LeaveOneOut.jl index 8fb2b62..34464c1 100644 --- a/src/LeaveOneOut.jl +++ b/src/LeaveOneOut.jl @@ -52,7 +52,7 @@ score. See also: [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref). """ function psis_loo( - log_likelihood::T, args...; + log_likelihood::AbstractArray, args...; kwargs... ) where {F<:Real, T<:AbstractArray{F, 3}} @@ -114,12 +114,6 @@ function psis_loo( new_log_ratios = _convert_to_array(log_likelihood, chain_index) return psis_loo(new_log_ratios, args...; kwargs...) end - -# function psis_loo(log_likelihood, args...; -# subsamples::Integer, rng::AbstractRNG=MersenneTwister(1776), kwargs... -# ) -# return log_likelihood = rand() -# end function _generate_loo_table( diff --git a/src/LooStructs.jl b/src/LooStructs.jl index dce2430..c53d1ab 100644 --- a/src/LooStructs.jl +++ b/src/LooStructs.jl @@ -24,9 +24,9 @@ const CV_DESC = """ - `:ess` is the effective sample size, which measures the simulation error caused by using Monte Carlo estimates. It is *not* related to the actual sample size, and it does not measure how accurate your predictions are. - - `:pareto_k` is the estimated value for the parameter `ξ` of the generalized Pareto - distribution. Values above .7 indicate that PSIS has failed to approximate the true - distribution. + - `:pareto_k` is the estimated value for the parameter `ξ` of the generalized Pareto + distribution. Values above .7 indicate that PSIS has failed to approximate the true + distribution. - `psis_object::Psis`: A `Psis` object containing the results of Pareto-smoothed importance sampling. diff --git a/src/MomentMatch.jl b/src/MomentMatch.jl new file mode 100644 index 0000000..61334d0 --- /dev/null +++ b/src/MomentMatch.jl @@ -0,0 +1,108 @@ +using AxisKeys +using LoopVectorization +using Tullio + + +function adapt_weights!( + log_target::Function, + psis_object::Psis, + samples::AbstractArray, + data +) + + dims = size(samples) + n_steps, n_params, n_chains = dims + mcmc_count = n_steps * n_chains + weights = psis_object.weights + resample_count = size(weights, 1) + + Threads.@threads for resample in 1:resample_count + @views log_proposal = weights[resample, :, :] + end + +end + + +function _moment_match_i!( + log_target::Function, + log_proposal::AbstractArray, + θ_hats::AbstractVector, # parameter vector + ξ::Real, + hard_thresh::Real = 2/3, + soft_thresh::Real = 1/2, + soft_cap::Integer = 10, +) + dims = size(θ_hats) + mcmc_count = size(θ_hats, :parameter) + + # initialize variables + num_iter = 0 # iterations of IWMM + transform = 1 + θ_proposed = similar(θ_hats) + ξ_proposed = soft_thresh + μ = mean(θ_hats; dims=:parameter) + μ_proposed = similar(μ) + σ = std(θ_hats; dims=:parameter) + σ_proposed = similar(σ) + + + while _keep_going(ξ, hard_thresh, soft_thresh, soft_cap, num_iter) + + μ_proposed = _calc_loc(weights, θ_hats, mcmc_count) + if transform == 1 + σ_proposed .= σ + elseif transform == 2 + σ_proposed = std(θ_hats) + elseif transform == 3 + σ_proposed = _calc_scatter(weights, θ_hats, mcmc_count) + elseif transform == 4 + break + end + + θ_proposed = (θ_hats + μ_proposed - μ) * (σ_proposed * inv(σ)) + log_like_proposed = log_target(θ_proposed) - log_proposal + @. weights_proposed = _safe_exp(log_like_proposed) + ξ_proposed = _psis_smooth!(weights_proposed) + + if ξ_proposed < ξ + num_iter += 1 + _normalize!(weights_proposed) + + ξ = ξ_proposed + μ = μ_proposed + σ = σ_proposed + θ_hats = θ_proposed + log_likelihood = log_like_proposed + else + transform += 1 + end + + end + +end + + +function _keep_going( + ξ::Real, + hard_thresh::Real, + soft_thresh::Real, + soft_cap::Int, + num_iter::Integer +) + if ξ > hard_thresh + return true + elseif (ξ > soft_thresh) && (num_iter ≤ soft_cap) + return true + else + return false + end +end + + + +""" +Safely exponentiate -- subtract maximum to prevent overflow +""" +function _safe_exp(x) + return exp(x - $maximum(x; dims=2)) +end \ No newline at end of file diff --git a/src/PublicHelpers.jl b/src/PublicHelpers.jl index a88d411..34afd4e 100644 --- a/src/PublicHelpers.jl +++ b/src/PublicHelpers.jl @@ -1,8 +1,8 @@ export pointwise_log_likelihoods -const ARRAY_DIMS_WARNING = "The supplied array of mcmc samples indicates you have more -parameters than mcmc samples.This is possible, but highly unusual. Please check that your -array of mcmc samples has the following dimensions: [n_samples,n_parms,n_chains]." +const ARRAY_DIMS_WARNING = "The supplied array of MCMC samples indicates you have more " * +"parameters than samples. This is possible, but unusual. Please check your array of MCMC " * +"samples has dimensions `[iter, var, chain]`." """ pointwise_log_likelihoods( From 3696413e7bdea70f64fd4b3d11d442aaeb45afc7 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Tue, 3 Aug 2021 11:26:52 -0700 Subject: [PATCH 2/4] stuff --- Project.toml | 2 ++ src/LeaveOneOut.jl | 4 ++-- src/MomentMatch.jl | 58 ++++++++++++++++++++++++++++++++++------------ 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 9f30cbb..440a548 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,8 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" [compat] diff --git a/src/LeaveOneOut.jl b/src/LeaveOneOut.jl index 34464c1..3e5a90b 100644 --- a/src/LeaveOneOut.jl +++ b/src/LeaveOneOut.jl @@ -35,8 +35,8 @@ end [, chain_index::Vector{Integer}, kwargs...] ) -> PsisLoo -Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation -score. +Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross-validation +estimate. # Arguments diff --git a/src/MomentMatch.jl b/src/MomentMatch.jl index 61334d0..f6386e2 100644 --- a/src/MomentMatch.jl +++ b/src/MomentMatch.jl @@ -1,32 +1,56 @@ using AxisKeys using LoopVectorization +using StatsBase +using Tables using Tullio - -function adapt_weights!( +""" + adapt_moments( + log_target::Function, + psis_object::Psis, + samples::AbstractArray, + data; + hard_thresh::Real = 2/3, + soft_thresh::Real = 1/2, + soft_cap::Integer = 10 + ) + +Perform importance-weighted moment matching, adapting a sample from a proposal distribution +to more closely match the target distribution. + +# Arguments + - `log_target`: The log-pdf of the target distribution, described as a function of the +""" +function adapt_moments( log_target::Function, psis_object::Psis, samples::AbstractArray, - data + data; + hard_thresh::Real = 2/3, + soft_thresh::Real = 1/2, + soft_cap::Integer = 10 ) + psis_object = Psis() dims = size(samples) n_steps, n_params, n_chains = dims mcmc_count = n_steps * n_chains weights = psis_object.weights + ξ = psis_object.pareto_k resample_count = size(weights, 1) - Threads.@threads for resample in 1:resample_count + Threads.@threads @inbounds for resample in 1:resample_count @views log_proposal = weights[resample, :, :] + _match!(log_target, log_proposal, samples, ξ; hard_thresh, soft_thresh, soft_cap) end end -function _moment_match_i!( +function _match!( log_target::Function, log_proposal::AbstractArray, - θ_hats::AbstractVector, # parameter vector + θ_hats::AbstractArray, ξ::Real, hard_thresh::Real = 2/3, soft_thresh::Real = 1/2, @@ -44,24 +68,28 @@ function _moment_match_i!( μ_proposed = similar(μ) σ = std(θ_hats; dims=:parameter) σ_proposed = similar(σ) + weights = + weights_proposed = similar(log_proposal) while _keep_going(ξ, hard_thresh, soft_thresh, soft_cap, num_iter) - μ_proposed = _calc_loc(weights, θ_hats, mcmc_count) + if transform == 1 + μ_proposed = mean(θ_hats, weights; dims=2) σ_proposed .= σ elseif transform == 2 - σ_proposed = std(θ_hats) + μ_proposed = mean(θ_hats, weights; dims=2) + σ_proposed = std(θ_hats, weights; mean=μ_proposed, dims=2) elseif transform == 3 - σ_proposed = _calc_scatter(weights, θ_hats, mcmc_count) + μ_proposed, Σ_proposed = mean_and_cov(θ_hats, weights; dims=2) + σ_proposed = sqrt(Σ_proposed) elseif transform == 4 break end - θ_proposed = (θ_hats + μ_proposed - μ) * (σ_proposed * inv(σ)) - log_like_proposed = log_target(θ_proposed) - log_proposal - @. weights_proposed = _safe_exp(log_like_proposed) + θ_proposed = (θ_hats + μ_proposed - μ) * (σ_proposed * inv(σ)) + @. weights_proposed = _safe_exp(log_target(θ_proposed) - log_proposal) ξ_proposed = _psis_smooth!(weights_proposed) if ξ_proposed < ξ @@ -72,7 +100,6 @@ function _moment_match_i!( μ = μ_proposed σ = σ_proposed θ_hats = θ_proposed - log_likelihood = log_like_proposed else transform += 1 end @@ -101,8 +128,9 @@ end """ -Safely exponentiate -- subtract maximum to prevent overflow +Safely exponentiate x, preventing underflow/overflow by rescaling all elements +by a common factor """ function _safe_exp(x) - return exp(x - $maximum(x; dims=2)) + return exp(x - $maximum(x; dims=2) + log(length(x))) end \ No newline at end of file From 8bda020a376c58affe34ea807046fc21309bec3b Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Tue, 3 Aug 2021 19:06:26 -0700 Subject: [PATCH 3/4] stuff --- src/MomentMatch.jl | 64 +++++++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/src/MomentMatch.jl b/src/MomentMatch.jl index f6386e2..41d4724 100644 --- a/src/MomentMatch.jl +++ b/src/MomentMatch.jl @@ -5,7 +5,7 @@ using Tables using Tullio """ - adapt_moments( + adapt_cv( log_target::Function, psis_object::Psis, samples::AbstractArray, @@ -19,11 +19,13 @@ Perform importance-weighted moment matching, adapting a sample from a proposal d to more closely match the target distribution. # Arguments - - `log_target`: The log-pdf of the target distribution, described as a function of the + - `log_target`: The log-pdf of the target distribution. This should be a function having + θ, a vector of parameters, as its input; and `x`, the data set, as its second input. + """ -function adapt_moments( +function adapt_cv( log_target::Function, - psis_object::Psis, + psis_object::AbstractCV, samples::AbstractArray, data; hard_thresh::Real = 2/3, @@ -31,17 +33,28 @@ function adapt_moments( soft_cap::Integer = 10 ) - psis_object = Psis() dims = size(samples) n_steps, n_params, n_chains = dims mcmc_count = n_steps * n_chains + log_count = mcmc_count + weights = psis_object.weights ξ = psis_object.pareto_k resample_count = size(weights, 1) Threads.@threads @inbounds for resample in 1:resample_count - @views log_proposal = weights[resample, :, :] - _match!(log_target, log_proposal, samples, ξ; hard_thresh, soft_thresh, soft_cap) + log_proposal = log_proposal + _match!( + log_target, + log_proposal, + view(weights, resample, :, :), + samples, + ξ[resample], + log_count, + hard_thresh, + soft_thresh, + soft_cap + ) end end @@ -50,14 +63,14 @@ end function _match!( log_target::Function, log_proposal::AbstractArray, + weights::AbstractArray, θ_hats::AbstractArray, ξ::Real, - hard_thresh::Real = 2/3, - soft_thresh::Real = 1/2, - soft_cap::Integer = 10, + log_count::Real, + hard_thresh::Real, + soft_thresh::Real, + soft_cap::Integer, ) - dims = size(θ_hats) - mcmc_count = size(θ_hats, :parameter) # initialize variables num_iter = 0 # iterations of IWMM @@ -68,13 +81,11 @@ function _match!( μ_proposed = similar(μ) σ = std(θ_hats; dims=:parameter) σ_proposed = similar(σ) - weights = - weights_proposed = similar(log_proposal) - + weights_proposed = similar(weights) + while _keep_going(ξ, hard_thresh, soft_thresh, soft_cap, num_iter) - - + if transform == 1 μ_proposed = mean(θ_hats, weights; dims=2) σ_proposed .= σ @@ -89,23 +100,28 @@ function _match!( end θ_proposed = (θ_hats + μ_proposed - μ) * (σ_proposed * inv(σ)) - @. weights_proposed = _safe_exp(log_target(θ_proposed) - log_proposal) - ξ_proposed = _psis_smooth!(weights_proposed) + @. weights_proposed = _safe_exp(log_target(θ_proposed) - log_proposal, log_count) + ξ_proposed = _psis_smooth!(weights) if ξ_proposed < ξ num_iter += 1 _normalize!(weights_proposed) ξ = ξ_proposed - μ = μ_proposed - σ = σ_proposed - θ_hats = θ_proposed + @. begin + μ = μ_proposed + σ = σ_proposed + θ_hats = θ_proposed + weights = weights_proposed + end else transform += 1 end end + return weights + end @@ -131,6 +147,6 @@ end Safely exponentiate x, preventing underflow/overflow by rescaling all elements by a common factor """ -function _safe_exp(x) - return exp(x - $maximum(x; dims=2) + log(length(x))) +function _safe_exp(x, log_count) + return @. exp(x - $maximum(x; dims=2) + log_count) end \ No newline at end of file From 033cda4d0eec322e0eb5903a17bd8daf83daa29b Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Tue, 3 Aug 2021 20:16:44 -0700 Subject: [PATCH 4/4] stuff --- src/MomentMatch.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/MomentMatch.jl b/src/MomentMatch.jl index 41d4724..6b983b7 100644 --- a/src/MomentMatch.jl +++ b/src/MomentMatch.jl @@ -25,7 +25,8 @@ to more closely match the target distribution. """ function adapt_cv( log_target::Function, - psis_object::AbstractCV, + log_p::AbstractArray, + cv_object::AbstractCV, samples::AbstractArray, data; hard_thresh::Real = 2/3, @@ -39,15 +40,15 @@ function adapt_cv( log_count = mcmc_count weights = psis_object.weights + original_weights = deepcopy(psis_object.weights) ξ = psis_object.pareto_k resample_count = size(weights, 1) Threads.@threads @inbounds for resample in 1:resample_count - log_proposal = log_proposal - _match!( + @views _match!( log_target, - log_proposal, - view(weights, resample, :, :), + log_p[resample, :, :], + weights[resample, :, :], samples, ξ[resample], log_count,