@@ -70,7 +70,7 @@ trace_plot <- function(x, parameter = "alpha", ...) {
7070 traces <- x [[trace_field ]]
7171
7272 if (length(traces ) == 0 ) {
73- stop(" No trace data available . Please run compute_sequentially with trace = TRUE in set_smc_options()." )
73+ stop(" Trace data not found . Please run compute_sequentially with trace = TRUE in set_smc_options()." )
7474 }
7575
7676 # Check for importance weights trace
@@ -103,9 +103,38 @@ plot_trace_alpha_tau <- function(traces, log_weights_traces, parameter_name,
103103 n_timepoints <- length(traces )
104104
105105 # Get dimensions from first trace
106+ # Need to infer n_clusters and n_particles from the trace
106107 first_trace <- traces [[1 ]]
107- n_clusters <- nrow(first_trace )
108- n_particles <- ncol(first_trace )
108+
109+ # If trace is a vector, need to infer dimensions
110+ if (is.vector(first_trace )) {
111+ # Get n_particles from log_weights
112+ n_particles <- length(log_weights_traces [[1 ]])
113+
114+ # If trace is a vector, infer n_clusters from its length
115+ trace_length <- length(first_trace )
116+ if (trace_length %% n_particles == 0 ) {
117+ n_clusters <- trace_length %/% n_particles
118+ } else {
119+ stop(sprintf(" Trace length (%d) is not divisible by n_particles (%d). " ,
120+ trace_length , n_particles ),
121+ " This indicates inconsistent dimensions in the trace data." )
122+ }
123+
124+ # Convert all traces to matrices [n_clusters x n_particles]
125+ # The C++ code stores traces as: alpha is [n_clusters x n_particles] matrix per timepoint
126+ # When passed to R as vector, elements are in column-major order:
127+ # cluster1_particle1, cluster2_particle1, cluster1_particle2, cluster2_particle2, ...
128+ traces <- lapply(traces , function (t ) {
129+ matrix (t , nrow = n_clusters , ncol = n_particles , byrow = FALSE )
130+ })
131+ first_trace <- traces [[1 ]]
132+ } else if (is.matrix(first_trace )) {
133+ n_clusters <- nrow(first_trace )
134+ n_particles <- ncol(first_trace )
135+ } else {
136+ stop(" Trace elements must be vectors or matrices" )
137+ }
109138
110139 # Create data frame for plotting
111140 plot_data_list <- vector(" list" , n_timepoints * n_clusters )
0 commit comments