1010# ' @param ... Currently unused.
1111# ' @param lw A matrix of (smoothed) log weights with the same dimensions as
1212# ' `yrep`. See [loo::psis()] and the associated `weights()` method as well as
13- # ' the **Examples** section, below.
13+ # ' the **Examples** section, below. If `lw` is not specified then
14+ # ' `psis_object` can be provided and log weights will be extracted.
15+ # ' @param psis_object If using **loo** version `2.0.0` or greater, an
16+ # ' object returned by the `psis()` function (or by the `loo()` function
17+ # ' with argument `save_psis` set to `TRUE`).
1418# ' @param alpha,size,fatten,linewidth Arguments passed to code geoms to control plot
1519# ' aesthetics. For `ppc_loo_pit_qq()` and `ppc_loo_pit_overlay()`, `size` and
1620# ' `alpha` are passed to [ggplot2::geom_point()] and
7175# ' log_radon ~ floor + log_uranium + floor:log_uranium
7276# ' + (1 + floor | county),
7377# ' data = radon,
74- # ' iter = 1000 ,
78+ # ' iter = 100 ,
7579# ' chains = 2,
7680# ' cores = 2
7781# ' )
8993# ' ppc_loo_pit_qq(y, yrep, lw = lw)
9094# ' ppc_loo_pit_qq(y, yrep, lw = lw, compare = "normal")
9195# '
96+ # ' # can use the psis object instead of lw
97+ # ' ppc_loo_pit_qq(y, yrep, psis_object = psis1)
9298# '
9399# ' # loo predictive intervals vs observations
94100# ' keep_obs <- 1:50
138144# '
139145ppc_loo_pit_overlay <- function (y ,
140146 yrep ,
141- lw ,
147+ lw = NULL ,
142148 ... ,
149+ psis_object = NULL ,
143150 pit = NULL ,
144151 samples = 100 ,
145152 size = 0.25 ,
@@ -158,6 +165,7 @@ ppc_loo_pit_overlay <- function(y,
158165 y = y ,
159166 yrep = yrep ,
160167 lw = lw ,
168+ psis_object = psis_object ,
161169 pit = pit ,
162170 samples = samples ,
163171 bw = bw ,
@@ -253,8 +261,9 @@ ppc_loo_pit_overlay <- function(y,
253261ppc_loo_pit_data <-
254262 function (y ,
255263 yrep ,
256- lw ,
264+ lw = NULL ,
257265 ... ,
266+ psis_object = NULL ,
258267 pit = NULL ,
259268 samples = 100 ,
260269 bw = " nrd0" ,
@@ -267,6 +276,7 @@ ppc_loo_pit_data <-
267276 suggested_package(" rstantools" )
268277 y <- validate_y(y )
269278 yrep <- validate_predictions(yrep , length(y ))
279+ lw <- .get_lw(lw , psis_object )
270280 stopifnot(identical(dim(yrep ), dim(lw )))
271281 pit <- rstantools :: loo_pit(object = yrep , y = y , lw = lw )
272282 }
@@ -295,22 +305,24 @@ ppc_loo_pit_data <-
295305# ' @export
296306ppc_loo_pit_qq <- function (y ,
297307 yrep ,
298- lw ,
299- pit ,
300- compare = c(" uniform" , " normal" ),
308+ lw = NULL ,
301309 ... ,
310+ psis_object = NULL ,
311+ pit = NULL ,
312+ compare = c(" uniform" , " normal" ),
302313 size = 2 ,
303314 alpha = 1 ) {
304315 check_ignored_arguments(... )
305316
306317 compare <- match.arg(compare )
307- if (! missing (pit )) {
318+ if (! is.null (pit )) {
308319 stopifnot(is.numeric(pit ), is_vector_or_1Darray(pit ))
309320 inform(" 'pit' specified so ignoring 'y','yrep','lw' if specified." )
310321 } else {
311322 suggested_package(" rstantools" )
312323 y <- validate_y(y )
313324 yrep <- validate_predictions(yrep , length(y ))
325+ lw <- .get_lw(lw , psis_object )
314326 stopifnot(identical(dim(yrep ), dim(lw )))
315327 pit <- rstantools :: loo_pit(object = yrep , y = y , lw = lw )
316328 }
@@ -352,7 +364,7 @@ ppc_loo_pit <-
352364 function (y ,
353365 yrep ,
354366 lw ,
355- pit ,
367+ pit = NULL ,
356368 compare = c(" uniform" , " normal" ),
357369 ... ,
358370 size = 2 ,
@@ -374,18 +386,14 @@ ppc_loo_pit <-
374386# ' @rdname PPC-loo
375387# ' @export
376388# ' @template args-prob-prob_outer
377- # ' @param psis_object If using **loo** version `2.0.0` or greater, an
378- # ' object returned by the `psis()` function (or by the `loo()` function
379- # ' with argument `save_psis` set to `TRUE`).
380- # ' @param intervals For `ppc_loo_intervals()` and `ppc_loo_ribbon()`,
381- # ' optionally a matrix of precomputed LOO predictive intervals
382- # ' that can be specified instead of `yrep` and `lw` (these are both
383- # ' ignored if `intervals` is specified). If not specified the intervals
384- # ' are computed internally before plotting. If specified, `intervals`
385- # ' must be a matrix with number of rows equal to the number of data points and
386- # ' five columns in the following order: lower outer interval, lower inner
387- # ' interval, median (50%), upper inner interval and upper outer interval
388- # ' (column names are ignored).
389+ # ' @param intervals For `ppc_loo_intervals()` and `ppc_loo_ribbon()`, optionally
390+ # ' a matrix of pre-computed LOO predictive intervals that can be specified
391+ # ' instead of `yrep` (ignored if `intervals` is specified). If not specified
392+ # ' the intervals are computed internally before plotting. If specified,
393+ # ' `intervals` must be a matrix with number of rows equal to the number of
394+ # ' data points and five columns in the following order: lower outer interval,
395+ # ' lower inner interval, median (50%), upper inner interval and upper outer
396+ # ' interval (column names are ignored).
389397# ' @param order For `ppc_loo_intervals()`, a string indicating how to arrange
390398# ' the plotted intervals. The default (`"index"`) is to plot them in the
391399# ' order of the observations. The alternative (`"median"`) arranges them
@@ -403,9 +411,9 @@ ppc_loo_intervals <-
403411 function (y ,
404412 yrep ,
405413 psis_object ,
414+ ... ,
406415 subset = NULL ,
407416 intervals = NULL ,
408- ... ,
409417 prob = 0.5 ,
410418 prob_outer = 0.9 ,
411419 alpha = 0.33 ,
@@ -498,11 +506,10 @@ ppc_loo_intervals <-
498506ppc_loo_ribbon <-
499507 function (y ,
500508 yrep ,
501- lw ,
502509 psis_object ,
510+ ... ,
503511 subset = NULL ,
504512 intervals = NULL ,
505- ... ,
506513 prob = 0.5 ,
507514 prob_outer = 0.9 ,
508515 alpha = 0.33 ,
@@ -720,3 +727,17 @@ ppc_loo_ribbon <-
720727
721728 list (xs = xs , unifs = bc_mat )
722729}
730+
731+ # Extract log weights from psis_object if provided
732+ .get_lw <- function (lw = NULL , psis_object = NULL ) {
733+ if (is.null(lw ) && is.null(psis_object )) {
734+ abort(" One of 'lw' and 'psis_object' must be specified." )
735+ } else if (is.null(lw )) {
736+ suggested_package(" loo" , min_version = " 2.0.0" )
737+ if (! loo :: is.psis(psis_object )) {
738+ abort(" If specified, 'psis_object' must be a PSIS object from the loo package." )
739+ }
740+ lw <- loo :: weights.importance_sampling(psis_object )
741+ }
742+ lw
743+ }
0 commit comments