Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions src/particle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,39 +111,36 @@ void Particle::run_particle_filter(

void Particle::assemble_backward_trajectory(
unsigned int T, const std::unique_ptr<Resampler> &resampler) {
// Pre-allocate storage so each column/element is written in O(1) rather than
// prepending with insert_cols/insert_rows, which would be O(T^2) overall.
unsigned int n_items = parameters.rho.n_rows;
bool multi_cluster = (parameters.tau.size() > 1);

ParticleFilter new_reference;
new_reference.log_weight.resize(T + 1);
new_reference.latent_rankings.set_size(n_items, T + 1);
if (multi_cluster) {
new_reference.cluster_assignments.set_size(T + 1);
new_reference.cluster_probabilities.set_size(parameters.tau.size(), T + 1);
new_reference.index = uvec(T + 1, fill::zeros);
}

for (int t = T; t >= 0; --t) {
arma::vec current_weights = stored_weights[t];

// Sample a single index b_t based on current_weights
arma::ivec counts = resampler->resample(1, current_weights);
// Draw b_t independently from the forward filtering weights W_t.
// Because cross-sectional users are conditionally independent given the
// static parameters, the backward weights reduce to W_t exactly.
arma::ivec counts = resampler->resample(1, stored_weights[t]);
unsigned int b_t = arma::as_scalar(arma::find(counts > 0, 1));

if (new_reference.latent_rankings.is_empty()) {
new_reference.latent_rankings =
particle_filters[b_t].latent_rankings.col(t);
if (parameters.tau.size() > 1) {
new_reference.cluster_assignments =
particle_filters[b_t].cluster_assignments.subvec(t, t);
new_reference.cluster_probabilities =
particle_filters[b_t].cluster_probabilities.cols(t, t);
new_reference.index = uvec(T + 1, fill::zeros);
}
} else {
new_reference.latent_rankings.insert_cols(
0, particle_filters[b_t].latent_rankings.col(t));
if (parameters.tau.size() > 1) {
new_reference.cluster_assignments.insert_rows(
0, particle_filters[b_t].cluster_assignments.subvec(t, t));
new_reference.cluster_probabilities.insert_cols(
0, particle_filters[b_t].cluster_probabilities.cols(t, t));
}
}

new_reference.latent_rankings.col(t) =
particle_filters[b_t].latent_rankings.col(t);
new_reference.log_weight(t) = particle_filters[b_t].log_weight(t);
if (multi_cluster) {
new_reference.cluster_assignments(t) =
particle_filters[b_t].cluster_assignments(t);
new_reference.cluster_probabilities.col(t) =
particle_filters[b_t].cluster_probabilities.col(t);
}
}

this->particle_filters[0] = new_reference;
Expand Down
5 changes: 5 additions & 0 deletions src/rejuvenate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ bool Particle::rejuvenate(unsigned int T, const Options &options,
this->log_normalized_particle_filter_weights =
proposal_particle.log_normalized_particle_filter_weights;
this->particle_filters = proposal_particle.particle_filters;
this->stored_weights = proposal_particle.stored_weights;
this->logz = proposal_particle.logz;
accepted = true;
} else {
Expand Down Expand Up @@ -143,6 +144,10 @@ bool Particle::rejuvenate(unsigned int T, const Options &options,
} else {
sample_particle_filter();
}
} else if (options.backward_sampling) {
// For single-cluster models, apply backward sampling to update the
// reference trajectory from the current particle filter weights.
this->assemble_backward_trajectory(T, resampler);
}

return accepted;
Expand Down
1 change: 0 additions & 1 deletion tests/testthat/test-compute_sequentially_complete.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ test_that("compute_sequentially works with complete data", {
resampler = "systematic")
)
alpha_hat <- weighted.mean(x = as.numeric(mod$alpha), w = mod$importance_weights)
alpha_hat <- weighted.mean(x = as.numeric(mod$alpha), w = mod$importance_weights)
expect_gt(alpha_hat, .02)
expect_lt(alpha_hat, .05)
})
Expand Down