Skip to content

Commit 7214c99

Browse files
authored
Merge pull request #60 from osorensen/copilot/create-trace-plot-functions
Add trace_plot() for visualizing parameter evolution over timepoints
2 parents 1e39054 + 7adc9e3 commit 7214c99

File tree

9 files changed

+475
-20
lines changed

9 files changed

+475
-20
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ LinkingTo:
2424
RcppArmadillo
2525
Imports:
2626
Rcpp,
27-
ggplot2
27+
ggplot2,
28+
Rdpack
2829
Depends:
2930
R (>= 4.1.0)
3031
Suggests:
3132
testthat (>= 3.0.0),
3233
label.switching (>= 1.8)
34+
RdMacros: Rdpack
3335
Config/testthat/edition: 3

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ export(compute_sequentially)
88
export(precompute_topological_sorts)
99
export(set_hyperparameters)
1010
export(set_smc_options)
11+
export(trace_plot)
1112
importFrom(Rcpp,sourceCpp)
13+
importFrom(Rdpack,reprompt)
1214
useDynLib(BayesMallowsSMC2, .registration = TRUE)

R/BayesMallowsSMC2-package.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
## usethis namespace: start
2+
#' @importFrom Rdpack reprompt
23
#' @importFrom Rcpp sourceCpp
34
## usethis namespace: end
45
NULL

R/print.R

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,37 +43,37 @@ print.BayesMallowsSMC2 <- function(x, ...) {
4343
if (!inherits(x, "BayesMallowsSMC2")) {
4444
stop("x must be an object of class 'BayesMallowsSMC2'")
4545
}
46-
46+
4747
required_fields <- c("alpha", "rho", "ESS", "resampling", "log_marginal_likelihood")
4848
missing_fields <- setdiff(required_fields, names(x))
4949
if (length(missing_fields) > 0) {
5050
stop("Object is missing required fields: ", paste(missing_fields, collapse = ", "))
5151
}
52-
52+
5353
# Extract dimensions
5454
n_particles <- ncol(x$alpha)
5555
n_timepoints <- length(x$ESS)
5656
n_items <- dim(x$rho)[1]
5757
n_clusters <- nrow(x$alpha)
58-
58+
5959
# Count resampling events
6060
n_resampling_events <- sum(x$resampling)
61-
61+
6262
# Create header
6363
cat("BayesMallowsSMC2 Model\n")
6464
cat(strrep("=", nchar("BayesMallowsSMC2 Model")), "\n\n", sep = "")
65-
65+
6666
# Display basic information
6767
cat(sprintf("%-25s %s\n", "Number of particles:", n_particles))
6868
cat(sprintf("%-25s %s\n", "Number of timepoints:", n_timepoints))
6969
cat(sprintf("%-25s %s\n", "Number of items:", n_items))
7070
cat(sprintf("%-25s %s\n\n", "Number of clusters:", n_clusters))
71-
71+
7272
# Display model fit information
7373
cat(sprintf("%-25s %.2f\n", "Log marginal likelihood:", x$log_marginal_likelihood))
7474
cat(sprintf("%-25s %.2f\n", "Final ESS:", x$ESS[n_timepoints]))
7575
cat(sprintf("%-25s %d/%d\n", "Resampling events:", n_resampling_events, n_timepoints))
76-
76+
7777
invisible(x)
7878
}
7979

