Skip to content

Commit 4239c90

Browse files
authored
Merge pull request #82 from osorensen/copilot/sub-pr-81
Fix O(T²) backward trajectory, extend backward sampling to single-cluster models
2 parents 281c43a + 7fb603f commit 4239c90

File tree

3 files changed

+27
-26
lines changed

3 files changed

+27
-26
lines changed

src/particle.cpp

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -111,39 +111,36 @@ void Particle::run_particle_filter(
111111

112112
void Particle::assemble_backward_trajectory(
113113
unsigned int T, const std::unique_ptr<Resampler> &resampler) {
114+
// Pre-allocate storage so each column/element is written in O(1) rather than
115+
// prepending with insert_cols/insert_rows, which would be O(T^2) overall.
116+
unsigned int n_items = parameters.rho.n_rows;
117+
bool multi_cluster = (parameters.tau.size() > 1);
114118

115119
ParticleFilter new_reference;
116120
new_reference.log_weight.resize(T + 1);
121+
new_reference.latent_rankings.set_size(n_items, T + 1);
122+
if (multi_cluster) {
123+
new_reference.cluster_assignments.set_size(T + 1);
124+
new_reference.cluster_probabilities.set_size(parameters.tau.size(), T + 1);
125+
new_reference.index = uvec(T + 1, fill::zeros);
126+
}
117127

118128
for (int t = T; t >= 0; --t) {
119-
arma::vec current_weights = stored_weights[t];
120-
121-
// Sample a single index b_t based on current_weights
122-
arma::ivec counts = resampler->resample(1, current_weights);
129+
// Draw b_t independently from the forward filtering weights W_t.
130+
// Because cross-sectional users are conditionally independent given the
131+
// static parameters, the backward weights reduce to W_t exactly.
132+
arma::ivec counts = resampler->resample(1, stored_weights[t]);
123133
unsigned int b_t = arma::as_scalar(arma::find(counts > 0, 1));
124134

125-
if (new_reference.latent_rankings.is_empty()) {
126-
new_reference.latent_rankings =
127-
particle_filters[b_t].latent_rankings.col(t);
128-
if (parameters.tau.size() > 1) {
129-
new_reference.cluster_assignments =
130-
particle_filters[b_t].cluster_assignments.subvec(t, t);
131-
new_reference.cluster_probabilities =
132-
particle_filters[b_t].cluster_probabilities.cols(t, t);
133-
new_reference.index = uvec(T + 1, fill::zeros);
134-
}
135-
} else {
136-
new_reference.latent_rankings.insert_cols(
137-
0, particle_filters[b_t].latent_rankings.col(t));
138-
if (parameters.tau.size() > 1) {
139-
new_reference.cluster_assignments.insert_rows(
140-
0, particle_filters[b_t].cluster_assignments.subvec(t, t));
141-
new_reference.cluster_probabilities.insert_cols(
142-
0, particle_filters[b_t].cluster_probabilities.cols(t, t));
143-
}
144-
}
145-
135+
new_reference.latent_rankings.col(t) =
136+
particle_filters[b_t].latent_rankings.col(t);
146137
new_reference.log_weight(t) = particle_filters[b_t].log_weight(t);
138+
if (multi_cluster) {
139+
new_reference.cluster_assignments(t) =
140+
particle_filters[b_t].cluster_assignments(t);
141+
new_reference.cluster_probabilities.col(t) =
142+
particle_filters[b_t].cluster_probabilities.col(t);
143+
}
147144
}
148145

149146
this->particle_filters[0] = new_reference;

src/rejuvenate.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ bool Particle::rejuvenate(unsigned int T, const Options &options,
100100
this->log_normalized_particle_filter_weights =
101101
proposal_particle.log_normalized_particle_filter_weights;
102102
this->particle_filters = proposal_particle.particle_filters;
103+
this->stored_weights = proposal_particle.stored_weights;
103104
this->logz = proposal_particle.logz;
104105
accepted = true;
105106
} else {
@@ -143,6 +144,10 @@ bool Particle::rejuvenate(unsigned int T, const Options &options,
143144
} else {
144145
sample_particle_filter();
145146
}
147+
} else if (options.backward_sampling) {
148+
// For single-cluster models, apply backward sampling to update the
149+
// reference trajectory from the current particle filter weights.
150+
this->assemble_backward_trajectory(T, resampler);
146151
}
147152

148153
return accepted;

tests/testthat/test-compute_sequentially_complete.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ test_that("compute_sequentially works with complete data", {
4040
resampler = "systematic")
4141
)
4242
alpha_hat <- weighted.mean(x = as.numeric(mod$alpha), w = mod$importance_weights)
43-
alpha_hat <- weighted.mean(x = as.numeric(mod$alpha), w = mod$importance_weights)
4443
expect_gt(alpha_hat, .02)
4544
expect_lt(alpha_hat, .05)
4645
})

0 commit comments

Comments
 (0)