diff --git a/src/particle.cpp b/src/particle.cpp index 016a2d7..fb561fe 100644 --- a/src/particle.cpp +++ b/src/particle.cpp @@ -111,39 +111,36 @@ void Particle::run_particle_filter( void Particle::assemble_backward_trajectory( unsigned int T, const std::unique_ptr &resampler) { + // Pre-allocate storage so each column/element is written in O(1) rather than + // prepending with insert_cols/insert_rows, which would be O(T^2) overall. + unsigned int n_items = parameters.rho.n_rows; + bool multi_cluster = (parameters.tau.size() > 1); ParticleFilter new_reference; new_reference.log_weight.resize(T + 1); + new_reference.latent_rankings.set_size(n_items, T + 1); + if (multi_cluster) { + new_reference.cluster_assignments.set_size(T + 1); + new_reference.cluster_probabilities.set_size(parameters.tau.size(), T + 1); + new_reference.index = uvec(T + 1, fill::zeros); + } for (int t = T; t >= 0; --t) { - arma::vec current_weights = stored_weights[t]; - - // Sample a single index b_t based on current_weights - arma::ivec counts = resampler->resample(1, current_weights); + // Draw b_t independently from the forward filtering weights W_t. + // Because cross-sectional users are conditionally independent given the + // static parameters, the backward weights reduce to W_t exactly. + arma::ivec counts = resampler->resample(1, stored_weights[t]); unsigned int b_t = arma::as_scalar(arma::find(counts > 0, 1)); - if (new_reference.latent_rankings.is_empty()) { - new_reference.latent_rankings = - particle_filters[b_t].latent_rankings.col(t); - if (parameters.tau.size() > 1) { - new_reference.cluster_assignments = - particle_filters[b_t].cluster_assignments.subvec(t, t); - new_reference.cluster_probabilities = - particle_filters[b_t].cluster_probabilities.cols(t, t); - new_reference.index = uvec(T + 1, fill::zeros); - } - } else { - new_reference.latent_rankings.insert_cols( - 0, particle_filters[b_t].latent_rankings.col(t)); - if (parameters.tau.size() > 1) { - new_reference.cluster_assignments.insert_rows( - 0, particle_filters[b_t].cluster_assignments.subvec(t, t)); - new_reference.cluster_probabilities.insert_cols( - 0, particle_filters[b_t].cluster_probabilities.cols(t, t)); - } - } - + new_reference.latent_rankings.col(t) = + particle_filters[b_t].latent_rankings.col(t); new_reference.log_weight(t) = particle_filters[b_t].log_weight(t); + if (multi_cluster) { + new_reference.cluster_assignments(t) = + particle_filters[b_t].cluster_assignments(t); + new_reference.cluster_probabilities.col(t) = + particle_filters[b_t].cluster_probabilities.col(t); + } } this->particle_filters[0] = new_reference; diff --git a/src/rejuvenate.cpp b/src/rejuvenate.cpp index 8bcca38..5073cbe 100644 --- a/src/rejuvenate.cpp +++ b/src/rejuvenate.cpp @@ -100,6 +100,7 @@ bool Particle::rejuvenate(unsigned int T, const Options &options, this->log_normalized_particle_filter_weights = proposal_particle.log_normalized_particle_filter_weights; this->particle_filters = proposal_particle.particle_filters; + this->stored_weights = proposal_particle.stored_weights; this->logz = proposal_particle.logz; accepted = true; } else { @@ -143,6 +144,10 @@ bool Particle::rejuvenate(unsigned int T, const Options &options, } else { sample_particle_filter(); } + } else if (options.backward_sampling) { + // For single-cluster models, apply backward sampling to update the + // reference trajectory from the current particle filter weights. + this->assemble_backward_trajectory(T, resampler); } return accepted; diff --git a/tests/testthat/test-compute_sequentially_complete.R b/tests/testthat/test-compute_sequentially_complete.R index efb715f..8ea5da7 100644 --- a/tests/testthat/test-compute_sequentially_complete.R +++ b/tests/testthat/test-compute_sequentially_complete.R @@ -40,7 +40,6 @@ test_that("compute_sequentially works with complete data", { resampler = "systematic") ) alpha_hat <- weighted.mean(x = as.numeric(mod$alpha), w = mod$importance_weights) - alpha_hat <- weighted.mean(x = as.numeric(mod$alpha), w = mod$importance_weights) expect_gt(alpha_hat, .02) expect_lt(alpha_hat, .05) })