Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

S3method(plot,BayesMallowsSMC2)
S3method(print,BayesMallowsSMC2)
S3method(print,summary.BayesMallowsSMC2)
S3method(summary,BayesMallowsSMC2)
export(compute_sequentially)
export(precompute_topological_sorts)
export(set_hyperparameters)
Expand Down
120 changes: 120 additions & 0 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,123 @@ print.BayesMallowsSMC2 <- function(x, ...) {

invisible(x)
}

#' Summary Method for BayesMallowsSMC2 Objects
#'
#' Creates a summary of a BayesMallowsSMC2 object returned by
#' [compute_sequentially()].
#'
#' @param object An object of class \code{BayesMallowsSMC2}.
#' @param ... Additional arguments (currently unused).
#'
#' @return An object of class \code{summary.BayesMallowsSMC2}, which is a list
#' containing summary information about the model.
#'
#' @details
#' The summary method creates a summary object that includes:
#' \itemize{
#' \item Number of particles
#' \item Number of timepoints
#' \item Number of items
#' \item Number of clusters
#' \item Log marginal likelihood
#' \item Final effective sample size (ESS)
#' \item Number of resampling events
#' \item Posterior mean of alpha for each cluster
#' \item Posterior standard deviation of alpha for each cluster
#' }
#'
#' @export
#'
#' @examples
#' # Fit a model with complete rankings
#' set.seed(123)
#' mod <- compute_sequentially(
#' complete_rankings,
#' hyperparameters = set_hyperparameters(n_items = 5),
#' smc_options = set_smc_options(n_particles = 100, n_particle_filters = 1)
#' )
#'
#' # Create summary
#' summary(mod)
#'
summary.BayesMallowsSMC2 <- function(object, ...) {
# Basic validation
if (!inherits(object, "BayesMallowsSMC2")) {
stop("object must be an object of class 'BayesMallowsSMC2'")
}

required_fields <- c("alpha", "rho", "ESS", "resampling", "log_marginal_likelihood")
missing_fields <- setdiff(required_fields, names(object))
if (length(missing_fields) > 0) {
stop("Object is missing required fields: ", paste(missing_fields, collapse = ", "))
}

# Extract dimensions
n_particles <- ncol(object$alpha)
n_timepoints <- length(object$ESS)
n_items <- dim(object$rho)[1]
n_clusters <- nrow(object$alpha)

# Count resampling events
n_resampling_events <- sum(object$resampling)

# Compute posterior mean and standard deviation of alpha
# alpha is a matrix where rows are clusters and columns are particles
alpha_mean <- rowMeans(object$alpha)
alpha_sd <- apply(object$alpha, 1, sd)

# Create summary object
summary_obj <- list(
n_particles = n_particles,
n_timepoints = n_timepoints,
n_items = n_items,
n_clusters = n_clusters,
log_marginal_likelihood = object$log_marginal_likelihood,
final_ess = object$ESS[n_timepoints],
n_resampling_events = n_resampling_events,
alpha_mean = alpha_mean,
alpha_sd = alpha_sd
)

class(summary_obj) <- "summary.BayesMallowsSMC2"
summary_obj
}

