Variational Bayes for Non-Conjugate Models: A Laplace-Delta Tutorial
20 Mar 2025Introduction
This tutorial presents a comprehensive guide to the Variational Bayes with Laplace-Delta (VB-LD) method, a framework for approximate Bayesian inference in models containing non-conjugate parameter relationships. VB-LD addresses the challenge of efficiently approximating posterior distributions when conditional conjugacy is absent without resorting to computationally intensive or unstable alternatives, such as Importance Sampling.
Methodology Overview
The VB-LD approach combines two techniques:
- Laplace Approximation for Non-Conjugate Updates:
- Provides a Gaussian approximation to intractable variational distributions by matching the mode and curvature of the log-posterior
- Particularly effective when the true posterior is unimodal and approximately Gaussian near its peak
- Avoids sampling-based approaches that may suffer from high variance or particle degeneracy
- Delta Method for Expectation Propagation:
- Enables closed-form approximation of necessary moments through second-order Taylor expansion
- Crucial for maintaining the mean-field factorization when updating conjugate parameters
1. Preliminaries
1.1 Mean-Field Variational Bayes (MFVB)
Consider the observed data \( D \in \mathcal{D} \) and a parameter vector \( \boldsymbol{\theta} = (\boldsymbol{\eta}, \boldsymbol{\psi}) \in \Theta \) where:
- \( \boldsymbol{\eta} \): non-conjugate parameters
- \( \boldsymbol{\psi} \): conjugate parameters
MFVB posits a structured variational family \( \mathcal{Q} \) of distributions that factorize as:
\[\mathcal{Q} = \{ q : q(\boldsymbol{\theta}) = q_{\boldsymbol{\eta}}(\boldsymbol{\eta})q_{\boldsymbol{\psi}}(\boldsymbol{\psi}), q_{\boldsymbol{\eta}} \in \mathcal{Q}_{\boldsymbol{\eta}}, q_{\boldsymbol{\psi}} \in \mathcal{Q}_{\boldsymbol{\psi}} \}\]The optimal approximation \( q^* \in \mathcal{Q} \) minimizes the Kullback-Leibler divergence:
\[q^* = \text{arg} min_{q \in \mathcal{Q}} \text{KL}(q\|p) = \text{arg} min_{q \in \mathcal{Q}} \int q(\boldsymbol{\theta}) \log \frac{q(\boldsymbol{\theta})}{p(\boldsymbol{\theta}|D)} d\boldsymbol{\theta}\]Equivalently, we maximize the Evidence Lower Bound (ELBO):
\[ \text{ELBO}(q) = \mathbb{E}_q[\log p(D, \boldsymbol{\theta})] - \mathbb{E}_q[\log q(\boldsymbol{\theta})] \]
where:
\[ \mathbb{E}_q[\log p(D, \boldsymbol{\theta})] := \text{Expected complete-data log-likelihood} \]
\[ \mathbb{E}_q[\log q(\boldsymbol{\theta})] := \text{Entropy of } q \]
Optimality Conditions
The stationary points of the ELBO satisfy:
$$ \log q^*(\boldsymbol{\eta}) = \mathbb{E}_{q^*(\boldsymbol{\psi})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})] + \text{const.} $$ $$ \log q^*(\boldsymbol{\psi}) = \mathbb{E}_{q^*(\boldsymbol{\eta})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})] + \text{const.} $$1.2 Coordinate Ascent Variational Inference (CAVI)
The CAVI algorithm iteratively optimizes each variational factor while holding others fixed:
Algorithm 1: CAVI for MFVB
- Initialize $ q^{(0)}(\boldsymbol{\eta}) $, $ q^{(0)}(\boldsymbol{\psi}) $ by their setting hyperparameters
- While stopping criteria not met:
- Update non-conjugate parameters: $$ q^{(k)}(\boldsymbol{\eta}) \propto \exp\left(\mathbb{E}_{q^{(k-1)}(\boldsymbol{\psi})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})]\right) $$
- Update conjugate parameters: $$ q^{(k)}(\boldsymbol{\psi}) \propto \exp\left(\mathbb{E}_{q^{(k)}(\boldsymbol{\eta})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})]\right) $$
1.3 Non-Conjugate Challenges
When \( \boldsymbol{\eta} \) is non-conjugate, the update:
\[ q(\boldsymbol{\eta}) \propto \exp\left(\mathbb{E}_{q(\boldsymbol{\psi})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})]\right) \equiv \exp(f(\boldsymbol{\eta})) \]
faces three fundamental difficulties:
- The normalizing constant \( Z = \int \exp(f(\boldsymbol{\eta})) d\boldsymbol{\eta} \) is typically intractable
- Required expectations \( \mathbb{E}_{q(\boldsymbol{\eta})}[g(\boldsymbol{\eta})] \) lack closed forms
- The ELBO cannot be evaluated exactly
Some Solutions
- Importance Sampling... but suffers from high variance in high dimensions and is prone to particle collapse
- Quadrature Methods... but exponential complexity in $ k = \dim(\boldsymbol{\eta}) $
- Conjugate Modifications... but restrictive modeling assumptions
2. VB-LD Algorithm
2.1 Laplace Approximation Step
Given the variational distribution:
\[ q(\boldsymbol{\eta}) \propto \exp\left(\mathbb{E}_{q(\boldsymbol{\psi})}[\log p(D,\boldsymbol{\eta},\boldsymbol{\psi})]\right) \equiv \exp(f(\boldsymbol{\eta})) \]
- Mode Finding:
- Compute gradient \( \nabla f(\boldsymbol{\eta}) \)
- Solve \( \nabla f(\hat{\boldsymbol{\eta}}) = 0 \) via Newton-Raphson: \[ \boldsymbol{\eta}^{(k+1)} = \boldsymbol{\eta}^{(k)} - [\nabla^2 f(\boldsymbol{\eta}^{(k)})]^{-1}\nabla f(\boldsymbol{\eta}^{(k)}) \]
- Terminate when \( |\boldsymbol{\eta}^{(k+1)} - \boldsymbol{\eta}^{(k)}| < \epsilon \)
- Covariance Approximation: \[ \Sigma = -\left[\nabla^2 f(\hat{\boldsymbol{\eta}})\right]^{-1} \] where the Hessian is computed via: \[ [\nabla^2 f(\boldsymbol{\eta})]_{ij} = \frac{\partial^2 f}{\partial \eta_i \partial \eta_j} \]
Laplacian approximation
$$ q(\boldsymbol{\eta}) \approx \mathcal{N}(\hat{\boldsymbol{\eta}}, \Sigma) $$2.2 Delta Method for Expectation Approximation
When updating \( q(\boldsymbol{\psi}) \), we are required to compute expectations of the form:
\[ \mathbb{E}_{q(\boldsymbol{\eta})}[g(\boldsymbol{\eta})] \]
To do so, we approximate such quantities as follows,
Delta method (second-order Taylor approximation)
$$ \mathbb{E}[g(\boldsymbol{\eta})] \approx g(\hat{\boldsymbol{\eta}}) + \frac{1}{2}\text{tr}\left[\nabla^2 g(\hat{\boldsymbol{\eta}})\Sigma\right] $$2.3 Complete VB-LD Algorithm
Algorithm 2: VB-LD Implementation
- Initialize $ q^{(0)}(\boldsymbol{\psi}) $
- While stopping criteria not met:
- Laplace Step:
- Compute $ \hat{\boldsymbol{\eta}}^{(k)} = \text{arg} max f(\boldsymbol{\eta}) $
- Evaluate $ \Sigma^{(k)} = -[\nabla^2 f(\hat{\boldsymbol{\eta}}^{(k)})]^{-1} $
- Non-Conjugate Update: $ q^{(k)}(\boldsymbol{\eta}) = \mathcal{N}(\hat{\boldsymbol{\eta}}^{(k)}, \Sigma^{(k)}) $
- Delta Step:
- For all required $ g $, compute: $\mathbb{E}[g(\boldsymbol{\eta})] \approx g(\hat{\boldsymbol{\eta}}^{(k)}) + \frac{1}{2} \text{tr}[ \nabla^2 g(\hat{\boldsymbol{\eta}}^{(k)}) \Sigma^{(k)} ] $
- Conjugate Update: Use the previous quantity to update $ q^{(k)}(\boldsymbol{\psi}) $
- Laplace Step:
Further Reading
- Original VB-LD paper: Wang & Blei (2013)