@@ -277,6 +277,9 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
277277# ' of rank-normalized MCMC samples. Defaults to `20`.
278278# ' @param ref_line For the rank plots, whether to draw a horizontal line at the
279279# ' average number of ranks per bin. Defaults to `FALSE`.
280+ # ' @param split_chains Logical indicating whether to split each chain into two parts.
281+ # ' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
282+ # ' Defaults to `FALSE`.
280283# ' @export
281284mcmc_rank_overlay <- function (x ,
282285 pars = character (),
@@ -285,7 +288,8 @@ mcmc_rank_overlay <- function(x,
285288 facet_args = list (),
286289 ... ,
287290 n_bins = 20 ,
288- ref_line = FALSE ) {
291+ ref_line = FALSE ,
292+ split_chains = FALSE ) {
289293 check_ignored_arguments(... )
290294 data <- mcmc_trace_data(
291295 x ,
@@ -294,7 +298,26 @@ mcmc_rank_overlay <- function(x,
294298 transformations = transformations
295299 )
296300
297- n_chains <- unique(data $ n_chains )
301+ # Split chains if requested
302+ if (split_chains ) {
303+ # Calculate midpoint for each chain
304+ n_samples <- length(unique(data $ iteration ))
305+ midpoint <- n_samples / 2
306+
307+ # Create new data frame with split chains
308+ data <- data %> %
309+ group_by(.data $ chain ) %> %
310+ mutate(
311+ chain = ifelse(
312+ iteration < = midpoint ,
313+ paste0(.data $ chain , " _1" ),
314+ paste0(.data $ chain , " _2" )
315+ )
316+ ) %> %
317+ ungroup()
318+ }
319+
320+ n_chains <- length(unique(data $ chain ))
298321 n_param <- unique(data $ n_parameters )
299322
300323 # We have to bin and count the data ourselves because
@@ -319,6 +342,7 @@ mcmc_rank_overlay <- function(x,
319342 bin_start = unique(histobins $ bin_start ),
320343 stringsAsFactors = FALSE
321344 ))
345+
322346 d_bin_counts <- all_combos %> %
323347 left_join(d_bin_counts , by = c(" parameter" , " chain" , " bin_start" )) %> %
324348 mutate(n = dplyr :: if_else(is.na(n ), 0L , n ))
@@ -331,7 +355,9 @@ mcmc_rank_overlay <- function(x,
331355 mutate(bin_start = right_edge ) %> %
332356 dplyr :: bind_rows(d_bin_counts )
333357
334- scale_color <- scale_color_manual(" Chain" , values = chain_colors(n_chains ))
358+ # Update legend title based on split_chains
359+ legend_title <- if (split_chains ) " Split Chains" else " Chain"
360+ scale_color <- scale_color_manual(legend_title , values = chain_colors(n_chains ))
335361
336362 layer_ref_line <- if (ref_line ) {
337363 geom_hline(
@@ -352,7 +378,7 @@ mcmc_rank_overlay <- function(x,
352378 }
353379
354380 ggplot(d_bin_counts ) +
355- aes(x = .data $ bin_start , y = .data $ n , color = .data $ chain ) +
381+ aes(x = .data $ bin_start , y = .data $ n , color = .data $ chain ) +
356382 geom_step() +
357383 layer_ref_line +
358384 facet_call +
0 commit comments