@@ -121,27 +121,27 @@ summary.BayesMallowsSMC2 <- function(object, ...) {
121121
if (!inherits(object, "BayesMallowsSMC2")) {
122122
stop("object must be an object of class 'BayesMallowsSMC2'")
123123
}
124-
124+
125125
required_fields <- c("alpha", "rho", "ESS", "resampling", "log_marginal_likelihood")
126126
missing_fields <- setdiff(required_fields, names(object))
127127
if (length(missing_fields) > 0) {
128128
stop("Object is missing required fields: ", paste(missing_fields, collapse = ", "))
129129
}
130-
130+
131131
# Extract dimensions
132132
n_particles <- ncol(object$alpha)
133133
n_timepoints <- length(object$ESS)
134134
n_items <- dim(object$rho)[1]
135135
n_clusters <- nrow(object$alpha)
136-
136+
137137
# Count resampling events
138138
n_resampling_events <- sum(object$resampling)
139-
139+
140140
# Compute posterior mean and standard deviation of alpha
141141
# alpha is a matrix where rows are clusters and columns are particles
142142
alpha_mean <- rowMeans(object$alpha)
143-
alpha_sd <- apply(object$alpha, 1, sd)
144-
143+
alpha_sd <- apply(object$alpha, 1, stats::sd)
144+
145145
# Create summary object
146146
summary_obj <- list(
147147
n_particles = n_particles,
@@ -154,7 +154,7 @@ summary.BayesMallowsSMC2 <- function(object, ...) {
154154
alpha_mean = alpha_mean,
155155
alpha_sd = alpha_sd
156156
)
157-
157+
158158
class(summary_obj) <- "summary.BayesMallowsSMC2"
159159
summary_obj
160160
}
@@ -174,25 +174,25 @@ print.summary.BayesMallowsSMC2 <- function(x, ...) {
174174
# Create header
175175
cat("BayesMallowsSMC2 Model Summary\n")
176176
cat(strrep("=", nchar("BayesMallowsSMC2 Model Summary")), "\n\n", sep = "")
177-
177+
178178
# Display basic information
179179
cat(sprintf("%-25s %s\n", "Number of particles:", x$n_particles))
180180
cat(sprintf("%-25s %s\n", "Number of timepoints:", x$n_timepoints))
181181
cat(sprintf("%-25s %s\n", "Number of items:", x$n_items))
182182
cat(sprintf("%-25s %s\n\n", "Number of clusters:", x$n_clusters))
183-
183+
184184
# Display model fit information
185185
cat(sprintf("%-25s %.2f\n", "Log marginal likelihood:", x$log_marginal_likelihood))
186186
cat(sprintf("%-25s %.2f\n", "Final ESS:", x$final_ess))
187187
cat(sprintf("%-25s %d/%d\n\n", "Resampling events:", x$n_resampling_events, x$n_timepoints))
188-
188+
189189
# Display posterior statistics for alpha
190190
cat("Posterior Statistics for Alpha:\n")
191191
cat(strrep("-", nchar("Posterior Statistics for Alpha:")), "\n", sep = "")
192192
for (i in seq_along(x$alpha_mean)) {
193-
cat(sprintf("Cluster %d: Mean = %.4f, SD = %.4f\n",
193+
cat(sprintf("Cluster %d: Mean = %.4f, SD = %.4f\n",
194194
i, x$alpha_mean[i], x$alpha_sd[i]))
195195
}
196-
196+
197197
invisible(x)
198198
}

R/trace_plot.R

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#' Create Trace Plots for BayesMallowsSMC2 Objects
2+
#'
3+
#' Visualize the timeseries dynamics of the alpha and tau parameters across
4+
#' timepoints. This function creates trace plots similar to Figure 4 (left) in
5+
#' \insertRef{10.1214/25-BA1564}{BayesMallowsSMC2}.
6+
#'
7+
#' @param x An object of class `BayesMallowsSMC2`, returned from
8+
#' [compute_sequentially()] with `trace = TRUE` in [set_smc_options()].
9+
#' @param parameter Character string defining the parameter to plot. Available
10+
#' options are `"alpha"` (default) and `"tau"`.
11+
#' @param ... Other arguments (currently unused).
12+
#'
13+
#' @return A ggplot object showing the evolution of the parameter over time.
14+
#' For each timepoint, the plot shows:
15+
#' \itemize{
16+
#' \item The weighted mean (solid line)
17+
#' \item The weighted 0.025 and 0.975 quantiles (shaded area representing
18+
#' the 95% credible interval)
19+
#' }
20+
#'
21+
#' @details
22+
#' This function requires that the model was fitted with `trace = TRUE` in the
23+
#' `smc_options`. The trace contains the parameter values at each timepoint,
24+
#' which allows visualization of how the posterior distribution evolves as more
25+
#' data arrives sequentially.
26+
#'
27+
#' For mixture models (multiple clusters), separate trace plots are created for
28+
#' each cluster using faceting.
29+
#'
30+
#' The shaded area represents the 95% credible interval (from 2.5% to 97.5%
31+
#' quantiles) of the posterior distribution at each timepoint, computed using
32+
#' the importance weights from the SMC algorithm.
33+
#'
34+
#' @export
35+
#'
36+
#' @references
37+
#' \insertRef{10.1214/25-BA1564}{BayesMallowsSMC2}
38+
#'
39+
#' @examples
40+
#' \dontrun{
41+
#' # Fit a model with trace enabled
42+
#' mod <- compute_sequentially(
43+
#' complete_rankings,
44+
#' hyperparameters = set_hyperparameters(n_items = 5),
45+
#' smc_options = set_smc_options(
46+
#' n_particles = 100,
47+
#' n_particle_filters = 1,
48+
#' trace = TRUE
49+
#' )
50+
#' )
51+
#'
52+
#' # Create trace plot for alpha (default)
53+
#' trace_plot(mod)
54+
#' }
55+
trace_plot <- function(x, parameter = "alpha", ...) {
56+
# Validate parameter
57+
parameter <- match.arg(parameter, c("alpha", "tau"))
58+
59+
# Basic validation
60+
if (!inherits(x, "BayesMallowsSMC2")) {
61+
stop("x must be an object of class 'BayesMallowsSMC2'")
62+
}
63+
64+
# Check if trace was enabled
65+
trace_field <- paste0(parameter, "_traces")
66+
if (!trace_field %in% names(x)) {
67+
stop("Trace data not found. Please run compute_sequentially with trace = TRUE in set_smc_options().")
68+
}
69+
70+
traces <- x[[trace_field]]
71+
72+
if (length(traces) == 0) {
73+
stop("Trace data not found. Please run compute_sequentially with trace = TRUE in set_smc_options().")
74+
}
75+
76+
# Check for importance weights trace
77+
if (!"log_importance_weights_traces" %in% names(x)) {
78+
stop("Importance weights trace not found. This should not happen if trace = TRUE was used.")
79+
}
80+
81+
log_weights_traces <- x$log_importance_weights_traces
82+
83+
if (parameter == "alpha") {
84+
plot_trace_alpha_tau(traces, log_weights_traces, parameter_name = "alpha",
85+
parameter_label = expression(alpha))
86+
} else if (parameter == "tau") {
87+
plot_trace_alpha_tau(traces, log_weights_traces, parameter_name = "tau",
88+
parameter_label = expression(tau))
89+
}
90+
}
91+
92+
# Avoid R CMD check NOTE about undefined global variables used in ggplot2::aes()
93+
utils::globalVariables(c("timepoint", "mean", "lower", "upper", "cluster"))
94+
95+
96+
# Internal function to plot trace for alpha or tau parameter
97+
# @param traces List of matrices, one per timepoint. Each matrix is [n_clusters x n_particles]
98+
# @param log_weights_traces List of vectors, one per timepoint. Each vector is length n_particles
99+
# @param parameter_name Character string, name of the parameter
100+
# @param parameter_label Expression for axis label
101+
plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
102+
parameter_label) {
103+
n_timepoints <- length(traces)
104+
105+
# Get dimensions from first trace
106+
# Need to infer n_clusters and n_particles from the trace
107+
first_trace <- traces[[1]]
108+
109+
# If trace is a vector, need to infer dimensions
110+
if (is.vector(first_trace)) {
111+
# Get n_particles from log_weights
112+
n_particles <- length(log_weights_traces[[1]])
113+
114+
# If trace is a vector, infer n_clusters from its length
115+
trace_length <- length(first_trace)
116+
if (trace_length %% n_particles == 0) {
117+
n_clusters <- trace_length %/% n_particles
118+
} else {
119+
stop(sprintf("Trace length (%d) is not divisible by n_particles (%d). ",
120+
trace_length, n_particles),
121+
"This indicates inconsistent dimensions in the trace data.")
122+
}
123+
124+
# Convert all traces to matrices [n_clusters x n_particles]
125+
# The C++ code stores traces as: alpha is [n_clusters x n_particles] matrix per timepoint
126+
# When passed to R as vector, elements are in column-major order:
127+
# cluster1_particle1, cluster2_particle1, cluster1_particle2, cluster2_particle2, ...
128+
traces <- lapply(traces, function(t) {
129+
matrix(t, nrow = n_clusters, ncol = n_particles, byrow = FALSE)
130+
})
131+
first_trace <- traces[[1]]
132+
} else if (is.matrix(first_trace)) {
133+
n_clusters <- nrow(first_trace)
134+
n_particles <- ncol(first_trace)
135+
} else {
136+
stop("Trace elements must be vectors or matrices")
137+
}
138+
139+
# Create data frame for plotting
140+
plot_data_list <- vector("list", n_timepoints * n_clusters)
141+
idx <- 1
142+
143+
for (t in seq_len(n_timepoints)) {
144+
param_matrix <- traces[[t]]
145+
log_weights <- log_weights_traces[[t]]
146+
147+
# Normalize weights
148+
weights <- exp(log_weights - max(log_weights))
149+
weights <- weights / sum(weights)
150+
151+
for (cluster in seq_len(n_clusters)) {
152+
param_values <- param_matrix[cluster, ]
153+
154+
# Compute weighted statistics
155+
weighted_mean <- stats::weighted.mean(param_values, weights)
156+
weighted_quantiles <- weighted_quantile(param_values, weights,
157+
probs = c(0.025, 0.975))
158+
159+
plot_data_list[[idx]] <- data.frame(
160+
timepoint = t,
161+
mean = weighted_mean,
162+
lower = weighted_quantiles[1],
163+
upper = weighted_quantiles[2],
164+
cluster = if (n_clusters > 1) paste0("Cluster ", cluster) else "All"
165+
)
166+
idx <- idx + 1
167+
}
168+
}
169+
170+
plot_data <- do.call(rbind, plot_data_list)
171+
172+
# Create line plot with credible interval
173+
p <- ggplot2::ggplot(plot_data, ggplot2::aes(x = timepoint, y = mean)) +
174+
ggplot2::geom_ribbon(ggplot2::aes(ymin = lower, ymax = upper),
175+
alpha = 0.3, fill = "steelblue") +
176+
ggplot2::geom_line(color = "darkblue", linewidth = 1) +
177+
ggplot2::xlab("Timepoint") +
178+
ggplot2::ylab(parameter_label) +
179+
ggplot2::theme_minimal() +
180+
ggplot2::theme(
181+
panel.grid.minor = ggplot2::element_blank()
182+
)
183+
184+
# Add faceting if multiple clusters
185+
if (n_clusters > 1) {
186+
p <- p + ggplot2::facet_wrap(~ cluster, scales = "free_y")
187+
}
188+
189+
p
190+
}
191+
192+
193+
# Internal helper function to compute weighted quantiles
194+
# @param x Numeric vector of values
195+
# @param weights Numeric vector of weights
196+
# @param probs Numeric vector of probabilities
197+
# @return Numeric vector of quantiles
198+
weighted_quantile <- function(x, weights, probs) {
199+
# Sort x and weights by x
200+
ord <- order(x)
201+
x_sorted <- x[ord]
202+
weights_sorted <- weights[ord]
203+
204+
# Compute cumulative weights
205+
cum_weights <- cumsum(weights_sorted) / sum(weights_sorted)
206+
207+
# Find quantiles
208+
quantiles <- numeric(length(probs))
209+
for (i in seq_along(probs)) {
210+
# Find first position where cumulative weight exceeds prob
211+
idx <- which(cum_weights >= probs[i])[1]
212+
if (is.na(idx)) {
213+
idx <- length(x_sorted)
214+
}
215+
quantiles[i] <- x_sorted[idx]
216+
}
217+
218+
quantiles
219+
}

man/print.summary.BayesMallowsSMC2.Rd

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)