Skip to content

Commit 7adc9e3

Browse files
committed
getting rid of NOTEs
1 parent fcee400 commit 7adc9e3

File tree

6 files changed

+116
-47
lines changed

6 files changed

+116
-47
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ export(set_hyperparameters)
1010
export(set_smc_options)
1111
export(trace_plot)
1212
importFrom(Rcpp,sourceCpp)
13+
importFrom(Rdpack,reprompt)
1314
useDynLib(BayesMallowsSMC2, .registration = TRUE)

R/BayesMallowsSMC2-package.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
## usethis namespace: start
2+
#' @importFrom Rdpack reprompt
23
#' @importFrom Rcpp sourceCpp
34
## usethis namespace: end
45
NULL

R/print.R

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,37 +43,37 @@ print.BayesMallowsSMC2 <- function(x, ...) {
4343
if (!inherits(x, "BayesMallowsSMC2")) {
4444
stop("x must be an object of class 'BayesMallowsSMC2'")
4545
}
46-
46+
4747
required_fields <- c("alpha", "rho", "ESS", "resampling", "log_marginal_likelihood")
4848
missing_fields <- setdiff(required_fields, names(x))
4949
if (length(missing_fields) > 0) {
5050
stop("Object is missing required fields: ", paste(missing_fields, collapse = ", "))
5151
}
52-
52+
5353
# Extract dimensions
5454
n_particles <- ncol(x$alpha)
5555
n_timepoints <- length(x$ESS)
5656
n_items <- dim(x$rho)[1]
5757
n_clusters <- nrow(x$alpha)
58-
58+
5959
# Count resampling events
6060
n_resampling_events <- sum(x$resampling)
61-
61+
6262
# Create header
6363
cat("BayesMallowsSMC2 Model\n")
6464
cat(strrep("=", nchar("BayesMallowsSMC2 Model")), "\n\n", sep = "")
65-
65+
6666
# Display basic information
6767
cat(sprintf("%-25s %s\n", "Number of particles:", n_particles))
6868
cat(sprintf("%-25s %s\n", "Number of timepoints:", n_timepoints))
6969
cat(sprintf("%-25s %s\n", "Number of items:", n_items))
7070
cat(sprintf("%-25s %s\n\n", "Number of clusters:", n_clusters))
71-
71+
7272
# Display model fit information
7373
cat(sprintf("%-25s %.2f\n", "Log marginal likelihood:", x$log_marginal_likelihood))
7474
cat(sprintf("%-25s %.2f\n", "Final ESS:", x$ESS[n_timepoints]))
7575
cat(sprintf("%-25s %d/%d\n", "Resampling events:", n_resampling_events, n_timepoints))
76-
76+
7777
invisible(x)
7878
}
7979

