diff --git a/DESCRIPTION b/DESCRIPTION index 74ffdb8..777ce7a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -24,10 +24,12 @@ LinkingTo: RcppArmadillo Imports: Rcpp, - ggplot2 + ggplot2, + Rdpack Depends: R (>= 4.1.0) Suggests: testthat (>= 3.0.0), label.switching (>= 1.8) +RdMacros: Rdpack Config/testthat/edition: 3 diff --git a/NAMESPACE b/NAMESPACE index 8595ed9..0fc5b3c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -8,5 +8,7 @@ export(compute_sequentially) export(precompute_topological_sorts) export(set_hyperparameters) export(set_smc_options) +export(trace_plot) importFrom(Rcpp,sourceCpp) +importFrom(Rdpack,reprompt) useDynLib(BayesMallowsSMC2, .registration = TRUE) diff --git a/R/BayesMallowsSMC2-package.R b/R/BayesMallowsSMC2-package.R index e492a96..1955c63 100644 --- a/R/BayesMallowsSMC2-package.R +++ b/R/BayesMallowsSMC2-package.R @@ -1,4 +1,5 @@ ## usethis namespace: start +#' @importFrom Rdpack reprompt #' @importFrom Rcpp sourceCpp ## usethis namespace: end NULL diff --git a/R/print.R b/R/print.R index db8d1e3..62f019a 100644 --- a/R/print.R +++ b/R/print.R @@ -43,37 +43,37 @@ print.BayesMallowsSMC2 <- function(x, ...) { if (!inherits(x, "BayesMallowsSMC2")) { stop("x must be an object of class 'BayesMallowsSMC2'") } - + required_fields <- c("alpha", "rho", "ESS", "resampling", "log_marginal_likelihood") missing_fields <- setdiff(required_fields, names(x)) if (length(missing_fields) > 0) { stop("Object is missing required fields: ", paste(missing_fields, collapse = ", ")) } - + # Extract dimensions n_particles <- ncol(x$alpha) n_timepoints <- length(x$ESS) n_items <- dim(x$rho)[1] n_clusters <- nrow(x$alpha) - + # Count resampling events n_resampling_events <- sum(x$resampling) - + # Create header cat("BayesMallowsSMC2 Model\n") cat(strrep("=", nchar("BayesMallowsSMC2 Model")), "\n\n", sep = "") - + # Display basic information cat(sprintf("%-25s %s\n", "Number of particles:", n_particles)) cat(sprintf("%-25s %s\n", "Number of timepoints:", n_timepoints)) cat(sprintf("%-25s %s\n", "Number of items:", n_items)) cat(sprintf("%-25s %s\n\n", "Number of clusters:", n_clusters)) - + # Display model fit information cat(sprintf("%-25s %.2f\n", "Log marginal likelihood:", x$log_marginal_likelihood)) cat(sprintf("%-25s %.2f\n", "Final ESS:", x$ESS[n_timepoints])) cat(sprintf("%-25s %d/%d\n", "Resampling events:", n_resampling_events, n_timepoints)) - + invisible(x) } @@ -121,27 +121,27 @@ summary.BayesMallowsSMC2 <- function(object, ...) { if (!inherits(object, "BayesMallowsSMC2")) { stop("object must be an object of class 'BayesMallowsSMC2'") } - + required_fields <- c("alpha", "rho", "ESS", "resampling", "log_marginal_likelihood") missing_fields <- setdiff(required_fields, names(object)) if (length(missing_fields) > 0) { stop("Object is missing required fields: ", paste(missing_fields, collapse = ", ")) } - + # Extract dimensions n_particles <- ncol(object$alpha) n_timepoints <- length(object$ESS) n_items <- dim(object$rho)[1] n_clusters <- nrow(object$alpha) - + # Count resampling events n_resampling_events <- sum(object$resampling) - + # Compute posterior mean and standard deviation of alpha # alpha is a matrix where rows are clusters and columns are particles alpha_mean <- rowMeans(object$alpha) - alpha_sd <- apply(object$alpha, 1, sd) - + alpha_sd <- apply(object$alpha, 1, stats::sd) + # Create summary object summary_obj <- list( n_particles = n_particles, @@ -154,7 +154,7 @@ summary.BayesMallowsSMC2 <- function(object, ...) { alpha_mean = alpha_mean, alpha_sd = alpha_sd ) - + class(summary_obj) <- "summary.BayesMallowsSMC2" summary_obj } @@ -174,25 +174,25 @@ print.summary.BayesMallowsSMC2 <- function(x, ...) { # Create header cat("BayesMallowsSMC2 Model Summary\n") cat(strrep("=", nchar("BayesMallowsSMC2 Model Summary")), "\n\n", sep = "") - + # Display basic information cat(sprintf("%-25s %s\n", "Number of particles:", x$n_particles)) cat(sprintf("%-25s %s\n", "Number of timepoints:", x$n_timepoints)) cat(sprintf("%-25s %s\n", "Number of items:", x$n_items)) cat(sprintf("%-25s %s\n\n", "Number of clusters:", x$n_clusters)) - + # Display model fit information cat(sprintf("%-25s %.2f\n", "Log marginal likelihood:", x$log_marginal_likelihood)) cat(sprintf("%-25s %.2f\n", "Final ESS:", x$final_ess)) cat(sprintf("%-25s %d/%d\n\n", "Resampling events:", x$n_resampling_events, x$n_timepoints)) - + # Display posterior statistics for alpha cat("Posterior Statistics for Alpha:\n") cat(strrep("-", nchar("Posterior Statistics for Alpha:")), "\n", sep = "") for (i in seq_along(x$alpha_mean)) { - cat(sprintf("Cluster %d: Mean = %.4f, SD = %.4f\n", + cat(sprintf("Cluster %d: Mean = %.4f, SD = %.4f\n", i, x$alpha_mean[i], x$alpha_sd[i])) } - + invisible(x) } diff --git a/R/trace_plot.R b/R/trace_plot.R new file mode 100644 index 0000000..95eab9f --- /dev/null +++ b/R/trace_plot.R @@ -0,0 +1,219 @@ +#' Create Trace Plots for BayesMallowsSMC2 Objects +#' +#' Visualize the timeseries dynamics of the alpha and tau parameters across +#' timepoints. This function creates trace plots similar to Figure 4 (left) in +#' \insertRef{10.1214/25-BA1564}{BayesMallowsSMC2}. +#' +#' @param x An object of class `BayesMallowsSMC2`, returned from +#' [compute_sequentially()] with `trace = TRUE` in [set_smc_options()]. +#' @param parameter Character string defining the parameter to plot. Available +#' options are `"alpha"` (default) and `"tau"`. +#' @param ... Other arguments (currently unused). +#' +#' @return A ggplot object showing the evolution of the parameter over time. +#' For each timepoint, the plot shows: +#' \itemize{ +#' \item The weighted mean (solid line) +#' \item The weighted 0.025 and 0.975 quantiles (shaded area representing +#' the 95% credible interval) +#' } +#' +#' @details +#' This function requires that the model was fitted with `trace = TRUE` in the +#' `smc_options`. The trace contains the parameter values at each timepoint, +#' which allows visualization of how the posterior distribution evolves as more +#' data arrives sequentially. +#' +#' For mixture models (multiple clusters), separate trace plots are created for +#' each cluster using faceting. +#' +#' The shaded area represents the 95% credible interval (from 2.5% to 97.5% +#' quantiles) of the posterior distribution at each timepoint, computed using +#' the importance weights from the SMC algorithm. +#' +#' @export +#' +#' @references +#' \insertRef{10.1214/25-BA1564}{BayesMallowsSMC2} +#' +#' @examples +#' \dontrun{ +#' # Fit a model with trace enabled +#' mod <- compute_sequentially( +#' complete_rankings, +#' hyperparameters = set_hyperparameters(n_items = 5), +#' smc_options = set_smc_options( +#' n_particles = 100, +#' n_particle_filters = 1, +#' trace = TRUE +#' ) +#' ) +#' +#' # Create trace plot for alpha (default) +#' trace_plot(mod) +#' } +trace_plot <- function(x, parameter = "alpha", ...) { + # Validate parameter + parameter <- match.arg(parameter, c("alpha", "tau")) + + # Basic validation + if (!inherits(x, "BayesMallowsSMC2")) { + stop("x must be an object of class 'BayesMallowsSMC2'") + } + + # Check if trace was enabled + trace_field <- paste0(parameter, "_traces") + if (!trace_field %in% names(x)) { + stop("Trace data not found. Please run compute_sequentially with trace = TRUE in set_smc_options().") + } + + traces <- x[[trace_field]] + + if (length(traces) == 0) { + stop("Trace data not found. Please run compute_sequentially with trace = TRUE in set_smc_options().") + } + + # Check for importance weights trace + if (!"log_importance_weights_traces" %in% names(x)) { + stop("Importance weights trace not found. This should not happen if trace = TRUE was used.") + } + + log_weights_traces <- x$log_importance_weights_traces + + if (parameter == "alpha") { + plot_trace_alpha_tau(traces, log_weights_traces, parameter_name = "alpha", + parameter_label = expression(alpha)) + } else if (parameter == "tau") { + plot_trace_alpha_tau(traces, log_weights_traces, parameter_name = "tau", + parameter_label = expression(tau)) + } +} + +# Avoid R CMD check NOTE about undefined global variables used in ggplot2::aes() +utils::globalVariables(c("timepoint", "mean", "lower", "upper", "cluster")) + + +# Internal function to plot trace for alpha or tau parameter +# @param traces List of matrices, one per timepoint. Each matrix is [n_clusters x n_particles] +# @param log_weights_traces List of vectors, one per timepoint. Each vector is length n_particles +# @param parameter_name Character string, name of the parameter +# @param parameter_label Expression for axis label +plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name, + parameter_label) { + n_timepoints <- length(traces) + + # Get dimensions from first trace + # Need to infer n_clusters and n_particles from the trace + first_trace <- traces[[1]] + + # If trace is a vector, need to infer dimensions + if (is.vector(first_trace)) { + # Get n_particles from log_weights + n_particles <- length(log_weights_traces[[1]]) + + # If trace is a vector, infer n_clusters from its length + trace_length <- length(first_trace) + if (trace_length %% n_particles == 0) { + n_clusters <- trace_length %/% n_particles + } else { + stop(sprintf("Trace length (%d) is not divisible by n_particles (%d). ", + trace_length, n_particles), + "This indicates inconsistent dimensions in the trace data.") + } + + # Convert all traces to matrices [n_clusters x n_particles] + # The C++ code stores traces as: alpha is [n_clusters x n_particles] matrix per timepoint + # When passed to R as vector, elements are in column-major order: + # cluster1_particle1, cluster2_particle1, cluster1_particle2, cluster2_particle2, ... + traces <- lapply(traces, function(t) { + matrix(t, nrow = n_clusters, ncol = n_particles, byrow = FALSE) + }) + first_trace <- traces[[1]] + } else if (is.matrix(first_trace)) { + n_clusters <- nrow(first_trace) + n_particles <- ncol(first_trace) + } else { + stop("Trace elements must be vectors or matrices") + } + + # Create data frame for plotting + plot_data_list <- vector("list", n_timepoints * n_clusters) + idx <- 1 + + for (t in seq_len(n_timepoints)) { + param_matrix <- traces[[t]] + log_weights <- log_weights_traces[[t]] + + # Normalize weights + weights <- exp(log_weights - max(log_weights)) + weights <- weights / sum(weights) + + for (cluster in seq_len(n_clusters)) { + param_values <- param_matrix[cluster, ] + + # Compute weighted statistics + weighted_mean <- stats::weighted.mean(param_values, weights) + weighted_quantiles <- weighted_quantile(param_values, weights, + probs = c(0.025, 0.975)) + + plot_data_list[[idx]] <- data.frame( + timepoint = t, + mean = weighted_mean, + lower = weighted_quantiles[1], + upper = weighted_quantiles[2], + cluster = if (n_clusters > 1) paste0("Cluster ", cluster) else "All" + ) + idx <- idx + 1 + } + } + + plot_data <- do.call(rbind, plot_data_list) + + # Create line plot with credible interval + p <- ggplot2::ggplot(plot_data, ggplot2::aes(x = timepoint, y = mean)) + + ggplot2::geom_ribbon(ggplot2::aes(ymin = lower, ymax = upper), + alpha = 0.3, fill = "steelblue") + + ggplot2::geom_line(color = "darkblue", linewidth = 1) + + ggplot2::xlab("Timepoint") + + ggplot2::ylab(parameter_label) + + ggplot2::theme_minimal() + + ggplot2::theme( + panel.grid.minor = ggplot2::element_blank() + ) + + # Add faceting if multiple clusters + if (n_clusters > 1) { + p <- p + ggplot2::facet_wrap(~ cluster, scales = "free_y") + } + + p +} + + +# Internal helper function to compute weighted quantiles +# @param x Numeric vector of values +# @param weights Numeric vector of weights +# @param probs Numeric vector of probabilities +# @return Numeric vector of quantiles +weighted_quantile <- function(x, weights, probs) { + # Sort x and weights by x + ord <- order(x) + x_sorted <- x[ord] + weights_sorted <- weights[ord] + + # Compute cumulative weights + cum_weights <- cumsum(weights_sorted) / sum(weights_sorted) + + # Find quantiles + quantiles <- numeric(length(probs)) + for (i in seq_along(probs)) { + # Find first position where cumulative weight exceeds prob + idx <- which(cum_weights >= probs[i])[1] + if (is.na(idx)) { + idx <- length(x_sorted) + } + quantiles[i] <- x_sorted[idx] + } + + quantiles +} diff --git a/man/print.summary.BayesMallowsSMC2.Rd b/man/print.summary.BayesMallowsSMC2.Rd new file mode 100644 index 0000000..c60871c --- /dev/null +++ b/man/print.summary.BayesMallowsSMC2.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/print.R +\name{print.summary.BayesMallowsSMC2} +\alias{print.summary.BayesMallowsSMC2} +\title{Print Method for summary.BayesMallowsSMC2 Objects} +\usage{ +\method{print}{summary.BayesMallowsSMC2}(x, ...) +} +\arguments{ +\item{x}{An object of class \code{summary.BayesMallowsSMC2}.} + +\item{...}{Additional arguments (currently unused).} +} +\value{ +Invisibly returns the input object \code{x}. +} +\description{ +Prints a summary of a BayesMallowsSMC2 model. +} diff --git a/man/summary.BayesMallowsSMC2.Rd b/man/summary.BayesMallowsSMC2.Rd new file mode 100644 index 0000000..93fa81c --- /dev/null +++ b/man/summary.BayesMallowsSMC2.Rd @@ -0,0 +1,48 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/print.R +\name{summary.BayesMallowsSMC2} +\alias{summary.BayesMallowsSMC2} +\title{Summary Method for BayesMallowsSMC2 Objects} +\usage{ +\method{summary}{BayesMallowsSMC2}(object, ...) +} +\arguments{ +\item{object}{An object of class \code{BayesMallowsSMC2}.} + +\item{...}{Additional arguments (currently unused).} +} +\value{ +An object of class \code{summary.BayesMallowsSMC2}, which is a list +containing summary information about the model. +} +\description{ +Creates a summary of a BayesMallowsSMC2 object returned by +\code{\link[=compute_sequentially]{compute_sequentially()}}. +} +\details{ +The summary method creates a summary object that includes: +\itemize{ +\item Number of particles +\item Number of timepoints +\item Number of items +\item Number of clusters +\item Log marginal likelihood +\item Final effective sample size (ESS) +\item Number of resampling events +\item Posterior mean of alpha for each cluster +\item Posterior standard deviation of alpha for each cluster +} +} +\examples{ +# Fit a model with complete rankings +set.seed(123) +mod <- compute_sequentially( + complete_rankings, + hyperparameters = set_hyperparameters(n_items = 5), + smc_options = set_smc_options(n_particles = 100, n_particle_filters = 1) +) + +# Create summary +summary(mod) + +} diff --git a/man/trace_plot.Rd b/man/trace_plot.Rd new file mode 100644 index 0000000..9987945 --- /dev/null +++ b/man/trace_plot.Rd @@ -0,0 +1,64 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/trace_plot.R +\name{trace_plot} +\alias{trace_plot} +\title{Create Trace Plots for BayesMallowsSMC2 Objects} +\usage{ +trace_plot(x, parameter = "alpha", ...) +} +\arguments{ +\item{x}{An object of class \code{BayesMallowsSMC2}, returned from +\code{\link[=compute_sequentially]{compute_sequentially()}} with \code{trace = TRUE} in \code{\link[=set_smc_options]{set_smc_options()}}.} + +\item{parameter}{Character string defining the parameter to plot. Available +options are \code{"alpha"} (default) and \code{"tau"}.} + +\item{...}{Other arguments (currently unused).} +} +\value{ +A ggplot object showing the evolution of the parameter over time. +For each timepoint, the plot shows: +\itemize{ +\item The weighted mean (solid line) +\item The weighted 0.025 and 0.975 quantiles (shaded area representing +the 95\% credible interval) +} +} +\description{ +Visualize the timeseries dynamics of the alpha and tau parameters across +timepoints. This function creates trace plots similar to Figure 4 (left) in +\insertRef{10.1214/25-BA1564}{BayesMallowsSMC2}. +} +\details{ +This function requires that the model was fitted with \code{trace = TRUE} in the +\code{smc_options}. The trace contains the parameter values at each timepoint, +which allows visualization of how the posterior distribution evolves as more +data arrives sequentially. + +For mixture models (multiple clusters), separate trace plots are created for +each cluster using faceting. + +The shaded area represents the 95\% credible interval (from 2.5\% to 97.5\% +quantiles) of the posterior distribution at each timepoint, computed using +the importance weights from the SMC algorithm. +} +\examples{ +\dontrun{ +# Fit a model with trace enabled +mod <- compute_sequentially( + complete_rankings, + hyperparameters = set_hyperparameters(n_items = 5), + smc_options = set_smc_options( + n_particles = 100, + n_particle_filters = 1, + trace = TRUE + ) +) + +# Create trace plot for alpha (default) +trace_plot(mod) +} +} +\references{ +\insertRef{10.1214/25-BA1564}{BayesMallowsSMC2} +} diff --git a/tests/testthat/test-trace_plot.R b/tests/testthat/test-trace_plot.R new file mode 100644 index 0000000..4e2ec3b --- /dev/null +++ b/tests/testthat/test-trace_plot.R @@ -0,0 +1,100 @@ +test_that("trace_plot works with alpha parameter", { + set.seed(123) + mod <- compute_sequentially( + complete_rankings, + hyperparameters = set_hyperparameters(n_items = 5), + smc_options = set_smc_options( + n_particles = 100, + n_particle_filters = 1, + trace = TRUE + ) + ) + + # Test that trace_plot function runs without error for alpha + expect_no_error(trace_plot(mod, parameter = "alpha")) + expect_no_error(trace_plot(mod)) # alpha is default + + # Check that it returns a ggplot object + p <- trace_plot(mod, parameter = "alpha") + expect_s3_class(p, "ggplot") +}) + +test_that("trace_plot works with tau parameter", { + set.seed(123) + mod <- compute_sequentially( + complete_rankings, + hyperparameters = set_hyperparameters(n_items = 5), + smc_options = set_smc_options( + n_particles = 100, + n_particle_filters = 1, + trace = TRUE + ) + ) + + # Test that trace_plot function runs without error for tau + expect_no_error(trace_plot(mod, parameter = "tau")) + + # Check that it returns a ggplot object + p <- trace_plot(mod, parameter = "tau") + expect_s3_class(p, "ggplot") +}) + +test_that("trace_plot validates parameter argument", { + set.seed(123) + mod <- compute_sequentially( + complete_rankings, + hyperparameters = set_hyperparameters(n_items = 5), + smc_options = set_smc_options( + n_particles = 100, + n_particle_filters = 1, + trace = TRUE + ) + ) + + # Test invalid parameter + expect_error(trace_plot(mod, parameter = "invalid"), "should be one of") +}) + +test_that("trace_plot validates object class", { + # Test with non-BayesMallowsSMC2 object + fake_obj <- list(alpha_traces = list(matrix(1:10, 2, 5))) + expect_error(trace_plot(fake_obj), "must be an object of class") +}) + +test_that("trace_plot requires trace = TRUE", { + set.seed(123) + mod <- compute_sequentially( + complete_rankings, + hyperparameters = set_hyperparameters(n_items = 5), + smc_options = set_smc_options( + n_particles = 100, + n_particle_filters = 1, + trace = FALSE # trace disabled + ) + ) + + # Test that it gives a helpful error message + expect_error(trace_plot(mod), "Trace data not found") + expect_error(trace_plot(mod), "trace = TRUE") +}) + +test_that("trace_plot works with mixture models", { + # Create mixture data (using partial rankings as a proxy) + set.seed(456) + mod <- compute_sequentially( + partial_rankings[1:20, ], + hyperparameters = set_hyperparameters(n_items = 5, n_clusters = 2), + smc_options = set_smc_options( + n_particles = 50, + n_particle_filters = 1, + trace = TRUE + ) + ) + + # Test that trace_plot works with multiple clusters + expect_no_error(p <- trace_plot(mod, parameter = "alpha")) + expect_s3_class(p, "ggplot") + + expect_no_error(p <- trace_plot(mod, parameter = "tau")) + expect_s3_class(p, "ggplot") +})