Parallel tempering for bimodal posteriors (M2)

This vignette is the first chapter of a worked-case companion to the package: one model at a time, the question the user faces, the failure of the default sampler, and the call that fixes it. The model here is M2 of the registered benchmark of gpumetropolis, a separated bimodal posterior, and the call that fixes it is gpu_metropolis(method = "pt"), the v0.3.0 release.

The model

Observations \(y_1, \ldots, y_N\) are drawn from \(\mathcal{N}(|\mu|,\, \sigma)\) with \(\sigma\) known. The likelihood depends on \(\mu\) only through \(|\mu|\), so the posterior of \(\mu\) is symmetric and bimodal with modes near \(+c\) and \(-c\) where \(c\) is positive and close to \(\mathrm{mean}(|y|)\). A single random-walk chain settles in one basin; crossing the low-density region around \(\mu = 0\) is exponentially rare once the modes separate.

library(gpumetropolis)
set.seed(11)
y <- rnorm(400, mean = 3, sd = 1)     # truth: mu = +3
model <- gpum_model(
  loglik = ~ log(exp(-((y - mu)^2) / 2) + exp(-((y + mu)^2) / 2)),
  params = "mu", data = "y"
)
init <- matrix(seq(-6, 6, length.out = 8), nrow = 8, ncol = 1L)

The failure mode of plain random-walk Metropolis

The default method = "rwm" with eight overdispersed chains looks at first like it covers both modes, because the pooled draws across chains include both \(+3\) and \(-3\):

fit_rwm <- gpu_metropolis(model, data = list(y = y), init = init,
                          proposal_sd = 0.15, n_iter = 4000,
                          seed = 1, backend = "cpu")
fit_rwm
#> <gpum_fit>
#>   parameters  : mu
#>   method      : rwm
#>   backend     : cpu
#>   chains      : 8
#>   iterations  : 2000 per chain (4000 raw, 2000 adaptive warmup discarded)
#>   accept_rate : 0.474 to 0.581
#>   mu          : posterior mean -0.0007 (sd 2.9830)

The acceptance rate is healthy and the posterior mean is near zero, the average of the two modes. The trap is that each chain is stuck in the basin near its starting point: a chain seeded at \(-6\) converges to \(-3\) and stays there; a chain seeded at \(+6\) converges to \(+3\) and stays there. The mode-crossing event has probability proportional to \(\exp(-(\text{barrier height}))\) per step, vanishing for the M2 separation.

The diagnostic that catches this is the split R-hat, not the pooled KS test against the reference. The pooled KS statistic averages the two within-mode CDFs and looks correct; R-hat compares within-chain and between-chain variances and reports the mismatch.

rhat(fit_rwm$draws[, , "mu"], warmup = 0)
#> [1] 61.49895

R-hat far above one is the honest verdict on this run: the chains did not mix.

Parallel tempering

