5555trace_plot <- function (x , parameter = " alpha" , ... ) {
5656 # Validate parameter
5757 parameter <- match.arg(parameter , c(" alpha" , " tau" ))
58-
58+
5959 # Basic validation
6060 if (! inherits(x , " BayesMallowsSMC2" )) {
6161 stop(" x must be an object of class 'BayesMallowsSMC2'" )
6262 }
63-
63+
6464 # Check if trace was enabled
6565 trace_field <- paste0(parameter , " _traces" )
6666 if (! trace_field %in% names(x )) {
6767 stop(" Trace data not found. Please run compute_sequentially with trace = TRUE in set_smc_options()." )
6868 }
69-
69+
7070 traces <- x [[trace_field ]]
71-
71+
7272 if (length(traces ) == 0 ) {
7373 stop(" Trace data not found. Please run compute_sequentially with trace = TRUE in set_smc_options()." )
7474 }
75-
75+
7676 # Check for importance weights trace
7777 if (! " log_importance_weights_traces" %in% names(x )) {
7878 stop(" Importance weights trace not found. This should not happen if trace = TRUE was used." )
7979 }
80-
80+
8181 log_weights_traces <- x $ log_importance_weights_traces
82-
82+
8383 if (parameter == " alpha" ) {
8484 plot_trace_alpha_tau(traces , log_weights_traces , parameter_name = " alpha" ,
8585 parameter_label = expression(alpha ))
@@ -101,16 +101,16 @@ utils::globalVariables(c("timepoint", "mean", "lower", "upper", "cluster"))
101101plot_trace_alpha_tau <- function (traces , log_weights_traces , parameter_name ,
102102 parameter_label ) {
103103 n_timepoints <- length(traces )
104-
104+
105105 # Get dimensions from first trace
106106 # Need to infer n_clusters and n_particles from the trace
107107 first_trace <- traces [[1 ]]
108-
108+
109109 # If trace is a vector, need to infer dimensions
110110 if (is.vector(first_trace )) {
111111 # Get n_particles from log_weights
112112 n_particles <- length(log_weights_traces [[1 ]])
113-
113+
114114 # If trace is a vector, infer n_clusters from its length
115115 trace_length <- length(first_trace )
116116 if (trace_length %% n_particles == 0 ) {
@@ -120,10 +120,10 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
120120 trace_length , n_particles ),
121121 " This indicates inconsistent dimensions in the trace data." )
122122 }
123-
123+
124124 # Convert all traces to matrices [n_clusters x n_particles]
125125 # The C++ code stores traces as: alpha is [n_clusters x n_particles] matrix per timepoint
126- # When passed to R as vector, elements are in column-major order:
126+ # When passed to R as vector, elements are in column-major order:
127127 # cluster1_particle1, cluster2_particle1, cluster1_particle2, cluster2_particle2, ...
128128 traces <- lapply(traces , function (t ) {
129129 matrix (t , nrow = n_clusters , ncol = n_particles , byrow = FALSE )
@@ -135,27 +135,27 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
135135 } else {
136136 stop(" Trace elements must be vectors or matrices" )
137137 }
138-
138+
139139 # Create data frame for plotting
140140 plot_data_list <- vector(" list" , n_timepoints * n_clusters )
141141 idx <- 1
142-
142+
143143 for (t in seq_len(n_timepoints )) {
144144 param_matrix <- traces [[t ]]
145145 log_weights <- log_weights_traces [[t ]]
146-
146+
147147 # Normalize weights
148148 weights <- exp(log_weights - max(log_weights ))
149149 weights <- weights / sum(weights )
150-
150+
151151 for (cluster in seq_len(n_clusters )) {
152152 param_values <- param_matrix [cluster , ]
153-
153+
154154 # Compute weighted statistics
155- weighted_mean <- weighted.mean(param_values , weights )
156- weighted_quantiles <- weighted_quantile(param_values , weights ,
155+ weighted_mean <- stats :: weighted.mean(param_values , weights )
156+ weighted_quantiles <- weighted_quantile(param_values , weights ,
157157 probs = c(0.025 , 0.975 ))
158-
158+
159159 plot_data_list [[idx ]] <- data.frame (
160160 timepoint = t ,
161161 mean = weighted_mean ,
@@ -166,12 +166,12 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
166166 idx <- idx + 1
167167 }
168168 }
169-
169+
170170 plot_data <- do.call(rbind , plot_data_list )
171-
171+
172172 # Create line plot with credible interval
173173 p <- ggplot2 :: ggplot(plot_data , ggplot2 :: aes(x = timepoint , y = mean )) +
174- ggplot2 :: geom_ribbon(ggplot2 :: aes(ymin = lower , ymax = upper ),
174+ ggplot2 :: geom_ribbon(ggplot2 :: aes(ymin = lower , ymax = upper ),
175175 alpha = 0.3 , fill = " steelblue" ) +
176176 ggplot2 :: geom_line(color = " darkblue" , linewidth = 1 ) +
177177 ggplot2 :: xlab(" Timepoint" ) +
@@ -180,12 +180,12 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
180180 ggplot2 :: theme(
181181 panel.grid.minor = ggplot2 :: element_blank()
182182 )
183-
183+
184184 # Add faceting if multiple clusters
185185 if (n_clusters > 1 ) {
186186 p <- p + ggplot2 :: facet_wrap(~ cluster , scales = " free_y" )
187187 }
188-
188+
189189 p
190190}
191191
@@ -200,10 +200,10 @@ weighted_quantile <- function(x, weights, probs) {
200200 ord <- order(x )
201201 x_sorted <- x [ord ]
202202 weights_sorted <- weights [ord ]
203-
203+
204204 # Compute cumulative weights
205205 cum_weights <- cumsum(weights_sorted ) / sum(weights_sorted )
206-
206+
207207 # Find quantiles
208208 quantiles <- numeric (length(probs ))
209209 for (i in seq_along(probs )) {
@@ -214,6 +214,6 @@ weighted_quantile <- function(x, weights, probs) {
214214 }
215215 quantiles [i ] <- x_sorted [idx ]
216216 }
217-
217+
218218 quantiles
219219}
0 commit comments