# Variational Inference in Stan#

Variational inference is a scalable technique for approximate Bayesian inference. Stan implements an automatic variational inference algorithm, called Automatic Differentiation Variational Inference (ADVI) which searches over a family of simple densities to find the best approximate posterior density. ADVI produces an estimate of the parameter means together with a sample from the approximate posterior density.

ADVI approximates the variational objective function, the evidence lower bound or ELBO, using stochastic gradient ascent. The algorithm ascends these gradients using an adaptive stepsize sequence that has one parameter eta which is adjusted during warmup. The number of draws used to approximate the ELBO is denoted by elbo_samples. ADVI heuristically determines a rolling window over which it computes the average and the median change of the ELBO. When this change falls below a threshold, denoted by tol_rel_obj, the algorithm is considered to have converged.

## Example: variational inference for model bernoulli.stan#

In CmdStanPy, the CmdStanModel class method variational invokes CmdStan with method=variational and returns an estimate of the approximate posterior mean of all model parameters as well as a set of draws from this approximate posterior.

[1]:

import os
from cmdstanpy.model import CmdStanModel
from cmdstanpy.utils import cmdstan_path

bernoulli_dir = os.path.join(cmdstan_path(), 'examples', 'bernoulli')
stan_file = os.path.join(bernoulli_dir, 'bernoulli.stan')
data_file = os.path.join(bernoulli_dir, 'bernoulli.data.json')
# instantiate, compile bernoulli model
model = CmdStanModel(stan_file=stan_file)
# run CmdStan's variational inference method, returns object CmdStanVB
vi = model.variational(data=data_file)

20:39:03 - cmdstanpy - INFO - Chain [1] start processing
20:39:03 - cmdstanpy - INFO - Chain [1] done processing


The class CmdStanVB <https://cmdstanpy.readthedocs.io/en/latest/api.html#stanvariational>__ provides the following properties to access information about the parameter names, estimated means, and the sample: + column_names + variational_params_dict + variational_params_np + variational_params_pd + variational_sample

[2]:

print(vi.column_names)

('lp__', 'log_p__', 'log_g__', 'theta')

[3]:

print(vi.variational_params_dict['theta'])

0.245642

[4]:

print(vi.variational_sample.shape)

(1000, 4)


These estimates are only valid if the algorithm has converged to a good approximation. When the algorithm fails to do so, the variational method will throw a RuntimeError.

[5]:

model_fail = CmdStanModel(stan_file='eta_should_fail.stan')
vi_fail = model_fail.variational()

20:39:03 - cmdstanpy - INFO - compiling stan file /home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.4/docsrc/examples/eta_should_fail.stan to exe file /home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.4/docsrc/examples/eta_should_fail
20:39:26 - cmdstanpy - INFO - compiled model executable: /home/docs/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.4/docsrc/examples/eta_should_fail
20:39:26 - cmdstanpy - INFO - Chain [1] start processing
20:39:26 - cmdstanpy - INFO - Chain [1] done processing

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [5], in <cell line: 2>()
1 model_fail = CmdStanModel(stan_file='eta_should_fail.stan')
----> 2 vi_fail = model_fail.variational()

File ~/checkouts/readthedocs.org/user_builds/cmdstanpy/checkouts/v1.0.4/cmdstanpy/model.py:1501, in CmdStanModel.variational(self, data, seed, inits, output_dir, sig_figs, save_latent_dynamics, save_profile, algorithm, iter, grad_samples, elbo_samples, eta, adapt_engaged, adapt_iter, tol_rel_obj, eval_elbo, output_samples, require_converged, show_console, refresh, time_fmt)
1499 if len(re.findall(pat, contents)) > 0:
1500     if require_converged:
-> 1501         raise RuntimeError(
1502             'The algorithm may not have converged.\n'
1503             'If you would like to inspect the output, '
1504             're-call with require_converged=False'
1505         )
1506     # else:
1507     get_logger().warning(
1508         '%s\n%s',
1509         'The algorithm may not have converged.',
1510         'Proceeding because require_converged is set to False',
1511     )

RuntimeError: The algorithm may not have converged.
If you would like to inspect the output, re-call with require_converged=False


Unless you set require_converged=False:

[6]:

vi_fail = model_fail.variational(require_converged=False)

20:39:27 - cmdstanpy - INFO - Chain [1] start processing
20:39:27 - cmdstanpy - INFO - Chain [1] done processing
20:39:27 - cmdstanpy - WARNING - The algorithm may not have converged.
Proceeding because require_converged is set to False


This lets you inspect the output to try to diagnose the issue with the model

[7]:

vi_fail.variational_params_dict

[7]:

OrderedDict([('lp__', 0.0),
('log_p__', 0.0),
('log_g__', 0.0),
('mu[1]', 0.0402514),
('mu[2]', 0.0152227)])


See the api docs, section CmdStanModel.variational <https://cmdstanpy.readthedocs.io/en/latest/api.html#cmdstanpy.CmdStanModel.variational>__ for a full description of all arguments.