Skip to content

Commit 476ce51

Browse files
Copilotosorensen
andcommitted
Extract weighted sampling logic into helper function and remove unused import
Co-authored-by: osorensen <21175639+osorensen@users.noreply.github.com>
1 parent 9c3a950 commit 476ce51

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

NAMESPACE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,4 @@ importFrom(ggplot2,ggplot)
1515
importFrom(ggplot2,theme_minimal)
1616
importFrom(ggplot2,xlab)
1717
importFrom(ggplot2,ylab)
18-
importFrom(stats,weighted.mean)
1918
useDynLib(BayesMallowsSMC2, .registration = TRUE)

R/plot.R

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
#'
3434
#' @export
3535
#' @importFrom ggplot2 ggplot aes geom_histogram geom_col facet_wrap xlab ylab theme_minimal
36-
#' @importFrom stats weighted.mean
3736
#'
3837
#' @examples
3938
#' \dontrun{
@@ -82,6 +81,20 @@ plot.BayesMallowsSMC2 <- function(x, parameter = "alpha", items = NULL, ...) {
8281
}
8382

8483

84+
#' @keywords internal
85+
#' Helper function to create weighted samples from parameter values
86+
#' @param values Numeric vector of parameter values
87+
#' @param weights Numeric vector of importance weights
88+
#' @param n_samples Number of samples to draw (default 10000)
89+
#' @return Numeric vector of weighted samples
90+
create_weighted_samples <- function(values, weights, n_samples = 10000) {
91+
sample_probs <- weights / sum(weights)
92+
sampled_indices <- sample(seq_along(values), size = n_samples,
93+
replace = TRUE, prob = sample_probs)
94+
values[sampled_indices]
95+
}
96+
97+
8598
#' @keywords internal
8699
plot_alpha_smc <- function(x) {
87100
# Extract alpha values and weights
@@ -101,12 +114,7 @@ plot_alpha_smc <- function(x) {
101114
alpha_vals <- alpha_matrix[cluster, ]
102115

103116
# Create weighted samples by replicating values
104-
# Scale weights to get integer counts for sampling
105-
n_samples <- 10000
106-
sample_probs <- weights / sum(weights)
107-
sampled_indices <- sample(seq_len(n_particles), size = n_samples,
108-
replace = TRUE, prob = sample_probs)
109-
sampled_alpha <- alpha_vals[sampled_indices]
117+
sampled_alpha <- create_weighted_samples(alpha_vals, weights)
110118

111119
plot_data_list[[cluster]] <- data.frame(
112120
value = sampled_alpha,
@@ -150,11 +158,7 @@ plot_tau_smc <- function(x) {
150158
tau_vals <- tau_matrix[cluster, ]
151159

152160
# Create weighted samples
153-
n_samples <- 10000
154-
sample_probs <- weights / sum(weights)
155-
sampled_indices <- sample(seq_len(n_particles), size = n_samples,
156-
replace = TRUE, prob = sample_probs)
157-
sampled_tau <- tau_vals[sampled_indices]
161+
sampled_tau <- create_weighted_samples(tau_vals, weights)
158162

159163
plot_data_list[[cluster]] <- data.frame(
160164
value = sampled_tau,

0 commit comments

Comments
 (0)