diff --git a/Project.toml b/Project.toml index 9185cb1c..514a1384 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,12 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "7.4.0" +version = "8.0.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -18,7 +19,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -30,6 +31,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] AbstractMCMC = "0.4, 0.5, 1.0, 2.0, 3.0, 4, 5" AxisArrays = "0.4.4" +Compat = "4.2.0" DataAPI = "1.16.0" Dates = "<0.0.1, 1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" @@ -40,7 +42,7 @@ MCMCDiagnosticTools = "0.3" MLJModelInterface = "0.3.5, 0.4, 1.0" NaturalSort = "1" OrderedCollections = "1.4" -PrettyTables = "0.9, 0.10, 0.11, 0.12, 1, 2" +PosteriorStats = "0.4" Random = "<0.0.1, 1" RecipesBase = "0.7, 0.8, 1.0" Statistics = "<0.0.1, 1" @@ -48,4 +50,4 @@ StatsBase = "0.33.2, 0.34" StatsFuns = "0.8, 0.9, 1" TableTraits = "0.4, 1" Tables = "1" -julia = "1.6.3" +julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 2b684d1d..37a3c464 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,10 +3,12 @@ CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" [compat] @@ -14,9 +16,11 @@ CairoMakie = "0.6 - 0.13.0, 0.13.2 - 0.15" CategoricalArrays = "0.8, 0.9, 0.10, 1" DataFrames = "0.22, 1" Documenter = "0.26, 0.27, 1" +DocumenterInterLinks = "1" Gadfly = "1.3.4" -MCMCChains = "7" +MCMCChains = "8" MLJBase = "0.19, 0.20, 0.21, 1" MLJDecisionTreeInterface = "0.3, 0.4" +StableRNGs = "1" StatsPlots = "0.14, 0.15" julia = "1.7" diff --git a/docs/make.jl b/docs/make.jl index 4ad18dcd..73ecde22 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,5 @@ using Documenter +using DocumenterInterLinks using MCMCChains DocMeta.setdocmeta!( @@ -8,6 +9,8 @@ DocMeta.setdocmeta!( recursive=true ) +links = InterLinks("PosteriorStats" => "https://julia.arviz.org/PosteriorStats/stable/") + makedocs( sitename = "MCMCChains.jl", pages = [ @@ -28,5 +31,6 @@ makedocs( ], format = Documenter.HTML(prettyurls = get(ENV, "CI", nothing) == "true"), modules = [MCMCChains], + plugins = [links], checkdocs = :exports, ) diff --git a/docs/src/stats.md b/docs/src/stats.md index 7b178ddb..4ed20b16 100644 --- a/docs/src/stats.md +++ b/docs/src/stats.md @@ -8,5 +8,6 @@ describe mean summarystats quantile -hpd +eti +hdi ``` diff --git a/docs/src/statsplots.md b/docs/src/statsplots.md index 038b7414..2c1ec565 100644 --- a/docs/src/statsplots.md +++ b/docs/src/statsplots.md @@ -174,13 +174,13 @@ ridgelineplot(chn, [:C, :B, :A]) ## Forest ```@example statsplots -forestplot(chn, [:C, :B, :A], hpd_val = [0.05, 0.15, 0.25]) +forestplot(chn, [:C, :B, :A], ci_probs = [0.05, 0.15, 0.25]) ``` ## Caterpillar ```@example statsplots -forestplot(chn, chn.name_map[:parameters], hpd_val = [0.05, 0.15, 0.25], ordered = true) +forestplot(chn, chn.name_map[:parameters], ci_fun=hdi, ci_probs = [0.05, 0.15, 0.25], ordered = true) ``` ## Posterior Predictive Checks (PPC) diff --git a/docs/src/summarize.md b/docs/src/summarize.md index e1865bd6..4e30ae2d 100644 --- a/docs/src/summarize.md +++ b/docs/src/summarize.md @@ -2,7 +2,6 @@ The methods listed below are defined in `src/summarize.jl`. -```@autodocs -Modules = [MCMCChains] -Pages = ["summarize.jl"] +```@docs +summarize ``` diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index be0dd502..d0637134 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -1,5 +1,6 @@ module MCMCChains +using Compat: stack using AxisArrays const axes = Base.axes import AbstractMCMC @@ -26,7 +27,7 @@ import MCMCDiagnosticTools import MLJModelInterface import NaturalSort import OrderedCollections -import PrettyTables +import PosteriorStats import StatsFuns import Tables import TableTraits @@ -41,8 +42,6 @@ export setrange, resetrange export set_section, get_params, sections, sort_sections, setinfo export replacenames, namesingroup, group export autocor, describe, sample, summarystats, AbstractWeights, mean, quantile -export ChainDataFrame -export summarize # Reexport diagnostics functions using MCMCDiagnosticTools: @@ -69,7 +68,9 @@ export mcse export rafterydiag export rstar -export hpd +# Reexport stats functions +using PosteriorStats: SummaryStats, eti, hdi, summarize +export SummaryStats, eti, hdi, summarize """ Chains diff --git a/src/chains.jl b/src/chains.jl index 39eeb718..1f9ae77b 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -345,7 +345,7 @@ function Base.show(io::IO, chains::Chains) end function Base.show(io::IO, mime::MIME"text/plain", chains::Chains) - print(io, "Chains ", chains, ":\n\n", header(chains)) + println(io, "Chains ", chains, ":\n\n", header(chains)) println(io, "\nUse `describe(chains)` for summary statistics and quantiles.") end diff --git a/src/discretediag.jl b/src/discretediag.jl index 67cec670..4babc245 100644 --- a/src/discretediag.jl +++ b/src/discretediag.jl @@ -17,16 +17,22 @@ function MCMCDiagnosticTools.discretediag( _permutedims_diagnostics(_chains.value.data); kwargs... ) - # Create dataframes - parameters = (parameters = names(_chains),) - between_chain_df = ChainDataFrame( - "Chisq diagnostic - Between chains", merge(parameters, between_chain_vals), + # Create SummaryStats + param_names = names(_chains) + between_chain_stats = SummaryStats( + between_chain_vals; + name = "Chisq diagnostic - Between chains", + labels = param_names, ) - within_chain_dfs = map(1:size(_chains, 3)) do i + within_chain_stats = map(1:size(_chains, 3)) do i vals = map(val -> val[:, i], within_chain_vals) - return ChainDataFrame("Chisq diagnostic - Chain $i", merge(parameters, vals)) + return SummaryStats( + vals; + name = "Chisq diagnostic - Chain $i", + labels = param_names, + ) end - dfs = vcat(between_chain_df, within_chain_dfs) + stats = vcat([between_chain_stats], within_chain_stats) - return dfs + return stats end diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 70d86d2b..e3a27632 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -24,9 +24,9 @@ function MCMCDiagnosticTools.ess( # Convert to NamedTuple ess_per_sec = ess ./ dur - nt = merge((parameters = names(_chains),), (; ess, ess_per_sec)) + nt = (; ess, ess_per_sec) - return ChainDataFrame("ESS", nt) + return SummaryStats(nt; name = "ESS", labels = names(_chains)) end """ @@ -48,9 +48,9 @@ function MCMCDiagnosticTools.rhat( ) # Convert to NamedTuple - nt = merge((parameters = names(_chains),), (; rhat)) + nt = (; rhat) - return ChainDataFrame("R-hat", nt) + return SummaryStats(nt; name = "R-hat", labels = names(_chains)) end """ @@ -79,7 +79,7 @@ function MCMCDiagnosticTools.ess_rhat( # Convert to NamedTuple ess_per_sec = ess_rhat.ess ./ dur - nt = merge((parameters = names(_chains),), ess_rhat, (; ess_per_sec)) + nt = merge(ess_rhat, (; ess_per_sec)) - return ChainDataFrame("ESS/R-hat", nt) + return SummaryStats(nt; name = "ESS/R-hat", labels = names(_chains)) end diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index 3fc894f5..999ef519 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -12,12 +12,13 @@ function MCMCDiagnosticTools.gelmandiag( results = MCMCDiagnosticTools.gelmandiag(_permutedims_diagnostics(psi); kwargs...) # Create a data frame with the results. - df = ChainDataFrame( - "Gelman, Rubin, and Brooks diagnostic", - merge((parameters = names(_chains),), results), + stats = SummaryStats( + results; + name = "Gelman, Rubin, and Brooks diagnostic", + labels = names(_chains), ) - return df + return stats end function MCMCDiagnosticTools.gelmandiag_multivariate( @@ -36,11 +37,12 @@ function MCMCDiagnosticTools.gelmandiag_multivariate( kwargs..., ) - # Create a data frame with the results. - df = ChainDataFrame( - "Gelman, Rubin, and Brooks diagnostic", - (parameters = names(_chains), psrf = results.psrf, psrfci = results.psrfci), + # Create SummaryStats with the results. + stats = SummaryStats( + (psrf = results.psrf, psrfci = results.psrfci); + name = "Gelman, Rubin, and Brooks diagnostic", + labels = names(_chains), ) - return df, results.psrfmultivariate + return stats, results.psrfmultivariate end diff --git a/src/gewekediag.jl b/src/gewekediag.jl index 72cbb5f2..d93bd313 100644 --- a/src/gewekediag.jl +++ b/src/gewekediag.jl @@ -18,12 +18,14 @@ function MCMCDiagnosticTools.gewekediag( return namedtuple_of_vecs end - # Create data frames. - parameters = (parameters = names(_chains),) - dfs = [ - ChainDataFrame("Geweke diagnostic - Chain $i", merge(parameters, result)) - for (i, result) in enumerate(results) + # Create SummaryStats. + stats = [ + SummaryStats( + result; + name = "Geweke diagnostic - Chain $i", + labels = names(_chains), + ) for (i, result) in enumerate(results) ] - return dfs + return stats end diff --git a/src/heideldiag.jl b/src/heideldiag.jl index 67def6a8..57a614f8 100644 --- a/src/heideldiag.jl +++ b/src/heideldiag.jl @@ -16,14 +16,14 @@ function MCMCDiagnosticTools.heideldiag( return namedtuple_of_vecs end - # Create data frames. - parameters = (parameters = names(_chains),) - dfs = [ - ChainDataFrame( - "Heidelberger and Welch diagnostic - Chain $i", merge(parameters, result) - ) - for (i, result) in enumerate(results) + # Create SummaryStats. + stats = [ + SummaryStats( + result; + name = "Heidelberger and Welch diagnostic - Chain $i", + labels = names(_chains), + ) for (i, result) in enumerate(results) ] - return dfs + return stats end diff --git a/src/mcse.jl b/src/mcse.jl index 78ae2552..a33b0af4 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -1,5 +1,5 @@ """ - mcse(chains::Chains; duration=compute_duration, kwargs...) + mcse(chains::Chains; kwargs...) Estimate the Monte Carlo standard error. """ @@ -16,7 +16,7 @@ function MCMCDiagnosticTools.mcse( kwargs..., ) - nt = merge((parameters = names(_chains),), (; mcse)) + nt = (; mcse) - return ChainDataFrame("MCSE", nt) + return SummaryStats(nt; name = "MCSE", labels = names(_chains)) end diff --git a/src/plot.jl b/src/plot.jl index 55c514ce..5f765961 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -108,7 +108,11 @@ The following options are available: - `q` (default: `[0.1, 0.9]`): The two quantiles used for plotting if `fill_q = true` or `show_qi = true`. -- `hpd_val` (default: `[0.05, 0.2]`): The complementary probability mass(es) of the highest posterior density intervals that are plotted if `fill_hpd = true` or `show_hpdi = true`. +- `ci_fun` (default: `eti`): The function used to compute the credible intervals. + (Can be [`eti`](@ref) or [`hdi`](@ref)) + +- `ci_probs` (default: `[$DEFAULT_CI_PROB, 0.8]`): The probability mass(es) of the credible + intervals that are plotted if `fill_ci = true` or `show_cii = true`. !!! note If a single parameter is provided, the generated plot is a density plot with all the elements described above. @@ -144,7 +148,11 @@ By default, all parameters are plotted. - `q` (default: `[0.1, 0.9]`): The two quantiles used for plotting if `fill_q = true` or `show_qi = true`. -- `hpd_val` (default: `[0.05, 0.2]`): The complementary probability mass(es) of the highest posterior density intervals that are plotted if `fill_hpd = true` or `show_hpdi = true`. +- `ci_fun` (default: `eti`): The function used to compute the credible intervals. + (Can be [`eti`](@ref) or [`hdi`](@ref)) + +- `ci_probs` (default: `[$DEFAULT_CI_PROB, 0.8]`): The probability mass(es) of the credible + intervals that are plotted if `fill_ci = true` or `show_cii = true`. """ @userplot ForestPlot @@ -235,7 +243,7 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag) # Chains are already appended in `c` if desired, hence we use `append_chains=false` ac = autocor(c; sections = nothing, lags = lags, append_chains = false) - ac_mat = convert(Array{Float64}, ac) + ac_mat = stack(map(stack ∘ Base.Fix2(Iterators.drop, 1), ac)) val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :] _AutocorPlot(lags, val) elseif st ∈ (:violinplot, :violin) @@ -777,7 +785,8 @@ function _compute_plot_data( i::Integer, chains::Chains, par_names::AbstractVector{Symbol}; - hpd_val = [0.05, 0.2], + ci_fun = eti, + ci_probs = [DEFAULT_CI_PROB, 0.8], q = [0.1, 0.9], spacer = 0.4, _riser = 0.2, @@ -785,29 +794,22 @@ function _compute_plot_data( show_mean = true, show_median = true, show_qi = false, - show_hpdi = true, + show_cii = true, fill_q = true, - fill_hpd = false, - ordered = false, + fill_ci = false, ) + probs_sorted = sort(ci_probs; rev = true) - chain_dic = Dict(zip(quantile(chains)[:, 1], quantile(chains)[:, 4])) - sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic)))) - sorted_par = [sorted_chain[i][2] for i = 1:length(par_names)] - par = (ordered ? sorted_par : par_names) - hpdi = sort(hpd_val) - - chain_sections = MCMCChains.group(chains, Symbol(par[i])) + chain_sections = MCMCChains.group(chains, Symbol(par_names[i])) chain_vec = vec(chain_sections.value.data) - lower_hpd = - [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.lower for j = 1:length(hpdi)] - upper_hpd = - [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.upper for j = 1:length(hpdi)] + ci_intervals = map(probs_sorted) do prob + only(Tables.getcolumn(ci_fun(chain_sections; prob), 2)) + end h = _riser + spacer * (i - 1) qs = quantile(chain_vec, q) k_density = kde(chain_vec) - if fill_hpd - x_int = filter(x -> lower_hpd[1][1] <= x <= upper_hpd[1][1], k_density.x) + if fill_ci + x_int = filter(in(ci_intervals[1]), k_density.x) val = pdf(k_density, x_int) .+ h elseif fill_q x_int = filter(x -> qs[1] <= x <= qs[2], k_density.x) @@ -821,10 +823,9 @@ function _compute_plot_data( min = minimum(k_density.density .+ h) q_int = (show_qi ? [qs[1], chain_med, qs[2]] : [chain_med]) - return par, - hpdi, - lower_hpd, - upper_hpd, + return par_names, + probs_sorted, + ci_intervals, h, qs, k_density, @@ -836,29 +837,38 @@ function _compute_plot_data( q_int end +_intervalname(::typeof(PosteriorStats.eti)) = "ETI" +_intervalname(::typeof(PosteriorStats.hdi)) = "HDI" +_intervalname(f) = string(nameof(f)) + @recipe function f( p::RidgelinePlot; - hpd_val = [0.05, 0.2], + ci_probs = [DEFAULT_CI_PROB, 0.8], + ci_fun = eti, q = [0.1, 0.9], spacer = 0.5, _riser = 0.2, show_mean = true, show_median = true, show_qi = false, - show_hpdi = true, + show_cii = true, fill_q = true, - fill_hpd = false, + fill_ci = false, ordered = false, ) chn = p.args[1] par_names = p.args[2] + if ordered + par_table_names, par_medians = summarize(chn[:, par_names, :], median) + par_names = par_table_names[sortperm(par_medians)] + end + for i = 1:length(par_names) par, - hpdi, - lower_hpd, - upper_hpd, + cii, + ci_intervals, h, qs, k_density, @@ -871,17 +881,17 @@ end i, chn, par_names; - hpd_val = hpd_val, + ci_fun = ci_fun, + ci_probs = ci_probs, q = q, spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median, show_qi = show_qi, - show_hpdi = show_hpdi, + show_cii = show_cii, fill_q = fill_q, - fill_hpd = fill_hpd, - ordered = ordered, + fill_ci = fill_ci, ) yticks --> ( @@ -936,45 +946,53 @@ end label := nothing linecolor := "#000000" linewidth --> (show_qi ? 1.2 : 0) + seriesalpha --> (show_qi ? 1.0 : 0.0) [qs[1], qs[2]], [h, h] end @series begin seriestype := :path label := ( - show_hpdi ? (i == 1 ? "$(Integer((1-hpdi[1])*100))% HPDI" : nothing) : - nothing + show_cii ? + (i == 1 ? "$(round(Int, cii[i]*100))% $(_intervalname(ci_fun))" : nothing) : nothing ) - linewidth --> (show_hpdi ? 2 : 0) - seriesalpha --> 0.80 + linewidth --> (show_cii ? 2 : 0) + markersize --> 0 + seriesalpha --> (show_cii ? 0.80 : 0.0) linecolor --> :darkblue - [lower_hpd[1][1], upper_hpd[1][1]], [h, h] + offset := h + ci_intervals[1] end end end @recipe function f( p::ForestPlot; - hpd_val = [0.05, 0.2], + ci_probs = [DEFAULT_CI_PROB, 0.8], + ci_fun = eti, q = [0.1, 0.9], spacer = 0.5, _riser = 0.2, show_mean = true, show_median = true, show_qi = false, - show_hpdi = true, + show_cii = true, fill_q = true, - fill_hpd = false, + fill_ci = false, ordered = false, ) chn = p.args[1] par_names = p.args[2] + if ordered + par_table_names, par_medians = summarize(chn[:, par_names, :], median) + par_names = par_table_names[sortperm(par_medians)] + end + for i = 1:length(par_names) par, - hpdi, - lower_hpd, - upper_hpd, + cii, + ci_intervals, h, qs, k_density, @@ -987,17 +1005,17 @@ end i, chn, par_names; - hpd_val = hpd_val, + ci_fun = ci_fun, + ci_probs = ci_probs, q = q, spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median, show_qi = show_qi, - show_hpdi = show_hpdi, + show_cii = show_cii, fill_q = fill_q, - fill_hpd = fill_hpd, - ordered = ordered, + fill_ci = fill_ci, ) yticks --> ( @@ -1006,17 +1024,22 @@ end ) yaxis --> (length(par_names) > 1 ? "Parameters" : "Density") - for j = 1:length(hpdi) + for j = 1:length(cii) @series begin seriestype := :path label := ( - show_hpdi ? - (i == 1 ? "$(Integer((1-hpdi[j])*100))% HPDI" : nothing) : nothing + show_cii ? + ( + i == 1 ? "$(round(Int, cii[j]*100))% $(_intervalname(ci_fun))" : + nothing + ) : nothing ) linecolor --> j - linewidth --> (show_hpdi ? 1.5 * j : 0) - seriesalpha --> 0.80 - [lower_hpd[j][1], upper_hpd[j][1]], [h, h] + linewidth --> (show_cii ? 1.5 * j : 0) + markersize --> 0 + seriesalpha --> (show_cii ? 0.80 : 0.0) + offset := h + ci_intervals[j] end end @series begin @@ -1024,7 +1047,7 @@ end label := (show_median ? (i == 1 ? "Median" : nothing) : nothing) markershape --> :diamond markercolor --> "#000000" - markersize --> (show_median ? length(hpdi) : 0) + markersize --> (show_median ? length(cii) : 0) [chain_med], [h] end @series begin @@ -1032,7 +1055,7 @@ end label := (show_mean ? (i == 1 ? "Mean" : nothing) : nothing) markershape --> :circle markercolor --> :gray - markersize --> (show_mean ? length(hpdi) : 0) + markersize --> (show_mean ? length(cii) : 0) [chain_mean], [h] end @series begin @@ -1048,6 +1071,7 @@ end label := nothing linecolor := "#000000" linewidth --> (show_qi ? 1.2 : 0.0) + seriesalpha --> (show_qi ? 1.0 : 0.0) [qs[1], qs[2]], [h, h] end end diff --git a/src/rafterydiag.jl b/src/rafterydiag.jl index 95126025..eb23716f 100644 --- a/src/rafterydiag.jl +++ b/src/rafterydiag.jl @@ -16,14 +16,14 @@ function MCMCDiagnosticTools.rafterydiag( return namedtuple_of_vecs end - # Create data frames. - parameters = (parameters = names(_chains),) - dfs = [ - ChainDataFrame( - "Raftery and Lewis diagnostic - Chain $i", merge(parameters, result) - ) - for (i, result) in enumerate(results) + # Create SummaryStats. + stats = [ + SummaryStats( + result; + name = "Raftery and Lewis diagnostic - Chain $i", + labels = names(_chains), + ) for (i, result) in enumerate(results) ] - return dfs + return stats end diff --git a/src/stats.jl b/src/stats.jl index 0bbc85fd..d15bd83e 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -1,5 +1,7 @@ #################### Posterior Statistics #################### +const DEFAULT_CI_PROB = 0.89f0 + """ autocor( chains; @@ -17,25 +19,41 @@ Setting `append_chains=false` will return a vector of dataframes containing the """ function autocor( chains::Chains; - append_chains = true, + sections = _default_sections(chains), + append_chains::Bool = true, demean::Bool = true, lags::AbstractVector{<:Integer} = _default_lags(chains, append_chains), + var_names = nothing, kwargs..., ) - funs = Function[] - func_names = @. Symbol("lag ", lags) - for i in lags - push!(funs, x -> autocor(x, [i], demean = demean)[1]) - end + chn = Chains(chains, _clean_sections(chains, sections)) - return summarize( - chains, - funs...; - func_names = func_names, - append_chains = append_chains, - name = "Autocorrelation", - kwargs..., - ) + # Obtain names of parameters. + names_of_params = var_names === nothing ? names(chn) : var_names + + # Construct column names for lags. + col_names = Symbol.("lag", lags) + + # avoids using summarize directly to support simultaneously computing a large number of + # lags without constructing a huge NamedTuple + if append_chains + data = _permutedims_diagnostics(chn.value.data) + vals = stack(map(eachslice(data; dims = 3)) do x + return autocor(vec(x), lags; demean = demean) + end) + table = Tables.table(vals'; header = col_names) + return SummaryStats(table; name = "Autocorrelation", labels = names_of_params) + else + data = to_vector_of_matrices(chn) + return map(enumerate(data)) do (i, x) + name_chain = "Autocorrelation (Chain $i)" + vals = stack(map(eachslice(x; dims = 2)) do xi + return autocor(xi, lags; demean = demean) + end) + table = Tables.table(vals'; header = col_names) + return SummaryStats(table; name = name_chain, labels = names_of_params) + end + end end """ @@ -54,7 +72,7 @@ function _default_lags(chains::Chains, append_chains::Bool) end """ - cor(chains[; sections, append_chains = true, kwargs...]) + cor(chains[; sections, append_chains = true]) Compute the Pearson correlation matrix for the chain. @@ -65,7 +83,6 @@ function cor( chains::Chains; sections = _default_sections(chains), append_chains = true, - kwargs..., ) # Subset the chain. _chains = Chains(chains, _clean_sections(chains, sections)) @@ -74,33 +91,32 @@ function cor( names_of_params = names(_chains) if append_chains - df = chaindataframe_cor("Correlation", names_of_params, to_matrix(_chains)) + df = summarystats_cor("Correlation", names_of_params, to_matrix(_chains)) return df else vector_of_df = [ - chaindataframe_cor("Correlation - Chain $i", names_of_params, data) for + summarystats_cor("Correlation - Chain $i", names_of_params, data) for (i, data) in enumerate(to_vector_of_matrices(_chains)) ] return vector_of_df end end -function chaindataframe_cor(name, names_of_params, chains::AbstractMatrix; kwargs...) +function summarystats_cor(name, names_of_params, chains::AbstractMatrix) # Compute the correlation matrix. cormat = cor(chains) - # Summarize the results in a named tuple. - nt = (; - parameters = names_of_params, - zip(names_of_params, (cormat[:, i] for i in axes(cormat, 2)))..., + # Summarize the results in a dict + dict = OrderedCollections.OrderedDict( + k => v for (k, v) in zip(names_of_params, eachcol(cormat)) ) - # Create a ChainDataFrame. - return ChainDataFrame(name, nt; kwargs...) + # Create a SummaryStats. + return SummaryStats(dict; name = name, labels = names_of_params) end """ - changerate(chains[; sections, append_chains = true, kwargs...]) + changerate(chains[; sections, append_chains = true]) Compute the change rate for the chain. @@ -111,7 +127,6 @@ function changerate( chains::Chains{<:Real}; sections = _default_sections(chains), append_chains = true, - kwargs..., ) # Subset the chain. _chains = Chains(chains, _clean_sections(chains, sections)) @@ -120,26 +135,26 @@ function changerate( names_of_params = names(_chains) if append_chains - df = chaindataframe_changerate("Change Rate", names_of_params, _chains.value.data) - return df + stats = summarystats_changerate("Change Rate", names_of_params, _chains.value.data) + return stats else - vector_of_df = [ - chaindataframe_changerate("Change Rate - Chain $i", names_of_params, data) - for (i, data) in enumerate(to_vector_of_matrices(_chains)) + vector_of_stats = [ + summarystats_changerate("Change Rate - Chain $i", names_of_params, data) for + (i, data) in enumerate(to_vector_of_matrices(_chains)) ] - return vector_of_df + return vector_of_stats end end -function chaindataframe_changerate(name, names_of_params, chains; kwargs...) +function summarystats_changerate(name, names_of_params, chains) # Compute the change rates. changerates, mvchangerate = changerate(chains) # Summarize the results in a named tuple. - nt = (; zip(names_of_params, changerates)..., multivariate = mvchangerate) + nt = (; label = names_of_params, changerate = changerates) - # Create a ChainDataFrame. - return ChainDataFrame(name, nt; kwargs...) + # Create a SummaryStats. + return SummaryStats(nt; name = name), mvchangerate end changerate(chains::AbstractMatrix{<:Real}) = changerate(reshape(chains, Val(3))) @@ -173,7 +188,6 @@ end """ describe(io, chains[; q = [0.025, 0.25, 0.5, 0.75, 0.975], - etype = :bm, kwargs...]) Print chain metadata, summary statistics, and quantiles. Use `describe(chains)` for REPL output to `stdout`, or specify `io` for other streams (e.g., file output). """ @@ -181,90 +195,119 @@ function DataAPI.describe( io::IO, chains::Chains; q = [0.025, 0.25, 0.5, 0.75, 0.975], - etype = :bm, kwargs..., ) print(io, "Chains ", chains, ":\n\n", header(chains)) - summstats = summarystats(chains; etype = etype, kwargs...) + summstats = summarystats(chains; kwargs...) println(io) show(io, MIME("text/plain"), summstats) qs = quantile(chains; q = q, kwargs...) println(io) + println(io) show(io, MIME("text/plain"), qs) end # Convenience method for default IO DataAPI.describe(chains::Chains; kwargs...) = DataAPI.describe(stdout, chains; kwargs...) -function _hpd(x::AbstractVector{<:Real}; alpha::Real = 0.05) - n = length(x) - m = max(1, ceil(Int, alpha * n)) +""" + eti(chn::Chains; prob::Real=$DEFAULT_CI_PROB, kwargs...) + +Return the equal-tailed interval (ETI) representing `prob` probability mass. + +The bounds of the ETI are the symmetric quantiles so that the interval contains `prob` +probability mass. + +Remaining keyword arguments are forwarded to [`summarize`](@ref). - y = sort(x) - a = y[1:m] - b = y[(n-m+1):n] - _, i = findmin(b - a) +See also [`quantile`](@ref), [`hdi`](@ref) - return [a[i], b[i]] +# Examples + +```jldoctest +julia> using StableRNGs; rng = StableRNG(42); + +julia> val = rand(rng, 500, 2, 3); + +julia> chn = Chains(val, [:a, :b]); + +julia> eti(chn) +ETI + eti89 + a 0.0620 .. 0.942 + b 0.0486 .. 0.939 +``` +""" +function PosteriorStats.eti(chn::Chains; prob::Real = DEFAULT_CI_PROB, kwargs...) + eti_name = Symbol("eti$(_prob_to_string(prob))") + return summarize(chn, eti_name => (x -> eti(x; prob)); name = "ETI", kwargs...) end """ - hpd(chn::Chains; alpha::Real=0.05, kwargs...) + hdi(chn::Chains; prob::Real=$DEFAULT_CI_PROB, method=:unimodal, kwargs...) + +Return the highest density interval (HDI) representing `prob` probability mass. -Return the highest posterior density interval representing `1-alpha` probability mass. +Note that for the default (`method=:unimodal`), this will return a single interval. +For multiple intervals for discontinuous regions, use `method=:multimodal`. +See [`PosteriorStats.hdi`](@extref) for more details. -Note that this will return a single interval and will not return multiple intervals for discontinuous regions. +Remaining keyword arguments are forwarded to [`summarize`](@ref). + +See also [`eti`](@ref) # Examples -```julia-repl -julia> val = rand(500, 2, 3); -julia> chn = Chains(val, [:a, :b]); +```jldoctest +julia> using StableRNGs; rng = StableRNG(42); -julia> hpd(chn) -HPD - parameters lower upper - Symbol Float64 Float64 +julia> val = rand(rng, 500, 2, 3); - a 0.0554 0.9944 - b 0.0114 0.9460 +julia> chn = Chains(val, [:a, :b]); + +julia> hdi(chn) +HDI + hdi89 + a 0.104 .. 0.977 + b 0.0827 .. 0.966 ``` """ -function hpd(chn::Chains; alpha::Real = 0.05, kwargs...) - labels = [:lower, :upper] - l(x) = _hpd(x, alpha = alpha)[1] - u(x) = _hpd(x, alpha = alpha)[2] - return summarize(chn, l, u; name = "HPD", func_names = labels, kwargs...) +function PosteriorStats.hdi(chn::Chains; prob::Real = DEFAULT_CI_PROB, kwargs...) + hdi_name = Symbol("hdi$(_prob_to_string(prob))") + return summarize(chn, hdi_name => (x -> hdi(x; prob)); name = "HDI", kwargs...) end +_prob_to_string(prob; digits = 2) = + replace(string(round(100 * prob; digits)), r"\.0+$" => "") + +@deprecate hpd(chn::Chains; alpha::Real = 0.05, kwargs...) hdi( + chn; + prob = 1 - alpha, + kwargs..., +) + """ - quantile(chains[; q = [0.025, 0.25, 0.5, 0.75, 0.975], append_chains = true, kwargs...]) + quantile(chains[; q = (0.025, 0.25, 0.5, 0.75, 0.975), append_chains = true, kwargs...]) Compute the quantiles for each parameter in the chain. Setting `append_chains=false` will return a vector of dataframes containing the quantiles for each chain. + +For intervals defined by symmetric quantiles, see [`eti`](@ref). """ function quantile( chains::Chains; - q::AbstractVector = [0.025, 0.25, 0.5, 0.75, 0.975], - append_chains = true, + q::Union{Tuple,AbstractVector} = (0.025, 0.25, 0.5, 0.75, 0.975), kwargs..., ) # compute quantiles - funs = Function[] - func_names = @. Symbol(100 * q, :%) - for i in q - push!(funs, x -> quantile(cskip(x), i)) - end - + func_names = Tuple(Symbol.(100 .* q, :%)) return summarize( chains, - funs...; - func_names = func_names, - append_chains = append_chains, + func_names => (Base.Fix2(quantile, q) ∘ cskip); name = "Quantiles", kwargs..., ) @@ -272,103 +315,15 @@ end """ - function summarystats( - chains; - sections = _default_sections(chains), - append_chains= true, - autocov_method::AbstractAutocovMethod = AutocovMethod(), - maxlag = 250, - kwargs... - ) + summarystats(chains; kwargs...) -Compute the mean, standard deviation, Monte Carlo standard error, bulk- and tail- effective -sample size, and ``\\widehat{R}`` diagnostic for each parameter in the chain. +Compute default summary statistics from the `chains`. -Setting `append_chains=false` will return a vector of dataframes containing the summary -statistics for each chain. - -When estimating the effective sample size, autocorrelations are computed for at most `maxlag` lags. +`kwargs` are forwarded to [`summarize`](@ref). To customize the summary statistics, see +`summarize`. """ -function summarystats( - chains::Chains; - sections = _default_sections(chains), - append_chains::Bool = true, - autocov_method::MCMCDiagnosticTools.AbstractAutocovMethod = AutocovMethod(), - maxlag = 250, - name = "Summary Statistics", - kwargs..., -) - # Store everything. - funs = [mean ∘ cskip, std ∘ cskip] - func_names = [:mean, :std] - - # Subset the chain. - _chains = Chains(chains, _clean_sections(chains, sections)) - - # Calculate MCSE and ESS/R-hat separately. - nt_additional = NamedTuple() - try - mcse_df = MCMCDiagnosticTools.mcse( - _chains; - sections = nothing, - autocov_method = autocov_method, - maxlag = maxlag, - ) - nt_additional = merge(nt_additional, (; mcse = mcse_df.nt.mcse)) - catch e - @warn "MCSE calculation failed: $e" - end - - try - ess_tail_df = MCMCDiagnosticTools.ess( - _chains; - sections = nothing, - autocov_method = autocov_method, - maxlag = maxlag, - kind = :tail, - ) - nt_additional = merge(nt_additional, (ess_tail = ess_tail_df.nt.ess,)) - catch e - @warn "Tail ESS calculation failed: $e" - end - - try - ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat( - _chains; - sections = nothing, - autocov_method = autocov_method, - maxlag = maxlag, - kind = :rank, - ) - nt_ess_rhat_rank = ( - ess_bulk = ess_rhat_rank_df.nt.ess, - rhat = ess_rhat_rank_df.nt.rhat, - ess_per_sec = ess_rhat_rank_df.nt.ess_per_sec, - ) - nt_additional = merge(nt_additional, nt_ess_rhat_rank) - catch e - @warn "Bulk ESS/R-hat calculation failed: $e" - end - - # Possibly re-order the columns to stay backwards-compatible. - additional_keys = (:mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec) - additional_df = ChainDataFrame( - "Additional", - (; ((k, nt_additional[k]) for k in additional_keys if k ∈ keys(nt_additional))...), - ) - - # Summarize. - summary_df = summarize( - _chains, - funs...; - func_names, - append_chains, - additional_df, - name, - sections = nothing, - ) - - return summary_df +function summarystats(chains::Chains; name = "Summary Statistics", kwargs...) + return summarize(chains; name, kwargs...) end """ @@ -377,15 +332,7 @@ end Calculate the mean of a chain. """ function mean(chains::Chains; kwargs...) - # Store everything. - funs = [mean ∘ cskip] - func_names = [:mean] - - # Summarize. - summary_df = - summarize(chains, funs...; func_names = func_names, name = "Mean", kwargs...) - - return summary_df + return summarize(chains, :mean => mean ∘ cskip; name = "Mean", kwargs...) end mean(chn::Chains, syms) = mean(chn[:, syms, :]) diff --git a/src/summarize.jl b/src/summarize.jl index 53af6364..59cca133 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -1,137 +1,29 @@ -struct ChainDataFrame{NT<:NamedTuple} - name::String - nt::NT - nrows::Int - ncols::Int - - function ChainDataFrame(name::String, nt::NamedTuple) - lengths = length(first(nt)) - all(x -> length(x) == lengths, nt) || error("Lengths must be equal.") - - return new{typeof(nt)}(name, nt, lengths, length(nt)) - end -end - -ChainDataFrame(nt::NamedTuple) = ChainDataFrame("", nt) - -Base.size(c::ChainDataFrame) = (c.nrows, c.ncols) -Base.names(c::ChainDataFrame) = collect(keys(c.nt)) - -# Display - -function Base.show(io::IO, df::ChainDataFrame) - print(io, df.name, " (", df.nrows, " x ", df.ncols, ")") -end - -function Base.show(io::IO, ::MIME"text/plain", df::ChainDataFrame) - digits = get(io, :digits, 4) - formatter = PrettyTables.ft_printf("%.$(digits)f") - - println(io, df.name) - # Support for PrettyTables 0.9 (`borderless`) and 0.10 (`tf_borderless`) - PrettyTables.pretty_table( - io, df.nt; - formatters = formatter, - tf = isdefined(PrettyTables, :borderless) ? PrettyTables.borderless : PrettyTables.tf_borderless, +""" + summarize( + chains[, stats_funs...]; + append_chains=true, + [sections, var_names], + kwargs..., ) -end -Base.isequal(c1::ChainDataFrame, c2::ChainDataFrame) = isequal(c1, c2) +Summarize `chains` in a [`PosteriorStats.SummaryStats`](@extref). -# Index functions -function Base.getindex(c::ChainDataFrame, s::Union{Colon, Integer, UnitRange}, g::Union{Colon, Integer, UnitRange}) - convert(Array, getindex(c, c.nt[:parameters][s], collect(keys(c.nt))[g])) -end +`stats_funs` is a collection of functions that reduces a matrix with shape `(draws, chains)` +to a scalar or a collection of scalars. Alternatively, an item in `stats_funs` may be a +`Pair` of the form `name => fun` specifying the name to be used for the statistic or of the +form `(name1, ...) => fun` when the function returns a collection. When the function returns +a collection, the names in this latter format must be provided. -Base.getindex(c::ChainDataFrame, s::Vector{Symbol}, ::Colon) = getindex(c, s) -function Base.getindex(c::ChainDataFrame, s::Union{Symbol, Vector{Symbol}}) - getindex(c, s, collect(keys(c.nt))) -end +# Keywords -function Base.getindex(c::ChainDataFrame, s::Union{Colon, Integer, UnitRange}, ks) - getindex(c, c.nt[:parameters][s], ks) -end - -# dispatches involing `String` and `AbstractVector{String}` -Base.getindex(c::ChainDataFrame, s::String, ks) = getindex(c, Symbol(s), ks) -function Base.getindex(c::ChainDataFrame, s::AbstractVector{String}, ks) - return getindex(c, Symbol.(s), ks) -end - -# dispatch for `Symbol` -Base.getindex(c::ChainDataFrame, s::Symbol, ks) = getindex(c, [s], ks) - -function Base.getindex(c::ChainDataFrame, s::AbstractVector{Symbol}, ks::Symbol) - return getindex(c, s, [ks]) -end - -function Base.getindex( - c::ChainDataFrame, - s::AbstractVector{Symbol}, - ks::AbstractVector{Symbol} -) - ind = indexin(s, c.nt[:parameters]) - - not_found = map(x -> x === nothing, ind) - - any(not_found) && error("Cannot find parameters $(s[not_found]) in chain") - - # If there are multiple columns, return a new CDF. - if length(ks) > 1 - if !(:parameters in ks) - ks = vcat(:parameters, ks) - end - nt = NamedTuple{tuple(ks...)}(tuple([c.nt[k][ind] for k in ks]...)) - return ChainDataFrame(c.name, nt) - else - # Otherwise, return a vector if there's multiple parameters - # or just a scalar if there's one parameter. - if length(s) == 1 - return c.nt[ks[1]][ind][1] - else - return c.nt[ks[1]][ind] - end - end -end - -function Base.lastindex(c::ChainDataFrame, i::Integer) - if i == 1 - return c.nrows - elseif i ==2 - return c.ncols - else - error("No such dimension") - end -end - -function Base.convert(::Type{Array}, c::ChainDataFrame) - T = promote_eltype_namedtuple_tail(c.nt) - return convert(Array{T}, c) -end -function Base.convert(::Type{Array{T}}, c::ChainDataFrame) where {T} - arr = Array{T, 2}(undef, c.nrows, c.ncols - 1) - - for (i, k) in enumerate(Iterators.drop(keys(c.nt), 1)) - arr[:, i] = c.nt[k] - end - - return arr -end - -function Base.convert(::Type{Array}, cs::Vector{ChainDataFrame{NamedTuple{K,V}}}) where {K,V} - T = promote_eltype_tuple_type(Base.tuple_type_tail(V)) - return convert(Array{T}, cs) -end -function Base.convert(::Type{Array{T}}, cs::Vector{<:ChainDataFrame}) where {T} - return mapreduce((x, y) -> cat(x, y; dims = Val(3)), cs) do c - reshape(convert(Array{T}, c), Val(3)) - end -end - -""" - summarize(chains, funs...[; sections, func_names = [], name = "", append_chains = true]) - -Summarize `chains` in a `ChainsDataFrame`. +- `section`: The sections of the chain to include in the summary. If not provided, defaults + to `:parameters`. +- `append_chains`: If `true`, a single `SummaryStats` for all chains is returned. If + `false`, a vector of `SummaryStats` (one for each chain) is returned. +- `var_names`: The names of the parameters in data. If not provided, the names are taken + from `chains`. +- `kwargs...`: Additional keyword arguments are forwarded to + [`PosteriorStats.summarize`](@extref). # Examples @@ -140,56 +32,32 @@ Summarize `chains` in a `ChainsDataFrame`. * `summarize(chns; sections=[:parameters])` : Chain summary of :parameters section * `summarize(chns; sections=[:parameters, :internals])` : Chain summary for multiple sections """ -function summarize( - chains::Chains, funs...; +function PosteriorStats.summarize( + chains::Chains, + funs...; sections = _default_sections(chains), - func_names::AbstractVector{Symbol} = Symbol[], append_chains::Bool = true, - name::String = "", - additional_df = nothing + var_names = nothing, + name::AbstractString = "SummaryStats", + kwargs..., ) - # If we weren't given any functions, fall back to summary stats. - if isempty(funs) - return summarystats(chains; sections, append_chains, name) - end - # Generate a chain to work on. chn = Chains(chains, _clean_sections(chains, sections)) # Obtain names of parameters. - names_of_params = names(chn) - - # If no function names were given, make a new list. - fnames = isempty(func_names) ? collect(nameof.(funs)) : func_names - - # Obtain the additional named tuple. - additional_nt = additional_df === nothing ? NamedTuple() : additional_df.nt + names_of_params = var_names === nothing ? names(chn) : var_names if append_chains # Evaluate the functions. - data = to_matrix(chn) - fvals = [[f(data[:, i]) for i in axes(data, 2)] for f in funs] - - # Build the ChainDataFrame. - nt = merge((; parameters = names_of_params, zip(fnames, fvals)...), additional_nt) - df = ChainDataFrame(name, nt) - - return df + data = _permutedims_diagnostics(chn.value.data) + summarize(data, funs...; var_names = names_of_params, name, kwargs...) else # Evaluate the functions. data = to_vector_of_matrices(chn) - vector_of_fvals = [[[f(x[:, i]) for i in axes(x, 2)] for f in funs] for x in data] - - # Build the ChainDataFrames. - vector_of_nt = [ - merge((; parameters = names_of_params, zip(fnames, fvals)...), additional_nt) - for fvals in vector_of_fvals - ] - vector_of_df = [ - ChainDataFrame(name * " (Chain $i)", nt) - for (i, nt) in enumerate(vector_of_nt) - ] - - return vector_of_df + return map(enumerate(data)) do (i, x) + z = reshape(x, size(x, 1), 1, size(x, 2)) + name_chain = name * " (Chain $i)" + summarize(z, funs...; var_names = names_of_params, name = name_chain, kwargs...) + end end end diff --git a/src/tables.jl b/src/tables.jl index d5ad2e67..7a70db04 100644 --- a/src/tables.jl +++ b/src/tables.jl @@ -69,42 +69,3 @@ function IteratorInterfaceExtensions.getiterator(chn::Chains) end TableTraits.isiterabletable(::Chains) = true - -#### -#### ChainDataFrame -#### - -#### Tables interface - -Tables.istable(::Type{<:ChainDataFrame}) = true - -# AbstractColumns interface - -Tables.columnaccess(::Type{<:ChainDataFrame}) = true - -Tables.columns(cdf::ChainDataFrame) = cdf - -Tables.columnnames(::ChainDataFrame{<:NamedTuple{names}}) where {names} = names - -Tables.getcolumn(cdf::ChainDataFrame, i::Int) = cdf.nt[i] -Tables.getcolumn(cdf::ChainDataFrame, nm::Symbol) = cdf.nt[nm] - -# row access - -Tables.rowaccess(::Type{<:ChainDataFrame}) = true - -Tables.rows(cdf::ChainDataFrame) = Tables.rows(Tables.columntable(cdf)) - -function Tables.schema(::ChainDataFrame{NamedTuple{names,T}}) where {names,T} - types = ntuple(i -> eltype(fieldtype(T, i)), fieldcount(T)) - return Tables.Schema(names, types) -end - -#### TableTraits interface - -IteratorInterfaceExtensions.isiterable(::ChainDataFrame) = true -function IteratorInterfaceExtensions.getiterator(cdf::ChainDataFrame) - return Tables.datavaluerows(Tables.columntable(cdf)) -end - -TableTraits.isiterabletable(::ChainDataFrame) = true diff --git a/test/Project.toml b/test/Project.toml index 4166a6e5..d0588054 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -15,8 +15,10 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" @@ -35,13 +37,15 @@ FFTW = "1.1" IteratorInterfaceExtensions = "1" KernelDensity = "0.6.2" Logging = "<0.0.1, 1" -MCMCChains = "7" +MCMCChains = "8" MCMCDiagnosticTools = "0.3.10" MLJBase = "1" MLJDecisionTreeInterface = "0.4" Plots = "1.40.2" +PosteriorStats = "0.4" Random = "<0.0.1, 1" Serialization = "<0.0.1, 1" +StableRNGs = "1" Statistics = "<0.0.1, 1" StatsBase = "0.34" StatsPlots = "0.15" diff --git a/test/diagnostic_tests.jl b/test/diagnostic_tests.jl index a7a4d443..31f06817 100644 --- a/test/diagnostic_tests.jl +++ b/test/diagnostic_tests.jl @@ -1,6 +1,9 @@ using MCMCChains using AbstractMCMC: AbstractChains using Dates +using DataFrames +using PosteriorStats: SummaryStats +using Tables using Test ## CHAIN TESTS @@ -68,8 +71,8 @@ chn_disc = Chains(val_disc, start = 1, thin = 2) @test all(MCMCChains.indiscretesupport(chn) .== [false, false, false, true]) @test setinfo(chn, NamedTuple{(:A, :B)}((1,2))).info == NamedTuple{(:A, :B)}((1,2)) @test isa(set_section(chn, Dict(:internals => ["param_1"])), AbstractChains) - @test mean(chn) isa ChainDataFrame - @test mean(chn, ["param_1", "param_3"]) isa ChainDataFrame + @test mean(chn) isa SummaryStats + @test mean(chn, ["param_1", "param_3"]) isa SummaryStats @test 0.95 ≤ mean(chn, "param_1") ≤ 1.05 end @@ -167,7 +170,7 @@ end @testset "function tests" begin tchain = Chains(rand(niter, nparams, nchains), ["a", "b", "c"], Dict(:internals => ["c"])) - @test eltype(discretediag(chn_disc[:,2:2,:])) <: ChainDataFrame + @test eltype(discretediag(chn_disc[:, 2:2, :])) <: SummaryStats gelman = gelmandiag(tchain) gelmanmv = gelmandiag_multivariate(tchain) @@ -176,18 +179,23 @@ end raferty = rafterydiag(tchain) # test raw return values - @test typeof(gelman) <: ChainDataFrame - @test typeof(gelmanmv) <: Tuple{ChainDataFrame,Float64} - @test typeof(geweke) <: Array{<:ChainDataFrame} - @test typeof(heidel) <: Array{<:ChainDataFrame} - @test typeof(raferty) <: Array{<:ChainDataFrame} - - # test ChainDataFrame sizes - @test size(gelman) == (2,3) - @test size(gelmanmv[1]) == (2,3) - @test size(geweke[1]) == (2,3) - @test size(heidel[1]) == (2,7) - @test size(raferty[1]) == (2,6) + @test typeof(gelman) <: SummaryStats + @test typeof(gelmanmv) <: Tuple{SummaryStats,Float64} + @test typeof(geweke) <: Array{<:SummaryStats} + @test typeof(heidel) <: Array{<:SummaryStats} + @test typeof(raferty) <: Array{<:SummaryStats} + + # test SummaryStats sizes + for s in (gelman, gelmanmv[1], geweke[1], heidel[1], raferty[1]) + @test s isa SummaryStats + df = DataFrame(s) + @test size(df, 1) == 2 + end + @test size(DataFrame(gelman), 2) == 3 + @test size(DataFrame(gelmanmv[1]), 2) == 3 + @test size(DataFrame(geweke[1]), 2) == 3 + @test size(DataFrame(heidel[1]), 2) == 7 + @test size(DataFrame(raferty[1]), 2) == 6 end @testset "stats tests" begin @@ -203,31 +211,38 @@ end @test lags == filter!(x -> x < n, [1, 5, 10, 50]) acor = autocor(c; append_chains=append_chains) - # Number of columns in the ChainDataFrame(s): lags + parameters + # Number of columns in the SummaryStats: lags + parameters ncols = length(lags) + 1 if append_chains - @test acor isa ChainDataFrame - @test size(acor)[2] == ncols + @test acor isa SummaryStats + @test length(keys(acor)) == ncols else - @test acor isa Vector{<:ChainDataFrame} - @test all(size(a)[2] == ncols for a in acor) + @test acor isa Vector{<:SummaryStats} + @test all(length(keys(a)) == ncols for a in acor) end end - @test autocor(c) isa ChainDataFrame - @test convert(Array, autocor(c)) == convert(Array, autocor(c; append_chains=true)) + @test autocor(c) isa SummaryStats + @test autocor(c) == autocor(c; append_chains = true) end - @test MCMCChains.cor(chn) isa ChainDataFrame - @test MCMCChains.cor(chn; append_chains = false) isa Vector{<:ChainDataFrame} + @test MCMCChains.cor(chn) isa SummaryStats + @test MCMCChains.cor(chn; append_chains = false) isa Vector{<:SummaryStats} + + @test MCMCChains.changerate(chn) isa Tuple{SummaryStats,Float64} + @test MCMCChains.changerate(chn; append_chains = false) isa + Vector{<:Tuple{SummaryStats,Float64}} + + @test eti(chn) isa SummaryStats + @test eti(chn; append_chains = false) isa Vector{<:SummaryStats} - @test MCMCChains.changerate(chn) isa ChainDataFrame - @test MCMCChains.changerate(chn; append_chains = false) isa Vector{<:ChainDataFrame} + @test hdi(chn) isa SummaryStats + @test hdi(chn; append_chains = false) isa Vector{<:SummaryStats} - @test hpd(chn) isa ChainDataFrame - @test hpd(chn; append_chains = false) isa Vector{<:ChainDataFrame} + result = hdi(chn) + @test :hdi89 in Tables.columnnames(result) - result = hpd(chn) - @test all(result.nt.upper .> result.nt.lower) + @test_deprecated hpd(chn) + @test hpd(chn) == hdi(chn; prob = 0.95) end @testset "vector of vectors" begin diff --git a/test/ess_rhat_tests.jl b/test/ess_rhat_tests.jl index bfac0d09..27d94d5b 100644 --- a/test/ess_rhat_tests.jl +++ b/test/ess_rhat_tests.jl @@ -1,5 +1,7 @@ using MCMCChains +using DataFrames using FFTW +using PosteriorStats: SummaryStats using Random using Statistics @@ -18,8 +20,10 @@ using Test for f in (ess, ess_rhat) s = f(c) - @test length(s[:,:ess_per_sec]) == 5 - @test all(map(!ismissing, s[:,:ess_per_sec])) + @test s isa SummaryStats + df = DataFrame(s) + @test length(df[!, :ess_per_sec]) == 5 + @test all(map(!ismissing, df[!, :ess_per_sec])) end end @@ -29,16 +33,25 @@ end for autocov_method in (AutocovMethod(), FFTAutocovMethod(), BDAAutocovMethod()), kind in (:bulk, :basic), f in (ess, ess_rhat, rhat) # analyze chain - ess_df = ess(chain; autocov_method = autocov_method, kind = kind) - rhat_df = rhat(chain; kind = kind) - ess_rhat_df = ess_rhat(chain; autocov_method = autocov_method, kind = kind) + ess_stats = ess(chain; autocov_method = autocov_method, kind = kind) + rhat_stats = rhat(chain; kind = kind) + ess_rhat_stats = ess_rhat(chain; autocov_method = autocov_method, kind = kind) + + @test ess_stats isa SummaryStats + @test ess_stats.name == "ESS" + @test rhat_stats isa SummaryStats + @test rhat_stats.name == "R-hat" + @test ess_rhat_stats isa SummaryStats + @test ess_rhat_stats.name == "ESS/R-hat" + + ess_df, rhat_df, ess_rhat_df = DataFrame.((ess_stats, rhat_stats, ess_rhat_stats)) # analyze array ess_array, rhat_array = ess_rhat( permutedims(x, (1, 3, 2)); autocov_method = autocov_method, kind = kind, ) - @test ess_df[:,2] == ess_rhat_df[:,2] == ess_array - @test rhat_df[:,2] == ess_rhat_df[:,3] == rhat_array + @test ess_df[!, :ess] == ess_rhat_df[!, :ess] == ess_array + @test rhat_df[!, :rhat] == ess_rhat_df[!, :rhat] == rhat_array end end @@ -48,16 +61,24 @@ end for autocov_method in (AutocovMethod(), FFTAutocovMethod(), BDAAutocovMethod()) # analyze chain - ess_df = ess(chain; autocov_method = autocov_method) - @test isequal(ess_df[:, :ess], fill(NaN, 5)) - @test isequal(ess_df[:, :ess_per_sec], fill(missing, 5)) - - ess_rhat_df = ess_rhat(chain; autocov_method = autocov_method) - @test isequal(ess_rhat_df[:, :ess], fill(NaN, 5)) - @test isequal(ess_rhat_df[:, :rhat], fill(NaN, 5)) - @test isequal(ess_rhat_df[:, :ess_per_sec], fill(missing, 5)) + ess_stats = ess(chain; autocov_method = autocov_method) + @test ess_stats isa SummaryStats + @test ess_stats.name == "ESS" + ess_df = DataFrame(ess_stats) + @test isequal(ess_df[!, :ess], fill(NaN, 5)) + @test isequal(ess_df[!, :ess_per_sec], fill(missing, 5)) + ess_rhat_stats = ess_rhat(chain; autocov_method = autocov_method) + @test ess_rhat_stats isa SummaryStats + @test ess_rhat_stats.name == "ESS/R-hat" + ess_rhat_df = DataFrame(ess_rhat_stats) + @test isequal(ess_rhat_df[!, :ess], fill(NaN, 5)) + @test isequal(ess_rhat_df[!, :rhat], fill(NaN, 5)) + @test isequal(ess_rhat_df[!, :ess_per_sec], fill(missing, 5)) end - rhat_df = rhat(chain) - @test isequal(rhat_df[:, :rhat], fill(NaN, 5)) + rhat_stats = rhat(chain) + @test rhat_stats isa SummaryStats + @test rhat_stats.name == "R-hat" + rhat_df = DataFrame(rhat_stats) + @test isequal(rhat_df[!, :rhat], fill(NaN, 5)) end diff --git a/test/mcse_tests.jl b/test/mcse_tests.jl index fdbefbb8..f4728cdb 100644 --- a/test/mcse_tests.jl +++ b/test/mcse_tests.jl @@ -1,4 +1,6 @@ using MCMCChains +using DataFrames +using PosteriorStats: SummaryStats using Random using Statistics @@ -14,21 +16,27 @@ mymean(x) = mean(x) if kind !== mymean for autocov_method in (AutocovMethod(), BDAAutocovMethod()) # analyze chain - mcse_df = mcse(chain; autocov_method = autocov_method, kind = kind) + mcse_stats = mcse(chain; autocov_method = autocov_method, kind = kind) + @test mcse_stats isa SummaryStats + @test mcse_stats.name == "MCSE" + mcse_df = DataFrame(mcse_stats) # analyze array mcse_array = mcse( PermutedDimsArray(x, (1, 3, 2)); autocov_method = autocov_method, kind = kind, ) - @test mcse_df[:,2] == mcse_array + @test mcse_df[!, :mcse] == mcse_array end else # analyze chain - mcse_df = mcse(chain; kind = kind) + mcse_stats = mcse(chain; kind = kind) + @test mcse_stats isa SummaryStats + @test mcse_stats.name == "MCSE" + mcse_df = DataFrame(mcse_stats) # analyze array mcse_array = mcse(PermutedDimsArray(x, (1, 3, 2)); kind = kind) - @test mcse_df[:,2] == mcse_array + @test mcse_df[!, :mcse] == mcse_array end end end diff --git a/test/missing_tests.jl b/test/missing_tests.jl index 46b6227c..8f738a0d 100644 --- a/test/missing_tests.jl +++ b/test/missing_tests.jl @@ -5,9 +5,7 @@ using Random # Tests for missing values. function testdiff(cdf1, cdf2) - m1 = convert(Array, cdf1) - m2 = convert(Array, cdf2) - return all(((x, y),) -> isapprox(x, y; atol=1e-2), zip(m1, m2)) + return all(((x, y),) -> isapprox(x, y; atol = 1e-2), Iterators.drop(zip(cdf1, cdf2), 1)) end @testset "utils" begin @@ -35,9 +33,9 @@ end rf_2 = rafterydiag(chn_m) @testset "diagnostics missing tests" for i in 1:nchains - @test testdiff(gw_1, gw_2) - @test testdiff(hd_1, hd_2) - @test testdiff(rf_1, rf_2) + @test all(Base.splat(testdiff), zip(gw_1, gw_2)) + @test all(Base.splat(testdiff), zip(hd_1, hd_2)) + @test all(Base.splat(testdiff), zip(rf_1, rf_2)) end @test_throws MethodError discretediag(chn_m) diff --git a/test/runtests.jl b/test/runtests.jl index f3a297ba..d1eb7cfc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,10 @@ Random.seed!(0) println("Model statistics") @time include("modelstats_test.jl") + # run tests for show methods + println("Show methods") + @time include("show_tests.jl") + # run tests for concatenation println("Concatenation") @time include("concatenation_tests.jl") diff --git a/test/show_tests.jl b/test/show_tests.jl new file mode 100644 index 00000000..8d72f280 --- /dev/null +++ b/test/show_tests.jl @@ -0,0 +1,29 @@ +using Test +using MCMCChains + +@testset "Show tests" begin + rng = MersenneTwister(1234) + val = rand(rng, 100, 4, 4) + parm_names = ["a", "b", "c", "d"] + chns = Chains(val, parm_names, Dict(:internals => ["b", "d"]))[1:2:99, :, :] + str = sprint(show, "text/plain", chns) + expected_str = """ + Chains MCMC chain (50×4×4 Array{Float64, 3}): + + Iterations = 1:2:99 + Number of chains = 4 + Samples per chain = 50 + parameters = a, c + internals = b, d + + + Use `describe(chains)` for summary statistics and quantiles. + """ + @test str == expected_str + + describe_str = sprint(describe, chns) + @test occursin("Summary Statistics", describe_str) + @test occursin("Quantiles", describe_str) + @test occursin("parameters", describe_str) + @test occursin("internals", describe_str) +end diff --git a/test/summarize_tests.jl b/test/summarize_tests.jl index d847ec43..60c76af8 100644 --- a/test/summarize_tests.jl +++ b/test/summarize_tests.jl @@ -1,54 +1,76 @@ +using DataFrames using MCMCChains, Test +using PosteriorStats: SummaryStats using Statistics: std -@testset "Summarize to DataFrame tests" begin +@testset "Summarize tests" begin val = rand(1000, 8, 4) - chns = Chains( - val, - ["a", "b", "c", "d", "e", "f", "g", "h"], - Dict(:internals => ["c", "d", "e", "f", "g", "h"]) - ) + parm_names = ["a", "b", "c", "d", "e", "f", "g", "h"] + chns = Chains(val, parm_names, Dict(:internals => ["c", "d", "e", "f", "g", "h"])) - parm_df = summarize(chns, sections=[:parameters]) + parm_stats = summarize(chns, sections = [:parameters]) + @test parm_stats isa SummaryStats + @test parm_stats.name == "SummaryStats" + parm_array_stats = + summarize(PermutedDimsArray(val[:, 1:2, :], (1, 3, 2)); var_names = [:a, :b]) - # check that display of ChainDataFrame does not error + # check that display of SummaryStats does not error println("compact display:") - show(stdout, parm_df) + show(stdout, parm_stats) println("\nverbose display:") - show(stdout, "text/plain", parm_df) + show(stdout, "text/plain", parm_stats) + + parm_df = DataFrame(parm_stats) + parm_array_df = DataFrame(parm_array_stats) - @test 0.48 < parm_df[:a, :mean][1] < 0.52 - @test names(parm_df) == [:parameters, :mean, :std, :mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec] + @test 0.48 < parm_df[1, :mean] < 0.52 + @test parm_df == parm_array_df # Indexing tests - @test isequal(convert(Array, parm_df[:a, :]), convert(Array, parm_df[:a])) - @test parm_df[:a, :][:,:parameters] == :a - @test parm_df[[:a, :b], :][:,:parameters] == [:a, :b] - - all_sections_df = summarize(chns, sections=[:parameters, :internals]) - @test all_sections_df isa ChainDataFrame - @test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h] - @test size(all_sections_df) == (8, 8) - @test all_sections_df.name == "" - - all_sections_dfs = summarize(chns, sections=[:parameters, :internals], name = "Summary", append_chains = false) - @test all_sections_dfs isa Vector{<:ChainDataFrame} - for (i, all_sections_df) in enumerate(all_sections_dfs) - @test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h] - @test size(all_sections_df) == (8, 8) - @test all_sections_df.name == "Summary (Chain $i)" + @test parm_df[!, 1] == [:a, :b] + + all_sections_stats = summarize(chns; sections = [:parameters, :internals]) + all_sections_array_stats = + summarize(PermutedDimsArray(val, (1, 3, 2)); var_names = Symbol.(parm_names)) + @test all_sections_stats isa SummaryStats + all_sections_df = DataFrame(all_sections_stats) + all_sections_array_df = DataFrame(all_sections_array_stats) + @test all_sections_df[!, 1] == Symbol.(parm_names) + @test all_sections_array_df == all_sections_df + @test all_sections_stats.name == "SummaryStats" + + all_sections_stats = summarize( + chns; + sections = [:parameters, :internals], + name = "Summary", + append_chains = false, + ) + @test all_sections_stats isa Vector{<:SummaryStats} + for (i, all_sections_stats_i) in enumerate(all_sections_stats) + all_sections_df_i = DataFrame(all_sections_stats_i) + @test all_sections_df_i[!, 1] == Symbol.(parm_names) + @test size(all_sections_df_i, 2) == size(all_sections_df, 2) + @test all_sections_stats_i.name == "Summary (Chain $i)" end - two_parms_two_funs_df = summarize(chns[[:a, :b]], mean, std) - @test two_parms_two_funs_df[:, :parameters] == [:a, :b] - @test size(two_parms_two_funs_df) == (2, 3) + two_parms_two_funs_df = DataFrame(summarize(chns[[:a, :b]], mean, std)) + @test two_parms_two_funs_df[!, 1] == [:a, :b] + @test propertynames(two_parms_two_funs_df) == [:label, :mean, :std] - three_parms_df = summarize(chns[[:a, :b, :c]], mean, std, sections=[:parameters, :internals]) - @test three_parms_df[:, :parameters] == [:a, :b, :c] - @test size(three_parms_df) == (3, 3) + three_parms_df = DataFrame( + summarize(chns[[:a, :b, :c]], mean, std; sections = [:parameters, :internals]), + ) + @test three_parms_df[!, 1] == [:a, :b, :c] + @test propertynames(three_parms_df) == [:label, :mean, :std] - three_parms_df_2 = summarize(chns[[:a, :b, :g]], mean, std, - sections=[:parameters, :internals], func_names=[:mean, :sd]) - @test three_parms_df_2[:, :parameters] == [:a, :b, :g] - @test size(three_parms_df_2) == (3, 3) + three_parms_df_2 = DataFrame( + summarize( + chns[[:a, :b, :g]], + :mymean => mean, + :mystd => std; + sections = [:parameters, :internals], + ), + ) + @test three_parms_df_2[!, 1] == [:a, :b, :g] + @test propertynames(three_parms_df_2) == [:label, :mymean, :mystd] end diff --git a/test/tables_tests.jl b/test/tables_tests.jl index 2dc99914..fc897c4f 100644 --- a/test/tables_tests.jl +++ b/test/tables_tests.jl @@ -131,111 +131,4 @@ using DataFrames @test isequal(Tables.columntable(df), Tables.columntable(chn)) end end - - @testset "ChainDataFrames" begin - val = rand(1000, 8, 4) - colnames = ["a", "b", "c", "d", "e", "f", "g", "h"] - internal_colnames = ["c", "d", "e", "f", "g", "h"] - chn = Chains(val, colnames, Dict(:internals => internal_colnames)) - - # Get ChainDataFrame objects - summstats = summarystats(chn) - qs = quantile(chn) - - # Helper function to test any ChainDataFrame - function test_chaindataframe(cdf::ChainDataFrame) - @testset "Tables interface" begin - @test Tables.istable(typeof(cdf)) - - @testset "column access" begin - @test Tables.columnaccess(typeof(cdf)) - @test Tables.columns(cdf) === cdf - @test Tables.columnnames(cdf) == keys(cdf.nt) - for (k, v) in pairs(cdf.nt) - @test isequal(Tables.getcolumn(cdf, k), v) - end - @test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1]) - @test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2]) - @test_throws Exception Tables.getcolumn(cdf, :blah) - @test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1) - end - - @testset "row access" begin - @test Tables.rowaccess(typeof(cdf)) - @test Tables.rows(cdf) isa Tables.RowIterator - @test eltype(Tables.rows(cdf)) <: Tables.AbstractRow - rows = collect(Tables.rows(cdf)) - @test eltype(rows) <: Tables.AbstractRow - @test size(rows) === (2,) - @testset for i = 1:2 - row = rows[i] - @test Tables.columnnames(row) == keys(cdf.nt) - for j in length(cdf.nt) - @test isequal(Tables.getcolumn(row, j), cdf.nt[j][i]) - @test isequal( - Tables.getcolumn(row, keys(cdf.nt)[j]), - cdf.nt[j][i], - ) - end - end - end - - @testset "integration tests" begin - @test length(Tables.rowtable(cdf)) == length(cdf.nt[1]) - @test isequal(Tables.columntable(cdf), cdf.nt) - nt = Tables.rowtable(cdf)[1] - @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...)) - @test isequal( - nt, - collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1], - ) - nt = Tables.rowtable(cdf)[2] - @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...)) - @test isequal( - nt, - collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2], - ) - @test isequal( - Tables.matrix(Tables.rowtable(cdf)), - Tables.matrix(Tables.columntable(cdf)), - ) - end - - @testset "schema" begin - schema = Tables.schema(cdf) - @test schema isa Tables.Schema - @test schema.names == keys(cdf.nt) - @test schema.types == eltype.(values(cdf.nt)) - end - end - - @testset "TableTraits interface" begin - @test IteratorInterfaceExtensions.isiterable(cdf) - @test TableTraits.isiterabletable(cdf) - nt = collect( - Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 1), - )[1] - @test isequal(nt, (; (k => v[1] for (k, v) in pairs(cdf.nt))...)) - nt = collect( - Iterators.take(IteratorInterfaceExtensions.getiterator(cdf), 2), - )[2] - @test isequal(nt, (; (k => v[2] for (k, v) in pairs(cdf.nt))...)) - end - - @testset "DataFrames.DataFrame constructor" begin - @inferred DataFrame(cdf) - df = DataFrame(cdf) - @test df isa DataFrame - @test isequal(Tables.columntable(df), cdf.nt) - end - end - - @testset "Summary Statistics" begin - test_chaindataframe(summstats) - end - - @testset "Quantiles" begin - test_chaindataframe(qs) - end - end end