Skip to contents

This vignette explores practical ways to make vrcmort fits faster, focusing on the no-covariate setting (~ 1) and small examples that run quickly.

Simulate a tiny example

We use a small dataset that is still large enough for speed differences between model options to be visible:

  • 5 regions
  • 40 time points
  • conflict begins at time 21 (20 pre, 20 post)
  • 8 age groups
library(vrcmort)

sim_fast <- vrc_simulate(
  R = R_sim,
  T = T_sim,
  t0 = t0_sim,
  age_breaks = age_breaks_sim,
  seed = 123,
  missing = list(type = "none"),
  rho0_true = c(0.99, 0.99)
)

df <- sim_fast$df_obs

head(df)
# Standata dimensions for reference
sd <- vrc_fit(
  data = df,
  t0 = sim_fast$meta$t0,
  chains = 0
)

sd$meta
c(N = sd$standata$N, R = sd$standata$R, T = sd$standata$T, A = sd$standata$A, S = sd$standata$S, G = sd$standata$G)

Switches we will benchmark

We benchmark four expensive/optional components:

  • A: region-specific time random-walk deviations (time = "region" vs "national")
  • B: region-varying conflict slopes (conflict = "region" vs "fixed")
  • C: pre-conflict reporting estimated vs fixed to 1 (pre_conflict_reporting = "estimate" vs "fixed1")
  • D: generated quantities enabled vs disabled (generated_quantities = "full" vs "none")

These options are all switchable via the high-level interface.

Benchmark: six runs

We run exactly six fits:

  • full model (A+B+C+D enabled)
  • quickest model (A+B+C+D disabled)
  • full model minus one component: -A, -B, -C, -D
invisible(vrc_model("vr_reporting_model"))
invisible(vrc_model("vr_reporting_model_rho1_pre"))
invisible(vrc_model("vr_reporting_model_nogq"))
invisible(vrc_model("vr_reporting_model_rho1_pre_nogq"))

iter_bench <- if (FAST_DEBUG) 40 else 100

run_one <- function(label,
                    mort,
                    rep,
                    pre_conflict_reporting,
                    generated_quantities,
                    iter = iter_bench,
                    chains = 1,
                    seed = 123) {
  cat("\n\n---\n")
  cat(label, "\n")
  fit <- vrcm(
    mortality = mort,
    reporting = rep,
    data = df,
    t0 = sim_fast$meta$t0,
    pre_conflict_reporting = pre_conflict_reporting,
    generated_quantities = generated_quantities,
    chains = chains,
    iter = iter,
    seed = seed,
    refresh = 0
  )

  et <- rstan::get_elapsed_time(fit$stanfit)
  et <- as.matrix(et)
  warmup_seconds <- sum(et[, "warmup"])
  sample_col <- if ("sampling" %in% colnames(et)) {
    "sampling"
  } else if ("sample" %in% colnames(et)) {
    "sample"
  } else {
    stop("Unexpected column names from rstan::get_elapsed_time().")
  }
  sampling_seconds <- sum(et[, sample_col])

  sp <- rstan::get_sampler_params(fit$stanfit, inc_warmup = FALSE)[[1]]
  sp <- as.matrix(sp)
  if (!all(c("n_leapfrog__", "treedepth__", "divergent__") %in% colnames(sp))) {
    stop("Unexpected column names from rstan::get_sampler_params().")
  }
  sum_leapfrog <- sum(sp[, "n_leapfrog__"])
  mean_leapfrog <- mean(sp[, "n_leapfrog__"])
  max_treedepth <- max(sp[, "treedepth__"])
  divergences <- sum(sp[, "divergent__"])

  print(fit)
  gc()
  data.frame(
    model = label,
    warmup_seconds = warmup_seconds,
    sampling_seconds = sampling_seconds,
    total_seconds = warmup_seconds + sampling_seconds,
    mean_n_leapfrog = mean_leapfrog,
    sum_n_leapfrog = sum_leapfrog,
    seconds_per_leapfrog = if (sum_leapfrog > 0) sampling_seconds / sum_leapfrog else NA_real_,
    max_treedepth = max_treedepth,
    divergences = divergences,
    row.names = NULL,
    stringsAsFactors = FALSE
  )
}

mort_full <- vrc_mortality(~ 1, conflict = "region", time = "region")
rep_full  <- vrc_reporting(~ 1, conflict = "region", time = "region")

mort_fast <- vrc_mortality(~ 1, conflict = "fixed", time = "national")
rep_fast  <- vrc_reporting(~ 1, conflict = "fixed", time = "national")

r_full <- run_one(
  "full (A+B+C+D)",
  mort = mort_full,
  rep = rep_full,
  pre_conflict_reporting = "estimate",
  generated_quantities = "full"
)

r_quick <- run_one(
  "quickest (-A-B-C-D)",
  mort = mort_fast,
  rep = rep_fast,
  pre_conflict_reporting = "fixed1",
  generated_quantities = "none"
)

r_minus_A <- run_one(
  "full - A (no region time deviations)",
  mort = vrc_mortality(~ 1, conflict = "region", time = "national"),
  rep = vrc_reporting(~ 1, conflict = "region", time = "national"),
  pre_conflict_reporting = "estimate",
  generated_quantities = "full"
)

r_minus_B <- run_one(
  "full - B (fixed conflict slopes)",
  mort = vrc_mortality(~ 1, conflict = "fixed", time = "region"),
  rep = vrc_reporting(~ 1, conflict = "fixed", time = "region"),
  pre_conflict_reporting = "estimate",
  generated_quantities = "full"
)

r_minus_C <- run_one(
  "full - C (pre-conflict rho fixed to 1)",
  mort = mort_full,
  rep = rep_full,
  pre_conflict_reporting = "fixed1",
  generated_quantities = "full"
)

r_minus_D <- run_one(
  "full - D (no generated quantities)",
  mort = mort_full,
  rep = rep_full,
  pre_conflict_reporting = "estimate",
  generated_quantities = "none"
)

runtime <- rbind(r_full, r_quick, r_minus_A, r_minus_B, r_minus_C, r_minus_D)
runtime$speedup_vs_full <- runtime$total_seconds[1] / runtime$total_seconds
runtime