Using Variational Estimates to Initialize the NUTS-HMC Sampler#

In this example we show how to use the parameter estimates return by Stan’s variational inference algorithm as the initial parameter values for Stan’s NUTS-HMC sampler. By default, the sampler algorithm randomly initializes all model parameters in the range uniform[-2, 2]. When the true parameter value is outside of this range, starting from the ADVI estimates will speed up and improve adaptation.

Model and data#

The Stan model and data are taken from the posteriordb package.

We use the blr model, a Bayesian standard linear regression model with noninformative priors, and its corresponding simulated dataset sblri.json, which was simulated via script sblr.R. For conveince, we have copied the posteriordb model and data to this directory, in files blr.stan and sblri.json.

[1]:
import os
from cmdstanpy import CmdStanModel

stan_file = 'blr.stan' # basic linear regression
data_file = 'sblri.json' # simulated data

model = CmdStanModel(stan_file=stan_file)

print(model.code())
19:03:27 - cmdstanpy - INFO - compiling stan file /home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.6/docsrc/users-guide/examples/blr.stan to exe file /home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.6/docsrc/users-guide/examples/blr
19:03:48 - cmdstanpy - INFO - compiled model executable: /home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.6/docsrc/users-guide/examples/blr
data {
  int<lower=0> N;
  int<lower=0> D;
  matrix[N, D] X;
  vector[N] y;
}
parameters {
  vector[D] beta;
  real<lower=0> sigma;
}
model {
  // prior
  target += normal_lpdf(beta | 0, 10);
  target += normal_lpdf(sigma | 0, 10);
  // likelihood
  target += normal_lpdf(y | X * beta, sigma);
}


Run Stan’s variational inference algorithm, obtain fitted estimates#

The CmdStanModel method variational runs CmdStan’s ADVI algorithm. Because this algorithm is unstable and may fail to converge, we run it with argument require_converged set to False. We also specify a seed, to avoid instabilities as well as for reproducibility.

[2]:
vb_fit = model.variational(data=data_file, require_converged=False, seed=123)
19:03:48 - cmdstanpy - INFO - Chain [1] start processing
19:03:48 - cmdstanpy - INFO - Chain [1] done processing
19:03:48 - cmdstanpy - WARNING - The algorithm may not have converged.
Proceeding because require_converged is set to False

The ADVI algorithm provides estimates of all model parameters.

The variational method returns a CmdStanVB object, with method stan_variables, which returns the approximate estimates of all model parameters as a Python dictionary.

[3]:
print(vb_fit.stan_variables())
{'beta': array([0.997115, 0.993865, 0.991472, 0.993601, 1.0095  ]), 'sigma': 1.67}

Posteriordb provides reference posteriors for all models. For the blr model, conditioned on the dataset sblri.json, the reference posteriors are in file sblri-blr.json

The reference posteriors for all elements of beta and sigma are all very close to 1.0.

The experiments reported in the paper Pathfinder: Parallel quasi-Newton variational inference by Zhang et al. show that mean-field ADVI provides a better estimate of the posterior, as measured by the 1-Wasserstein distance to the reference posterior, than 75 iterations of the warmup Phase I algorithm used by the NUTS-HMC sampler, furthermore, ADVI is more computationally efficient, requiring fewer evaluations of the log density and gradient functions. Therefore, using the estimates from ADVI to initialize the parameter values for the NUTS-HMC sampler will allow the sampler to do a better job of adapting the stepsize and metric during warmup, resulting in better performance and estimation.

[4]:
vb_vars = vb_fit.stan_variables()
mcmc_vb_inits_fit = model.sample(
    data=data_file, inits=vb_vars, iter_warmup=75, seed=12345
)
19:03:48 - cmdstanpy - INFO - CmdStan start processing

19:03:48 - cmdstanpy - INFO - CmdStan done processing.
19:03:48 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.6/docsrc/users-guide/examples/blr.stan', line 16, column 2 to column 45)
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.6/docsrc/users-guide/examples/blr.stan', line 16, column 2 to column 45)
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.6/docsrc/users-guide/examples/blr.stan', line 16, column 2 to column 45)
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in '/home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.6/docsrc/users-guide/examples/blr.stan', line 16, column 2 to column 45)
Consider re-running with show_console=True if the above output is unclear!

