Skip to content

Commit 3b5ff55

Browse files
authored
Merge pull request #61 from osorensen/copilot/create-summary-method
[WIP] Add summary method for BayesMallowsSMC2 objects
2 parents 358ca28 + edcb198 commit 3b5ff55

File tree

3 files changed

+195
-0
lines changed

3 files changed

+195
-0
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
S3method(plot,BayesMallowsSMC2)
44
S3method(print,BayesMallowsSMC2)
5+
S3method(print,summary.BayesMallowsSMC2)
6+
S3method(summary,BayesMallowsSMC2)
57
export(compute_sequentially)
68
export(precompute_topological_sorts)
79
export(set_hyperparameters)

R/print.R

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,123 @@ print.BayesMallowsSMC2 <- function(x, ...) {
7676

7777
invisible(x)
7878
}
79+
80+
#' Summary Method for BayesMallowsSMC2 Objects
81+
#'
82+
#' Creates a summary of a BayesMallowsSMC2 object returned by
83+
#' [compute_sequentially()].
84+
#'
85+
#' @param object An object of class \code{BayesMallowsSMC2}.
86+
#' @param ... Additional arguments (currently unused).
87+
#'
88+
#' @return An object of class \code{summary.BayesMallowsSMC2}, which is a list
89+
#' containing summary information about the model.
90+
#'
91+
#' @details
92+
#' The summary method creates a summary object that includes:
93+
#' \itemize{
94+
#' \item Number of particles
95+
#' \item Number of timepoints
96+
#' \item Number of items
97+
#' \item Number of clusters
98+
#' \item Log marginal likelihood
99+
#' \item Final effective sample size (ESS)
100+
#' \item Number of resampling events
101+
#' \item Posterior mean of alpha for each cluster
102+
#' \item Posterior standard deviation of alpha for each cluster
103+
#' }
104+
#'
105+
#' @export
106+
#'
107+
#' @examples
108+
#' # Fit a model with complete rankings
109+
#' set.seed(123)
110+
#' mod <- compute_sequentially(
111+
#' complete_rankings,
112+
#' hyperparameters = set_hyperparameters(n_items = 5),
113+
#' smc_options = set_smc_options(n_particles = 100, n_particle_filters = 1)
114+
#' )
115+
#'
116+
#' # Create summary
117+
#' summary(mod)
118+
#'
119+
summary.BayesMallowsSMC2 <- function(object, ...) {
120+
# Basic validation
121+
if (!inherits(object, "BayesMallowsSMC2")) {
122+
stop("object must be an object of class 'BayesMallowsSMC2'")
123+
}
124+
125+
required_fields <- c("alpha", "rho", "ESS", "resampling", "log_marginal_likelihood")
126+
missing_fields <- setdiff(required_fields, names(object))
127+
if (length(missing_fields) > 0) {
128+
stop("Object is missing required fields: ", paste(missing_fields, collapse = ", "))
129+
}
130+
131+
# Extract dimensions
132+
n_particles <- ncol(object$alpha)
133+
n_timepoints <- length(object$ESS)
134+
n_items <- dim(object$rho)[1]
135+
n_clusters <- nrow(object$alpha)
136+
137+
# Count resampling events
138+
n_resampling_events <- sum(object$resampling)
139+
140+
# Compute posterior mean and standard deviation of alpha
141+
# alpha is a matrix where rows are clusters and columns are particles
142+
alpha_mean <- rowMeans(object$alpha)
143+
alpha_sd <- apply(object$alpha, 1, sd)
144+
145+
# Create summary object
146+
summary_obj <- list(
147+
n_particles = n_particles,
148+
n_timepoints = n_timepoints,
149+
n_items = n_items,
150+
n_clusters = n_clusters,
151+
log_marginal_likelihood = object$log_marginal_likelihood,
152+
final_ess = object$ESS[n_timepoints],
153+
n_resampling_events = n_resampling_events,
154+
alpha_mean = alpha_mean,
155+
alpha_sd = alpha_sd
156+
)
157+
158+
class(summary_obj) <- "summary.BayesMallowsSMC2"
159+
summary_obj
160+
}
161+
162+
#' Print Method for summary.BayesMallowsSMC2 Objects
163+
#'
164+
#' Prints a summary of a BayesMallowsSMC2 model.
165+
#'
166+
#' @param x An object of class \code{summary.BayesMallowsSMC2}.
167+
#' @param ... Additional arguments (currently unused).
168+
#'
169+
#' @return Invisibly returns the input object \code{x}.
170+
#'
171+
#' @export
172+
#'
173+
print.summary.BayesMallowsSMC2 <- function(x, ...) {
174+
# Create header
175+
cat("BayesMallowsSMC2 Model Summary\n")
176+
cat(strrep("=", nchar("BayesMallowsSMC2 Model Summary")), "\n\n", sep = "")
177+
178+
# Display basic information
179+
cat(sprintf("%-25s %s\n", "Number of particles:", x$n_particles))
180+
cat(sprintf("%-25s %s\n", "Number of timepoints:", x$n_timepoints))
181+
cat(sprintf("%-25s %s\n", "Number of items:", x$n_items))
182+
cat(sprintf("%-25s %s\n\n", "Number of clusters:", x$n_clusters))
183+
184+
# Display model fit information
185+
cat(sprintf("%-25s %.2f\n", "Log marginal likelihood:", x$log_marginal_likelihood))
186+
cat(sprintf("%-25s %.2f\n", "Final ESS:", x$final_ess))
187+
cat(sprintf("%-25s %d/%d\n\n", "Resampling events:", x$n_resampling_events, x$n_timepoints))
188+
189+
# Display posterior statistics for alpha
190+
cat("Posterior Statistics for Alpha:\n")
191+
cat(strrep("-", nchar("Posterior Statistics for Alpha:")), "\n", sep = "")
192+
for (i in seq_along(x$alpha_mean)) {
193+
cat(sprintf("Cluster %d: Mean = %.4f, SD = %.4f\n",
194+
i, x$alpha_mean[i], x$alpha_sd[i]))
195+
}
196+
197+
invisible(x)
198+
}