method = "pt" adds an auxiliary set of chains at higher temperatures on the same target. The hot chains accept moves more easily because their tempered acceptance ratio is \(\exp\bigl((\log\pi(y') - \log\pi(y)) / T_c\bigr)\) with \(T_c > 1\); the cold chain, at \(T = 1\), samples the actual posterior. Between batches an adjacent-pair swap step proposes exchanges of states, accepted with the ratio \(\exp\bigl((\log\pi(x_{c+1}) - \log\pi(x_c)) \cdot (1/T_c - 1/T_{c+1})\bigr)\). The cold chain inherits mode-crossings from the hot ones through accepted swaps; the hot chains keep mixing because their tempered target is substantially flatter than the cold one.

fit_pt <- gpu_metropolis(model, data = list(y = y), init = init,
                         proposal_sd = 0.15, n_iter = 4000,
                         method = "pt", seed = 1, backend = "cpu")
fit_pt
#> <gpum_fit>
#>   parameters  : mu
#>   method      : pt
#>   backend     : cpu
#>   chains      : 8
#>   iterations  : 2000 per chain (4000 raw, 2000 adaptive warmup discarded)
#>   accept_rate : 0.392 to 0.431
#>   swap accept : pairs (1-7) mean 0.885 to 0.920
#>   mu (T=1)    : posterior mean 0.5944 (sd 2.9210)

The print method labels the run as parallel tempering, shows the band of swap acceptances across pairs, and reports the posterior summary of the cold chain alone. R-hat on the cold chain is near one:

rhat(matrix(fit_pt$draws[, 1L, "mu"], ncol = 1L), warmup = 0)
#> [1] 1.001213

The full diagnostic is one call away:

gpum_diagnose(fit_pt)
#> <gpum_diagnose: Inconclusive>
#>   method pt, cold chain, backend cpu, chains 8, iterations 2000 (warmup 2000, adaptive)
#> 
#>  parameter   mean     sd    q2.5    q50  q97.5   Rhat     ESS   MCSE
#>         mu 0.5944 2.9210 -3.0570 2.9339 3.0524 1.0012 81.4985 0.3236
#> 
#>   Increase n_iter to raise the effective sample size.

The verdict bar at the top of the table reads parallel tempering, the diagnostic table summarises the cold chain only, and the extra row of plots shows the swap acceptance per adjacent pair across the warmup batches with the asymptotic optimum of \(0.234\) drawn as a reference. When the M2 posterior is fully symmetric the swap acceptance sits well above the optimum, which is harmless: a symmetric target makes most swaps trivial.

Inspecting the cold chain directly

Column 1 of the draws array is the cold chain, the chain at temperature \(T = 1\), and it is the only chain that targets the actual posterior; the other columns are auxiliary hot chains that feed it through swaps and are not for inference. So the post-warmup posterior is the slice fit_pt$draws[, 1L, ]. A histogram makes the two modes plain:

cold <- fit_pt$draws[, 1L, "mu"]
hist(cold, breaks = 60, freq = FALSE,
     main = "Posterior of mu under parallel tempering (cold chain)",
     xlab = "mu", col = "grey85", border = "white")
abline(v = c(-3, 3), col = "red", lty = 2)

A bimodal posterior is also where the credible interval hdi() must be read with care. A single 95% interval here spans the gap between the two modes, including the low-density region around zero that the chain rarely visits, so the interval is wide and not a summary of either mode. The lesson is general: hdi() is a one-interval summary and is only faithful when the posterior is unimodal.

hdi(cold, ci = 0.95)         # spans both modes: not a summary of either one
#>     lower     upper 
#> -3.039638  3.067423

A trace shows that the cold chain repeatedly crosses zero, a clear sign of true mode mixing:

matplot(fit_pt$draws[, , "mu"], type = "l", lty = 1,
        xlab = "iteration (post-warmup)", ylab = "mu",
        main = "All chains; cold chain in black")

Tuning notes for parallel tempering

When the default ladder is too dense the swap acceptance is high and the cold chain still mixes well, just with no efficiency lost. When the ladder is too sparse the swap acceptance drops; the gpum_diagnose plot fires a hint when any adjacent pair averages below \(0.10\).

The two practical knobs are temperatures and swap_every. A wider ladder reduces the autocorrelation of the cold chain at the cost of additional parallel-chain effort:

fit_pt_wide <- gpu_metropolis(
  model, data = list(y = y), init = init,
  proposal_sd = 0.15, n_iter = 4000,
  method = "pt",
  temperatures = c(1, 2, 4, 8, 16, 32, 64, 128),
  swap_every = 5,
  seed = 1, backend = "cpu"
)

A smaller swap_every increases the rate at which the cold chain inherits new states from the hot chains, often the most cost-effective adjustment on symmetric multimodal targets.

What the focused re-run shows in numbers

The accompanying benchmark/run_m2_pt.R runs the same cell as above with twenty replications and compares three adapters: gpumetropolis-cpu, gpumetropolis-cpu-pt and nimble. Averaged over the twenty replications, parallel tempering reaches R-hat near \(1.00\) where the random-walk gpumetropolis and nimble both stay at R-hat near \(62\). The trade is a \(1.7\)x wall-clock factor for the host-side swap step and a lower nominal effective sample size, since each cold-chain draw now carries the autocorrelation of a true mode-crossing chain rather than the autocorrelation of a stuck one. The full numbers are recorded as the v0.9 amendment of the experiment protocol.

When to reach for parallel tempering

The honest rule of thumb: random-walk Metropolis on a unimodal target with moderate posterior contrast is well served by the default method = "rwm" with adaptive warmup. Parallel tempering pays off when:

  1. The posterior is multimodal and the modes are separated by low-density regions that random walk cannot cross within the iteration budget.
  2. The user can spare a small number of hot chains; the package supports up to dozens with no extra coordination.
  3. The cost of a \(1.5\) to \(2\)x wall-clock factor is acceptable relative to the value of a converged R-hat.

When none of those hold, rwm is the right default. The package leaves the choice with the user through the method argument.