This vignette explores practical ways to make vrcmort
fits faster, focusing on the no-covariate setting (~ 1) and
small examples that run quickly.
Simulate a tiny example
We use a small dataset that is still large enough for speed differences between model options to be visible:
- 5 regions
- 40 time points
- conflict begins at time 21 (20 pre, 20 post)
- 8 age groups
Switches we will benchmark
We benchmark four expensive/optional components:
-
A: region-specific time random-walk deviations
(
time = "region"vs"national") -
B: region-varying conflict slopes
(
conflict = "region"vs"fixed") -
C: pre-conflict reporting estimated vs fixed to 1
(
pre_conflict_reporting = "estimate"vs"fixed1") -
D: generated quantities enabled vs disabled
(
generated_quantities = "full"vs"none")
These options are all switchable via the high-level interface.
Benchmark: six runs
We run exactly six fits:
- full model (A+B+C+D enabled)
- quickest model (A+B+C+D disabled)
- full model minus one component: -A, -B, -C, -D
invisible(vrc_model("vr_reporting_model"))
invisible(vrc_model("vr_reporting_model_rho1_pre"))
invisible(vrc_model("vr_reporting_model_nogq"))
invisible(vrc_model("vr_reporting_model_rho1_pre_nogq"))
iter_bench <- if (FAST_DEBUG) 40 else 100
run_one <- function(label,
mort,
rep,
pre_conflict_reporting,
generated_quantities,
iter = iter_bench,
chains = 1,
seed = 123) {
cat("\n\n---\n")
cat(label, "\n")
fit <- vrcm(
mortality = mort,
reporting = rep,
data = df,
t0 = sim_fast$meta$t0,
pre_conflict_reporting = pre_conflict_reporting,
generated_quantities = generated_quantities,
chains = chains,
iter = iter,
seed = seed,
refresh = 0
)
et <- rstan::get_elapsed_time(fit$stanfit)
et <- as.matrix(et)
warmup_seconds <- sum(et[, "warmup"])
sample_col <- if ("sampling" %in% colnames(et)) {
"sampling"
} else if ("sample" %in% colnames(et)) {
"sample"
} else {
stop("Unexpected column names from rstan::get_elapsed_time().")
}
sampling_seconds <- sum(et[, sample_col])
sp <- rstan::get_sampler_params(fit$stanfit, inc_warmup = FALSE)[[1]]
sp <- as.matrix(sp)
if (!all(c("n_leapfrog__", "treedepth__", "divergent__") %in% colnames(sp))) {
stop("Unexpected column names from rstan::get_sampler_params().")
}
sum_leapfrog <- sum(sp[, "n_leapfrog__"])
mean_leapfrog <- mean(sp[, "n_leapfrog__"])
max_treedepth <- max(sp[, "treedepth__"])
divergences <- sum(sp[, "divergent__"])
print(fit)
gc()
data.frame(
model = label,
warmup_seconds = warmup_seconds,
sampling_seconds = sampling_seconds,
total_seconds = warmup_seconds + sampling_seconds,
mean_n_leapfrog = mean_leapfrog,
sum_n_leapfrog = sum_leapfrog,
seconds_per_leapfrog = if (sum_leapfrog > 0) sampling_seconds / sum_leapfrog else NA_real_,
max_treedepth = max_treedepth,
divergences = divergences,
row.names = NULL,
stringsAsFactors = FALSE
)
}
mort_full <- vrc_mortality(~ 1, conflict = "region", time = "region")
rep_full <- vrc_reporting(~ 1, conflict = "region", time = "region")
mort_fast <- vrc_mortality(~ 1, conflict = "fixed", time = "national")
rep_fast <- vrc_reporting(~ 1, conflict = "fixed", time = "national")
r_full <- run_one(
"full (A+B+C+D)",
mort = mort_full,
rep = rep_full,
pre_conflict_reporting = "estimate",
generated_quantities = "full"
)
r_quick <- run_one(
"quickest (-A-B-C-D)",
mort = mort_fast,
rep = rep_fast,
pre_conflict_reporting = "fixed1",
generated_quantities = "none"
)
r_minus_A <- run_one(
"full - A (no region time deviations)",
mort = vrc_mortality(~ 1, conflict = "region", time = "national"),
rep = vrc_reporting(~ 1, conflict = "region", time = "national"),
pre_conflict_reporting = "estimate",
generated_quantities = "full"
)
r_minus_B <- run_one(
"full - B (fixed conflict slopes)",
mort = vrc_mortality(~ 1, conflict = "fixed", time = "region"),
rep = vrc_reporting(~ 1, conflict = "fixed", time = "region"),
pre_conflict_reporting = "estimate",
generated_quantities = "full"
)
r_minus_C <- run_one(
"full - C (pre-conflict rho fixed to 1)",
mort = mort_full,
rep = rep_full,
pre_conflict_reporting = "fixed1",
generated_quantities = "full"
)
r_minus_D <- run_one(
"full - D (no generated quantities)",
mort = mort_full,
rep = rep_full,
pre_conflict_reporting = "estimate",
generated_quantities = "none"
)
runtime <- rbind(r_full, r_quick, r_minus_A, r_minus_B, r_minus_C, r_minus_D)
runtime$speedup_vs_full <- runtime$total_seconds[1] / runtime$total_seconds
runtime