diff --git a/NAMESPACE b/NAMESPACE index 94b0ff5..8595ed9 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,8 @@ S3method(plot,BayesMallowsSMC2) S3method(print,BayesMallowsSMC2) +S3method(print,summary.BayesMallowsSMC2) +S3method(summary,BayesMallowsSMC2) export(compute_sequentially) export(precompute_topological_sorts) export(set_hyperparameters) diff --git a/R/print.R b/R/print.R index b827181..db8d1e3 100644 --- a/R/print.R +++ b/R/print.R @@ -76,3 +76,123 @@ print.BayesMallowsSMC2 <- function(x, ...) { invisible(x) } + +#' Summary Method for BayesMallowsSMC2 Objects +#' +#' Creates a summary of a BayesMallowsSMC2 object returned by +#' [compute_sequentially()]. +#' +#' @param object An object of class \code{BayesMallowsSMC2}. +#' @param ... Additional arguments (currently unused). +#' +#' @return An object of class \code{summary.BayesMallowsSMC2}, which is a list +#' containing summary information about the model. +#' +#' @details +#' The summary method creates a summary object that includes: +#' \itemize{ +#' \item Number of particles +#' \item Number of timepoints +#' \item Number of items +#' \item Number of clusters +#' \item Log marginal likelihood +#' \item Final effective sample size (ESS) +#' \item Number of resampling events +#' \item Posterior mean of alpha for each cluster +#' \item Posterior standard deviation of alpha for each cluster +#' } +#' +#' @export +#' +#' @examples +#' # Fit a model with complete rankings +#' set.seed(123) +#' mod <- compute_sequentially( +#' complete_rankings, +#' hyperparameters = set_hyperparameters(n_items = 5), +#' smc_options = set_smc_options(n_particles = 100, n_particle_filters = 1) +#' ) +#' +#' # Create summary +#' summary(mod) +#' +summary.BayesMallowsSMC2 <- function(object, ...) { + # Basic validation + if (!inherits(object, "BayesMallowsSMC2")) { + stop("object must be an object of class 'BayesMallowsSMC2'") + } + + required_fields <- c("alpha", "rho", "ESS", "resampling", "log_marginal_likelihood") + missing_fields <- setdiff(required_fields, names(object)) + if (length(missing_fields) > 0) { + stop("Object is missing required fields: ", paste(missing_fields, collapse = ", ")) + } + + # Extract dimensions + n_particles <- ncol(object$alpha) + n_timepoints <- length(object$ESS) + n_items <- dim(object$rho)[1] + n_clusters <- nrow(object$alpha) + + # Count resampling events + n_resampling_events <- sum(object$resampling) + + # Compute posterior mean and standard deviation of alpha + # alpha is a matrix where rows are clusters and columns are particles + alpha_mean <- rowMeans(object$alpha) + alpha_sd <- apply(object$alpha, 1, sd) + + # Create summary object + summary_obj <- list( + n_particles = n_particles, + n_timepoints = n_timepoints, + n_items = n_items, + n_clusters = n_clusters, + log_marginal_likelihood = object$log_marginal_likelihood, + final_ess = object$ESS[n_timepoints], + n_resampling_events = n_resampling_events, + alpha_mean = alpha_mean, + alpha_sd = alpha_sd + ) + + class(summary_obj) <- "summary.BayesMallowsSMC2" + summary_obj +} + +#' Print Method for summary.BayesMallowsSMC2 Objects +#' +#' Prints a summary of a BayesMallowsSMC2 model. +#' +#' @param x An object of class \code{summary.BayesMallowsSMC2}. +#' @param ... Additional arguments (currently unused). +#' +#' @return Invisibly returns the input object \code{x}. +#' +#' @export +#' +print.summary.BayesMallowsSMC2 <- function(x, ...) { + # Create header + cat("BayesMallowsSMC2 Model Summary\n") + cat(strrep("=", nchar("BayesMallowsSMC2 Model Summary")), "\n\n", sep = "") + + # Display basic information + cat(sprintf("%-25s %s\n", "Number of particles:", x$n_particles)) + cat(sprintf("%-25s %s\n", "Number of timepoints:", x$n_timepoints)) + cat(sprintf("%-25s %s\n", "Number of items:", x$n_items)) + cat(sprintf("%-25s %s\n\n", "Number of clusters:", x$n_clusters)) + + # Display model fit information + cat(sprintf("%-25s %.2f\n", "Log marginal likelihood:", x$log_marginal_likelihood)) + cat(sprintf("%-25s %.2f\n", "Final ESS:", x$final_ess)) + cat(sprintf("%-25s %d/%d\n\n", "Resampling events:", x$n_resampling_events, x$n_timepoints)) + + # Display posterior statistics for alpha + cat("Posterior Statistics for Alpha:\n") + cat(strrep("-", nchar("Posterior Statistics for Alpha:")), "\n", sep = "") + for (i in seq_along(x$alpha_mean)) { + cat(sprintf("Cluster %d: Mean = %.4f, SD = %.4f\n", + i, x$alpha_mean[i], x$alpha_sd[i])) + } + + invisible(x) +} diff --git a/tests/testthat/test-print.R b/tests/testthat/test-print.R index 71f3bc2..1292afa 100644 --- a/tests/testthat/test-print.R +++ b/tests/testthat/test-print.R @@ -39,3 +39,76 @@ test_that("print method works with partial rankings", { output <- capture.output(print(mod)) expect_true(any(grepl("BayesMallowsSMC2 Model", output))) }) + +test_that("summary method works for BayesMallowsSMC2 objects", { + set.seed(123) + mod <- compute_sequentially( + complete_rankings, + hyperparameters = set_hyperparameters(n_items = 5), + smc_options = set_smc_options(n_particles = 100, n_particle_filters = 1) + ) + + # Test that summary method runs without error + expect_error(summary(mod), NA) + + # Test that summary returns an object of the correct class + summ <- summary(mod) + expect_s3_class(summ, "summary.BayesMallowsSMC2") + + # Test that summary object contains expected fields + expect_true("n_particles" %in% names(summ)) + expect_true("n_timepoints" %in% names(summ)) + expect_true("n_items" %in% names(summ)) + expect_true("n_clusters" %in% names(summ)) + expect_true("log_marginal_likelihood" %in% names(summ)) + expect_true("final_ess" %in% names(summ)) + expect_true("n_resampling_events" %in% names(summ)) + expect_true("alpha_mean" %in% names(summ)) + expect_true("alpha_sd" %in% names(summ)) + + # Test that alpha statistics are numeric + expect_type(summ$alpha_mean, "double") + expect_type(summ$alpha_sd, "double") + + # Test that alpha statistics have correct length (equal to number of clusters) + expect_equal(length(summ$alpha_mean), summ$n_clusters) + expect_equal(length(summ$alpha_sd), summ$n_clusters) + + # Test that print method for summary works + expect_error(print(summ), NA) + + # Capture output and verify it contains expected content + output <- capture.output(print(summ)) + expect_true(any(grepl("BayesMallowsSMC2 Model Summary", output))) + expect_true(any(grepl("Number of particles:", output))) + expect_true(any(grepl("Number of timepoints:", output))) + expect_true(any(grepl("Number of items:", output))) + expect_true(any(grepl("Number of clusters:", output))) + expect_true(any(grepl("Log marginal likelihood:", output))) + expect_true(any(grepl("Final ESS:", output))) + expect_true(any(grepl("Resampling events:", output))) + expect_true(any(grepl("Posterior Statistics for Alpha:", output))) + expect_true(any(grepl("Mean =", output))) + expect_true(any(grepl("SD =", output))) +}) + +test_that("summary method works with partial rankings", { + set.seed(456) + mod <- compute_sequentially( + partial_rankings, + hyperparameters = set_hyperparameters(n_items = 5), + smc_options = set_smc_options(n_particles = 50, n_particle_filters = 1) + ) + + # Test that summary method runs without error + expect_error(summary(mod), NA) + + # Test summary object has correct structure + summ <- summary(mod) + expect_s3_class(summ, "summary.BayesMallowsSMC2") + + # Capture output and verify it contains expected content + output <- capture.output(print(summ)) + expect_true(any(grepl("BayesMallowsSMC2 Model Summary", output))) + expect_true(any(grepl("Posterior Statistics for Alpha:", output))) +})