diff --git a/R/set_smc_options.R b/R/set_smc_options.R index 6d64c16..425e7d8 100644 --- a/R/set_smc_options.R +++ b/R/set_smc_options.R @@ -5,6 +5,9 @@ #' particle #' @param max_particle_filters Maximum number of particle filters. #' @param resampling_threshold Effective sample size threshold for resampling +#' @param doubling_threshold Threshold for particle filter doubling. If the +#' acceptance rate of the rejuvenation step falls below this threshold, the +#' number of particle filters is doubled. Defaults to 0.2. #' @param max_rejuvenation_steps Maximum number of rejuvenation steps. If the #' number of unique particles has not exceeded half the number of particles #' after this many steps, the rejuvenation is still stopped. @@ -22,7 +25,8 @@ #' set_smc_options <- function( n_particles = 1000, n_particle_filters = 50, max_particle_filters = 10000, - resampling_threshold = n_particles / 2, max_rejuvenation_steps = 20, + resampling_threshold = n_particles / 2, doubling_threshold = .2, + max_rejuvenation_steps = 20, metric = "footrule", resampler = "multinomial", latent_rank_proposal = "uniform", verbose = FALSE, trace = FALSE, trace_latent = FALSE) { diff --git a/man/set_smc_options.Rd b/man/set_smc_options.Rd index a99625a..9eeb81f 100644 --- a/man/set_smc_options.Rd +++ b/man/set_smc_options.Rd @@ -9,6 +9,7 @@ set_smc_options( n_particle_filters = 50, max_particle_filters = 10000, resampling_threshold = n_particles/2, + doubling_threshold = 0.2, max_rejuvenation_steps = 20, metric = "footrule", resampler = "multinomial", @@ -28,6 +29,10 @@ particle} \item{resampling_threshold}{Effective sample size threshold for resampling} +\item{doubling_threshold}{Threshold for particle filter doubling. If the +acceptance rate of the rejuvenation step falls below this threshold, the +number of particle filters is doubled. Defaults to 0.2.} + \item{max_rejuvenation_steps}{Maximum number of rejuvenation steps. If the number of unique particles has not exceeded half the number of particles after this many steps, the rejuvenation is still stopped.} diff --git a/src/options.cpp b/src/options.cpp index fa08a7a..2405b01 100644 --- a/src/options.cpp +++ b/src/options.cpp @@ -9,6 +9,7 @@ Options::Options(const Rcpp::List& input_options) : max_particle_filters {input_options["max_particle_filters"]}, resampling_threshold{input_options["resampling_threshold"]}, max_rejuvenation_steps{input_options["max_rejuvenation_steps"]}, + doubling_threshold{input_options["doubling_threshold"]}, verbose{input_options["verbose"]}, trace{input_options["trace"]}, trace_latent{input_options["trace_latent"]}{} diff --git a/src/options.h b/src/options.h index f345677..db0ad85 100644 --- a/src/options.h +++ b/src/options.h @@ -14,6 +14,7 @@ struct Options{ unsigned int max_particle_filters; unsigned int resampling_threshold; unsigned int max_rejuvenation_steps; + const double doubling_threshold; const bool verbose; const bool trace; const bool trace_latent; diff --git a/src/run_smc.cpp b/src/run_smc.cpp index 2530dd2..74c2dbb 100644 --- a/src/run_smc.cpp +++ b/src/run_smc.cpp @@ -91,7 +91,7 @@ Rcpp::List run_smc( double acceptance_rate = accepted / particle_vector.size() / iter; reporter.report_acceptance_rate(acceptance_rate); - if(acceptance_rate < 0.2 && options.n_particle_filters < options.max_particle_filters) { + if(acceptance_rate < options.doubling_threshold && options.n_particle_filters < options.max_particle_filters) { for(auto& p : particle_vector) { double log_Z_old = compute_log_Z(p.particle_filters, t);