@@ -277,6 +277,9 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
277277# ' of rank-normalized MCMC samples. Defaults to `20`.
278278# ' @param ref_line For the rank plots, whether to draw a horizontal line at the
279279# ' average number of ranks per bin. Defaults to `FALSE`.
280+ # ' @param split_chains Logical indicating whether to split each chain into two parts.
281+ # ' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
282+ # ' Defaults to `FALSE`.
280283# ' @export
281284mcmc_rank_overlay <- function (x ,
282285 pars = character (),
@@ -285,7 +288,8 @@ mcmc_rank_overlay <- function(x,
285288 facet_args = list (),
286289 ... ,
287290 n_bins = 20 ,
288- ref_line = FALSE ) {
291+ ref_line = FALSE ,
292+ split_chains = FALSE ) {
289293 check_ignored_arguments(... )
290294 data <- mcmc_trace_data(
291295 x ,
@@ -294,7 +298,28 @@ mcmc_rank_overlay <- function(x,
294298 transformations = transformations
295299 )
296300
297- n_chains <- unique(data $ n_chains )
301+ # Split chains if requested
302+ if (split_chains ) {
303+ data $ n_chains = data $ n_chains / 2
304+ data $ n_iterations = data $ n_iterations / 2
305+ # Calculate midpoint for each chain
306+ n_samples <- length(unique(data $ iteration ))
307+ midpoint <- n_samples / 2
308+
309+ # Create new data frame with split chains
310+ data <- data %> %
311+ group_by(.data $ chain ) %> %
312+ mutate(
313+ chain = ifelse(
314+ iteration < = midpoint ,
315+ paste0(.data $ chain , " _1" ),
316+ paste0(.data $ chain , " _2" )
317+ )
318+ ) %> %
319+ ungroup()
320+ }
321+
322+ n_chains <- length(unique(data $ chain ))
298323 n_param <- unique(data $ n_parameters )
299324
300325 # We have to bin and count the data ourselves because
@@ -319,6 +344,7 @@ mcmc_rank_overlay <- function(x,
319344 bin_start = unique(histobins $ bin_start ),
320345 stringsAsFactors = FALSE
321346 ))
347+
322348 d_bin_counts <- all_combos %> %
323349 left_join(d_bin_counts , by = c(" parameter" , " chain" , " bin_start" )) %> %
324350 mutate(n = dplyr :: if_else(is.na(n ), 0L , n ))
@@ -331,7 +357,9 @@ mcmc_rank_overlay <- function(x,
331357 mutate(bin_start = right_edge ) %> %
332358 dplyr :: bind_rows(d_bin_counts )
333359
334- scale_color <- scale_color_manual(" Chain" , values = chain_colors(n_chains ))
360+ # Update legend title based on split_chains
361+ legend_title <- if (split_chains ) " Split Chains" else " Chain"
362+ scale_color <- scale_color_manual(legend_title , values = chain_colors(n_chains ))
335363
336364 layer_ref_line <- if (ref_line ) {
337365 geom_hline(
@@ -352,7 +380,7 @@ mcmc_rank_overlay <- function(x,
352380 }
353381
354382 ggplot(d_bin_counts ) +
355- aes(x = .data $ bin_start , y = .data $ n , color = .data $ chain ) +
383+ aes(x = .data $ bin_start , y = .data $ n , color = .data $ chain ) +
356384 geom_step() +
357385 layer_ref_line +
358386 facet_call +
0 commit comments