[5]:
mcmc_vb_inits_fit.summary()
[5]:
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -156.867000 0.056908 1.743290 -160.145000 -156.532000 -154.64600 938.42000 994.089000 1.001480
beta[1] 0.999474 0.000014 0.000952 0.997903 0.999461 1.00107 4863.49000 5152.000000 1.000320
beta[2] 1.000230 0.000018 0.001160 0.998355 1.000210 1.00210 4344.92000 4602.670000 0.999725
beta[3] 1.000440 0.000014 0.000938 0.998907 1.000430 1.00200 4385.54000 4645.700000 0.999669
beta[4] 1.001160 0.000016 0.001064 0.999422 1.001150 1.00292 4664.71000 4941.430000 0.999536
beta[5] 1.001540 0.000015 0.001033 0.999865 1.001540 1.00321 4786.98000 5070.960000 0.999197
sigma 0.963840 0.004465 0.071505 0.849600 0.961783 1.09259 256.47117 271.685561 1.011019

The sampler estimates match the reference posterior.

[6]:
print(mcmc_vb_inits_fit.diagnose())
Processing csv files: /tmp/tmp9rnwz9nb/blr8dooj120/blr-20220823190348_1.csv, /tmp/tmp9rnwz9nb/blr8dooj120/blr-20220823190348_2.csv, /tmp/tmp9rnwz9nb/blr8dooj120/blr-20220823190348_3.csv, /tmp/tmp9rnwz9nb/blr8dooj120/blr-20220823190348_4.csv

Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.

Checking sampler transitions for divergences.
No divergent transitions found.

Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

Effective sample size satisfactory.

Split R-hat values satisfactory all parameters.

Processing complete, no problems detected.

Using the default random parameter initializations, we need to run more warmup iteratons. If we only run 75 warmup iterations with random inits, the result fails to estimate sigma correctly. It is necessary to run the model with at least 150 warmup iterations to produce a good set of estimates.

[7]:
mcmc_random_inits_fit = model.sample(data=data_file, iter_warmup=75, seed=12345)
19:03:49 - cmdstanpy - INFO - CmdStan start processing

19:03:49 - cmdstanpy - INFO - CmdStan done processing.
19:03:49 - cmdstanpy - WARNING - Some chains may have failed to converge.
        Chain 1 had 161 divergent transitions (16.1%)
        Chain 3 had 147 divergent transitions (14.7%)
        Chain 4 had 244 divergent transitions (24.4%)
        Use function "diagnose()" to see further information.

[8]:
mcmc_random_inits_fit.summary()
[8]:
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -191.333000 24.678800 35.170300 -231.560000 -165.541000 -154.33900 2.03097 6.15447 11.37150
beta[1] 0.999452 0.000119 0.001816 0.996272 0.999494 1.00252 232.18100 703.57900 1.01286
beta[2] 1.000560 0.000229 0.002416 0.996529 1.000410 1.00459 110.88200 336.00700 1.04571
beta[3] 1.000590 0.000259 0.002043 0.997326 1.000480 1.00442 62.47110 189.30600 1.04607
beta[4] 1.001380 0.000224 0.002279 0.997013 1.001690 1.00512 103.34500 313.16600 1.09049
beta[5] 1.001200 0.000150 0.002013 0.997854 1.001290 1.00443 180.70500 547.59200 1.03165
sigma 1.962000 0.725020 1.034300 0.907470 2.708830 3.17346 2.03514 6.16709 10.50420
[9]:
print(mcmc_random_inits_fit.diagnose())
Processing csv files: /tmp/tmp9rnwz9nb/blrurfptacv/blr-20220823190349_1.csv, /tmp/tmp9rnwz9nb/blrurfptacv/blr-20220823190349_2.csv, /tmp/tmp9rnwz9nb/blrurfptacv/blr-20220823190349_3.csv, /tmp/tmp9rnwz9nb/blrurfptacv/blr-20220823190349_4.csv

Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.

Checking sampler transitions for divergences.
552 of 4000 (13.80%) transitions ended with a divergence.
These divergent transitions indicate that HMC is not fully able to explore the posterior distribution.
Try increasing adapt delta closer to 1.
If this doesn't remove all divergences, try to reparameterize the model.

Checking E-BFMI - sampler transitions HMC potential energy.
The E-BFMI, 0.01, is below the nominal threshold of 0.30 which suggests that HMC may have trouble exploring the target distribution.
If possible, try to reparameterize the model.

The following parameters had fewer than 0.001 effective draws per transition:
  sigma
Such low values indicate that the effective sample size estimators may be biased high and actual performance may be substantially lower than quoted.

The following parameters had split R-hat greater than 1.05:
  beta[4], sigma
Such high values indicate incomplete mixing and biased estimation.
You should consider regularizating your model with additional prior information or a more effective parameterization.

Processing complete.