@@ -121,27 +121,27 @@ summary.BayesMallowsSMC2 <- function(object, ...) {
121121
if (!inherits(object, "BayesMallowsSMC2")) {
122122
stop("object must be an object of class 'BayesMallowsSMC2'")
123123
}
124-
124+
125125
required_fields <- c("alpha", "rho", "ESS", "resampling", "log_marginal_likelihood")
126126
missing_fields <- setdiff(required_fields, names(object))
127127
if (length(missing_fields) > 0) {
128128
stop("Object is missing required fields: ", paste(missing_fields, collapse = ", "))
129129
}
130-
130+
131131
# Extract dimensions
132132
n_particles <- ncol(object$alpha)
133133
n_timepoints <- length(object$ESS)
134134
n_items <- dim(object$rho)[1]
135135
n_clusters <- nrow(object$alpha)
136-
136+
137137
# Count resampling events
138138
n_resampling_events <- sum(object$resampling)
139-
139+
140140
# Compute posterior mean and standard deviation of alpha
141141
# alpha is a matrix where rows are clusters and columns are particles
142142
alpha_mean <- rowMeans(object$alpha)
143-
alpha_sd <- apply(object$alpha, 1, sd)
144-
143+
alpha_sd <- apply(object$alpha, 1, stats::sd)
144+
145145
# Create summary object
146146
summary_obj <- list(
147147
n_particles = n_particles,
@@ -154,7 +154,7 @@ summary.BayesMallowsSMC2 <- function(object, ...) {
154154
alpha_mean = alpha_mean,
155155
alpha_sd = alpha_sd
156156
)
157-
157+
158158
class(summary_obj) <- "summary.BayesMallowsSMC2"
159159
summary_obj
160160
}
@@ -174,25 +174,25 @@ print.summary.BayesMallowsSMC2 <- function(x, ...) {
174174
# Create header
175175
cat("BayesMallowsSMC2 Model Summary\n")
176176
cat(strrep("=", nchar("BayesMallowsSMC2 Model Summary")), "\n\n", sep = "")
177-
177+
178178
# Display basic information
179179
cat(sprintf("%-25s %s\n", "Number of particles:", x$n_particles))
180180
cat(sprintf("%-25s %s\n", "Number of timepoints:", x$n_timepoints))
181181
cat(sprintf("%-25s %s\n", "Number of items:", x$n_items))
182182
cat(sprintf("%-25s %s\n\n", "Number of clusters:", x$n_clusters))
183-
183+
184184
# Display model fit information
185185
cat(sprintf("%-25s %.2f\n", "Log marginal likelihood:", x$log_marginal_likelihood))
186186
cat(sprintf("%-25s %.2f\n", "Final ESS:", x$final_ess))
187187
cat(sprintf("%-25s %d/%d\n\n", "Resampling events:", x$n_resampling_events, x$n_timepoints))
188-
188+
189189
# Display posterior statistics for alpha
190190
cat("Posterior Statistics for Alpha:\n")
191191
cat(strrep("-", nchar("Posterior Statistics for Alpha:")), "\n", sep = "")
192192
for (i in seq_along(x$alpha_mean)) {
193-
cat(sprintf("Cluster %d: Mean = %.4f, SD = %.4f\n",
193+
cat(sprintf("Cluster %d: Mean = %.4f, SD = %.4f\n",
194194
i, x$alpha_mean[i], x$alpha_sd[i]))
195195
}
196-
196+
197197
invisible(x)
198198
}

R/trace_plot.R

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,31 +55,31 @@
5555
trace_plot <- function(x, parameter = "alpha", ...) {
5656
# Validate parameter
5757
parameter <- match.arg(parameter, c("alpha", "tau"))
58-
58+
5959
# Basic validation
6060
if (!inherits(x, "BayesMallowsSMC2")) {
6161
stop("x must be an object of class 'BayesMallowsSMC2'")
6262
}
63-
63+
6464
# Check if trace was enabled
6565
trace_field <- paste0(parameter, "_traces")
6666
if (!trace_field %in% names(x)) {
6767
stop("Trace data not found. Please run compute_sequentially with trace = TRUE in set_smc_options().")
6868
}
69-
69+
7070
traces <- x[[trace_field]]
71-
71+
7272
if (length(traces) == 0) {
7373
stop("Trace data not found. Please run compute_sequentially with trace = TRUE in set_smc_options().")
7474
}
75-
75+
7676
# Check for importance weights trace
7777
if (!"log_importance_weights_traces" %in% names(x)) {
7878
stop("Importance weights trace not found. This should not happen if trace = TRUE was used.")
7979
}
80-
80+
8181
log_weights_traces <- x$log_importance_weights_traces
82-
82+
8383
if (parameter == "alpha") {
8484
plot_trace_alpha_tau(traces, log_weights_traces, parameter_name = "alpha",
8585
parameter_label = expression(alpha))
@@ -101,16 +101,16 @@ utils::globalVariables(c("timepoint", "mean", "lower", "upper", "cluster"))
101101
plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
102102
parameter_label) {
103103
n_timepoints <- length(traces)
104-
104+
105105
# Get dimensions from first trace
106106
# Need to infer n_clusters and n_particles from the trace
107107
first_trace <- traces[[1]]
108-
108+
109109
# If trace is a vector, need to infer dimensions
110110
if (is.vector(first_trace)) {
111111
# Get n_particles from log_weights
112112
n_particles <- length(log_weights_traces[[1]])
113-
113+
114114
# If trace is a vector, infer n_clusters from its length
115115
trace_length <- length(first_trace)
116116
if (trace_length %% n_particles == 0) {
@@ -120,10 +120,10 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
120120
trace_length, n_particles),
121121
"This indicates inconsistent dimensions in the trace data.")
122122
}
123-
123+
124124
# Convert all traces to matrices [n_clusters x n_particles]
125125
# The C++ code stores traces as: alpha is [n_clusters x n_particles] matrix per timepoint
126-
# When passed to R as vector, elements are in column-major order:
126+
# When passed to R as vector, elements are in column-major order:
127127
# cluster1_particle1, cluster2_particle1, cluster1_particle2, cluster2_particle2, ...
128128
traces <- lapply(traces, function(t) {
129129
matrix(t, nrow = n_clusters, ncol = n_particles, byrow = FALSE)
@@ -135,27 +135,27 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
135135
} else {
136136
stop("Trace elements must be vectors or matrices")
137137
}
138-
138+
139139
# Create data frame for plotting
140140
plot_data_list <- vector("list", n_timepoints * n_clusters)
141141
idx <- 1
142-
142+
143143
for (t in seq_len(n_timepoints)) {
144144
param_matrix <- traces[[t]]
145145
log_weights <- log_weights_traces[[t]]
146-
146+
147147
# Normalize weights
148148
weights <- exp(log_weights - max(log_weights))
149149
weights <- weights / sum(weights)
150-
150+
151151
for (cluster in seq_len(n_clusters)) {
152152
param_values <- param_matrix[cluster, ]
153-
153+
154154
# Compute weighted statistics
155-
weighted_mean <- weighted.mean(param_values, weights)
156-
weighted_quantiles <- weighted_quantile(param_values, weights,
155+
weighted_mean <- stats::weighted.mean(param_values, weights)
156+
weighted_quantiles <- weighted_quantile(param_values, weights,
157157
probs = c(0.025, 0.975))
158-
158+
159159
plot_data_list[[idx]] <- data.frame(
160160
timepoint = t,
161161
mean = weighted_mean,
@@ -166,12 +166,12 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
166166
idx <- idx + 1
167167
}
168168
}
169-
169+
170170
plot_data <- do.call(rbind, plot_data_list)
171-
171+
172172
# Create line plot with credible interval
173173
p <- ggplot2::ggplot(plot_data, ggplot2::aes(x = timepoint, y = mean)) +
174-
ggplot2::geom_ribbon(ggplot2::aes(ymin = lower, ymax = upper),
174+
ggplot2::geom_ribbon(ggplot2::aes(ymin = lower, ymax = upper),
175175
alpha = 0.3, fill = "steelblue") +
176176
ggplot2::geom_line(color = "darkblue", linewidth = 1) +
177177
ggplot2::xlab("Timepoint") +
@@ -180,12 +180,12 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
180180
ggplot2::theme(
181181
panel.grid.minor = ggplot2::element_blank()
182182
)
183-
183+
184184
# Add faceting if multiple clusters
185185
if (n_clusters > 1) {
186186
p <- p + ggplot2::facet_wrap(~ cluster, scales = "free_y")
187187
}
188-
188+
189189
p
190190
}
191191

@@ -200,10 +200,10 @@ weighted_quantile <- function(x, weights, probs) {
200200
ord <- order(x)
201201
x_sorted <- x[ord]
202202
weights_sorted <- weights[ord]
203-
203+
204204
# Compute cumulative weights
205205
cum_weights <- cumsum(weights_sorted) / sum(weights_sorted)
206-
206+
207207
# Find quantiles
208208
quantiles <- numeric(length(probs))
209209
for (i in seq_along(probs)) {
@@ -214,6 +214,6 @@ weighted_quantile <- function(x, weights, probs) {
214214
}
215215
quantiles[i] <- x_sorted[idx]
216216
}
217-
217+
218218
quantiles
219219
}

man/print.summary.BayesMallowsSMC2.Rd

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/summary.BayesMallowsSMC2.Rd

Lines changed: 48 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)