#' Print Method for summary.BayesMallowsSMC2 Objects
#'
#' Prints a summary of a BayesMallowsSMC2 model.
#'
#' @param x An object of class \code{summary.BayesMallowsSMC2}.
#' @param ... Additional arguments (currently unused).
#'
#' @return Invisibly returns the input object \code{x}.
#'
#' @export
#'
print.summary.BayesMallowsSMC2 <- function(x, ...) {
# Create header
cat("BayesMallowsSMC2 Model Summary\n")
cat(strrep("=", nchar("BayesMallowsSMC2 Model Summary")), "\n\n", sep = "")

# Display basic information
cat(sprintf("%-25s %s\n", "Number of particles:", x$n_particles))
cat(sprintf("%-25s %s\n", "Number of timepoints:", x$n_timepoints))
cat(sprintf("%-25s %s\n", "Number of items:", x$n_items))
cat(sprintf("%-25s %s\n\n", "Number of clusters:", x$n_clusters))

# Display model fit information
cat(sprintf("%-25s %.2f\n", "Log marginal likelihood:", x$log_marginal_likelihood))
cat(sprintf("%-25s %.2f\n", "Final ESS:", x$final_ess))
cat(sprintf("%-25s %d/%d\n\n", "Resampling events:", x$n_resampling_events, x$n_timepoints))

# Display posterior statistics for alpha
cat("Posterior Statistics for Alpha:\n")
cat(strrep("-", nchar("Posterior Statistics for Alpha:")), "\n", sep = "")
for (i in seq_along(x$alpha_mean)) {
cat(sprintf("Cluster %d: Mean = %.4f, SD = %.4f\n",
i, x$alpha_mean[i], x$alpha_sd[i]))
}

invisible(x)
}
73 changes: 73 additions & 0 deletions tests/testthat/test-print.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,76 @@ test_that("print method works with partial rankings", {
output <- capture.output(print(mod))
expect_true(any(grepl("BayesMallowsSMC2 Model", output)))
})

test_that("summary method works for BayesMallowsSMC2 objects", {
set.seed(123)
mod <- compute_sequentially(
complete_rankings,
hyperparameters = set_hyperparameters(n_items = 5),
smc_options = set_smc_options(n_particles = 100, n_particle_filters = 1)
)

# Test that summary method runs without error
expect_error(summary(mod), NA)

# Test that summary returns an object of the correct class
summ <- summary(mod)
expect_s3_class(summ, "summary.BayesMallowsSMC2")

# Test that summary object contains expected fields
expect_true("n_particles" %in% names(summ))
expect_true("n_timepoints" %in% names(summ))
expect_true("n_items" %in% names(summ))
expect_true("n_clusters" %in% names(summ))
expect_true("log_marginal_likelihood" %in% names(summ))
expect_true("final_ess" %in% names(summ))
expect_true("n_resampling_events" %in% names(summ))
expect_true("alpha_mean" %in% names(summ))
expect_true("alpha_sd" %in% names(summ))

# Test that alpha statistics are numeric
expect_type(summ$alpha_mean, "double")
expect_type(summ$alpha_sd, "double")

# Test that alpha statistics have correct length (equal to number of clusters)
expect_equal(length(summ$alpha_mean), summ$n_clusters)
expect_equal(length(summ$alpha_sd), summ$n_clusters)

# Test that print method for summary works
expect_error(print(summ), NA)

# Capture output and verify it contains expected content
output <- capture.output(print(summ))
expect_true(any(grepl("BayesMallowsSMC2 Model Summary", output)))
expect_true(any(grepl("Number of particles:", output)))
expect_true(any(grepl("Number of timepoints:", output)))
expect_true(any(grepl("Number of items:", output)))
expect_true(any(grepl("Number of clusters:", output)))
expect_true(any(grepl("Log marginal likelihood:", output)))
expect_true(any(grepl("Final ESS:", output)))
expect_true(any(grepl("Resampling events:", output)))
expect_true(any(grepl("Posterior Statistics for Alpha:", output)))
expect_true(any(grepl("Mean =", output)))
expect_true(any(grepl("SD =", output)))
})

test_that("summary method works with partial rankings", {
set.seed(456)
mod <- compute_sequentially(
partial_rankings,
hyperparameters = set_hyperparameters(n_items = 5),
smc_options = set_smc_options(n_particles = 50, n_particle_filters = 1)
)

# Test that summary method runs without error
expect_error(summary(mod), NA)

# Test summary object has correct structure
summ <- summary(mod)
expect_s3_class(summ, "summary.BayesMallowsSMC2")

# Capture output and verify it contains expected content
output <- capture.output(print(summ))
expect_true(any(grepl("BayesMallowsSMC2 Model Summary", output)))
expect_true(any(grepl("Posterior Statistics for Alpha:", output)))
})