Skip to content

Commit 711e29d

Browse files
Copilotosorensen
andcommitted
Update NAMESPACE and documentation, fix trace_plot for single-cluster models
Co-authored-by: osorensen <21175639+osorensen@users.noreply.github.com>
1 parent c2c2906 commit 711e29d

File tree

3 files changed

+97
-3
lines changed

3 files changed

+97
-3
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ export(compute_sequentially)
66
export(precompute_topological_sorts)
77
export(set_hyperparameters)
88
export(set_smc_options)
9+
export(trace_plot)
910
importFrom(Rcpp,sourceCpp)
1011
useDynLib(BayesMallowsSMC2, .registration = TRUE)

R/trace_plot.R

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ trace_plot <- function(x, parameter = "alpha", ...) {
7070
traces <- x[[trace_field]]
7171

7272
if (length(traces) == 0) {
73-
stop("No trace data available. Please run compute_sequentially with trace = TRUE in set_smc_options().")
73+
stop("Trace data not found. Please run compute_sequentially with trace = TRUE in set_smc_options().")
7474
}
7575

7676
# Check for importance weights trace
@@ -103,9 +103,38 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
103103
n_timepoints <- length(traces)
104104

105105
# Get dimensions from first trace
106+
# Need to infer n_clusters and n_particles from the trace
106107
first_trace <- traces[[1]]
107-
n_clusters <- nrow(first_trace)
108-
n_particles <- ncol(first_trace)
108+
109+
# If trace is a vector, need to infer dimensions
110+
if (is.vector(first_trace)) {
111+
# Get n_particles from log_weights
112+
n_particles <- length(log_weights_traces[[1]])
113+
114+
# If trace is a vector, infer n_clusters from its length
115+
trace_length <- length(first_trace)
116+
if (trace_length %% n_particles == 0) {
117+
n_clusters <- trace_length %/% n_particles
118+
} else {
119+
stop(sprintf("Trace length (%d) is not divisible by n_particles (%d). ",
120+
trace_length, n_particles),
121+
"This indicates inconsistent dimensions in the trace data.")
122+
}
123+
124+
# Convert all traces to matrices [n_clusters x n_particles]
125+
# The C++ code stores traces as: alpha is [n_clusters x n_particles] matrix per timepoint
126+
# When passed to R as vector, elements are in column-major order:
127+
# cluster1_particle1, cluster2_particle1, cluster1_particle2, cluster2_particle2, ...
128+
traces <- lapply(traces, function(t) {
129+
matrix(t, nrow = n_clusters, ncol = n_particles, byrow = FALSE)
130+
})
131+
first_trace <- traces[[1]]
132+
} else if (is.matrix(first_trace)) {
133+
n_clusters <- nrow(first_trace)
134+
n_particles <- ncol(first_trace)
135+
} else {
136+
stop("Trace elements must be vectors or matrices")
137+
}
109138

110139
# Create data frame for plotting
111140
plot_data_list <- vector("list", n_timepoints * n_clusters)

man/trace_plot.Rd

Lines changed: 64 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)