tests/testthat/test-print.R

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,76 @@ test_that("print method works with partial rankings", {
3939
output <- capture.output(print(mod))
4040
expect_true(any(grepl("BayesMallowsSMC2 Model", output)))
4141
})
42+
43+
test_that("summary method works for BayesMallowsSMC2 objects", {
44+
set.seed(123)
45+
mod <- compute_sequentially(
46+
complete_rankings,
47+
hyperparameters = set_hyperparameters(n_items = 5),
48+
smc_options = set_smc_options(n_particles = 100, n_particle_filters = 1)
49+
)
50+
51+
# Test that summary method runs without error
52+
expect_error(summary(mod), NA)
53+
54+
# Test that summary returns an object of the correct class
55+
summ <- summary(mod)
56+
expect_s3_class(summ, "summary.BayesMallowsSMC2")
57+
58+
# Test that summary object contains expected fields
59+
expect_true("n_particles" %in% names(summ))
60+
expect_true("n_timepoints" %in% names(summ))
61+
expect_true("n_items" %in% names(summ))
62+
expect_true("n_clusters" %in% names(summ))
63+
expect_true("log_marginal_likelihood" %in% names(summ))
64+
expect_true("final_ess" %in% names(summ))
65+
expect_true("n_resampling_events" %in% names(summ))
66+
expect_true("alpha_mean" %in% names(summ))
67+
expect_true("alpha_sd" %in% names(summ))
68+
69+
# Test that alpha statistics are numeric
70+
expect_type(summ$alpha_mean, "double")
71+
expect_type(summ$alpha_sd, "double")
72+
73+
# Test that alpha statistics have correct length (equal to number of clusters)
74+
expect_equal(length(summ$alpha_mean), summ$n_clusters)
75+
expect_equal(length(summ$alpha_sd), summ$n_clusters)
76+
77+
# Test that print method for summary works
78+
expect_error(print(summ), NA)
79+
80+
# Capture output and verify it contains expected content
81+
output <- capture.output(print(summ))
82+
expect_true(any(grepl("BayesMallowsSMC2 Model Summary", output)))
83+
expect_true(any(grepl("Number of particles:", output)))
84+
expect_true(any(grepl("Number of timepoints:", output)))
85+
expect_true(any(grepl("Number of items:", output)))
86+
expect_true(any(grepl("Number of clusters:", output)))
87+
expect_true(any(grepl("Log marginal likelihood:", output)))
88+
expect_true(any(grepl("Final ESS:", output)))
89+
expect_true(any(grepl("Resampling events:", output)))
90+
expect_true(any(grepl("Posterior Statistics for Alpha:", output)))
91+
expect_true(any(grepl("Mean =", output)))
92+
expect_true(any(grepl("SD =", output)))
93+
})
94+
95+
test_that("summary method works with partial rankings", {
96+
set.seed(456)
97+
mod <- compute_sequentially(
98+
partial_rankings,
99+
hyperparameters = set_hyperparameters(n_items = 5),
100+
smc_options = set_smc_options(n_particles = 50, n_particle_filters = 1)
101+
)
102+
103+
# Test that summary method runs without error
104+
expect_error(summary(mod), NA)
105+
106+
# Test summary object has correct structure
107+
summ <- summary(mod)
108+
expect_s3_class(summ, "summary.BayesMallowsSMC2")
109+
110+
# Capture output and verify it contains expected content
111+
output <- capture.output(print(summ))
112+
expect_true(any(grepl("BayesMallowsSMC2 Model Summary", output)))
113+
expect_true(any(grepl("Posterior Statistics for Alpha:", output)))
114+
})

0 commit comments

Comments
 (0)