Antonio Aguirre Data Modelling·Bayesian Statistics·Machine Learning

Variational Bayes for Non-Conjugate Models: A Laplace-Delta Tutorial

Introduction

This tutorial presents a comprehensive guide to the Variational Bayes with Laplace-Delta (VB-LD) method, a framework for approximate Bayesian inference in models containing non-conjugate parameter relationships. VB-LD addresses the challenge of efficiently approximating posterior distributions when conditional conjugacy is absent without resorting to computationally intensive or unstable alternatives, such as Importance Sampling.

Methodology Overview

The VB-LD approach combines two techniques:

  • Laplace Approximation for Non-Conjugate Updates:
    • Provides a Gaussian approximation to intractable variational distributions by matching the mode and curvature of the log-posterior
    • Particularly effective when the true posterior is unimodal and approximately Gaussian near its peak
    • Avoids sampling-based approaches that may suffer from high variance or particle degeneracy
  • Delta Method for Expectation Propagation:
    • Enables closed-form approximation of necessary moments through second-order Taylor expansion
    • Crucial for maintaining the mean-field factorization when updating conjugate parameters

1. Preliminaries

1.1 Mean-Field Variational Bayes (MFVB)

Consider the observed data \( D \in \mathcal{D} \) and a parameter vector \( \boldsymbol{\theta} = (\boldsymbol{\eta}, \boldsymbol{\psi}) \in \Theta \) where:

  • \( \boldsymbol{\eta} \): non-conjugate parameters
  • \( \boldsymbol{\psi} \): conjugate parameters

MFVB posits a structured variational family \( \mathcal{Q} \) of distributions that factorize as:

\[\mathcal{Q} = \{ q : q(\boldsymbol{\theta}) = q_{\boldsymbol{\eta}}(\boldsymbol{\eta})q_{\boldsymbol{\psi}}(\boldsymbol{\psi}), q_{\boldsymbol{\eta}} \in \mathcal{Q}_{\boldsymbol{\eta}}, q_{\boldsymbol{\psi}} \in \mathcal{Q}_{\boldsymbol{\psi}} \}\]

The optimal approximation \( q^* \in \mathcal{Q} \) minimizes the Kullback-Leibler divergence:

\[q^* = \text{arg} min_{q \in \mathcal{Q}} \text{KL}(q\|p) = \text{arg} min_{q \in \mathcal{Q}} \int q(\boldsymbol{\theta}) \log \frac{q(\boldsymbol{\theta})}{p(\boldsymbol{\theta}|D)} d\boldsymbol{\theta}\]

Equivalently, we maximize the Evidence Lower Bound (ELBO):

\[ \text{ELBO}(q) = \mathbb{E}_q[\log p(D, \boldsymbol{\theta})] - \mathbb{E}_q[\log q(\boldsymbol{\theta})] \]

where:

\[ \mathbb{E}_q[\log p(D, \boldsymbol{\theta})] := \text{Expected complete-data log-likelihood} \]

\[ \mathbb{E}_q[\log q(\boldsymbol{\theta})] := \text{Entropy of } q \]

Optimality Conditions

The stationary points of the ELBO satisfy:

$$ \log q^*(\boldsymbol{\eta}) = \mathbb{E}_{q^*(\boldsymbol{\psi})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})] + \text{const.} $$ $$ \log q^*(\boldsymbol{\psi}) = \mathbb{E}_{q^*(\boldsymbol{\eta})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})] + \text{const.} $$

1.2 Coordinate Ascent Variational Inference (CAVI)

The CAVI algorithm iteratively optimizes each variational factor while holding others fixed:

Algorithm 1: CAVI for MFVB

  1. Initialize $ q^{(0)}(\boldsymbol{\eta}) $, $ q^{(0)}(\boldsymbol{\psi}) $ by their setting hyperparameters
  2. While stopping criteria not met:
    1. Update non-conjugate parameters: $$ q^{(k)}(\boldsymbol{\eta}) \propto \exp\left(\mathbb{E}_{q^{(k-1)}(\boldsymbol{\psi})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})]\right) $$
    2. Update conjugate parameters: $$ q^{(k)}(\boldsymbol{\psi}) \propto \exp\left(\mathbb{E}_{q^{(k)}(\boldsymbol{\eta})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})]\right) $$

1.3 Non-Conjugate Challenges

When \( \boldsymbol{\eta} \) is non-conjugate, the update:

\[ q(\boldsymbol{\eta}) \propto \exp\left(\mathbb{E}_{q(\boldsymbol{\psi})}[\log p(D, \boldsymbol{\eta}, \boldsymbol{\psi})]\right) \equiv \exp(f(\boldsymbol{\eta})) \]

faces three fundamental difficulties:

  1. The normalizing constant \( Z = \int \exp(f(\boldsymbol{\eta})) d\boldsymbol{\eta} \) is typically intractable
  2. Required expectations \( \mathbb{E}_{q(\boldsymbol{\eta})}[g(\boldsymbol{\eta})] \) lack closed forms
  3. The ELBO cannot be evaluated exactly

Some Solutions

  • Importance Sampling... but suffers from high variance in high dimensions and is prone to particle collapse
  • Quadrature Methods... but exponential complexity in $ k = \dim(\boldsymbol{\eta}) $
  • Conjugate Modifications... but restrictive modeling assumptions

2. VB-LD Algorithm

2.1 Laplace Approximation Step

Given the variational distribution:

\[ q(\boldsymbol{\eta}) \propto \exp\left(\mathbb{E}_{q(\boldsymbol{\psi})}[\log p(D,\boldsymbol{\eta},\boldsymbol{\psi})]\right) \equiv \exp(f(\boldsymbol{\eta})) \]

  1. Mode Finding:
    • Compute gradient \( \nabla f(\boldsymbol{\eta}) \)
    • Solve \( \nabla f(\hat{\boldsymbol{\eta}}) = 0 \) via Newton-Raphson: \[ \boldsymbol{\eta}^{(k+1)} = \boldsymbol{\eta}^{(k)} - [\nabla^2 f(\boldsymbol{\eta}^{(k)})]^{-1}\nabla f(\boldsymbol{\eta}^{(k)}) \]
    • Terminate when \( |\boldsymbol{\eta}^{(k+1)} - \boldsymbol{\eta}^{(k)}| < \epsilon \)
  2. Covariance Approximation: \[ \Sigma = -\left[\nabla^2 f(\hat{\boldsymbol{\eta}})\right]^{-1} \] where the Hessian is computed via: \[ [\nabla^2 f(\boldsymbol{\eta})]_{ij} = \frac{\partial^2 f}{\partial \eta_i \partial \eta_j} \]

Laplacian approximation

$$ q(\boldsymbol{\eta}) \approx \mathcal{N}(\hat{\boldsymbol{\eta}}, \Sigma) $$

2.2 Delta Method for Expectation Approximation

When updating \( q(\boldsymbol{\psi}) \), we are required to compute expectations of the form:

\[ \mathbb{E}_{q(\boldsymbol{\eta})}[g(\boldsymbol{\eta})] \]

To do so, we approximate such quantities as follows,

Delta method (second-order Taylor approximation)

$$ \mathbb{E}[g(\boldsymbol{\eta})] \approx g(\hat{\boldsymbol{\eta}}) + \frac{1}{2}\text{tr}\left[\nabla^2 g(\hat{\boldsymbol{\eta}})\Sigma\right] $$

2.3 Complete VB-LD Algorithm

Algorithm 2: VB-LD Implementation

  1. Initialize $ q^{(0)}(\boldsymbol{\psi}) $
  2. While stopping criteria not met:
    1. Laplace Step:
      • Compute $ \hat{\boldsymbol{\eta}}^{(k)} = \text{arg} max f(\boldsymbol{\eta}) $
      • Evaluate $ \Sigma^{(k)} = -[\nabla^2 f(\hat{\boldsymbol{\eta}}^{(k)})]^{-1} $
      • Non-Conjugate Update: $ q^{(k)}(\boldsymbol{\eta}) = \mathcal{N}(\hat{\boldsymbol{\eta}}^{(k)}, \Sigma^{(k)}) $
    2. Delta Step:
      • For all required $ g $, compute: $\mathbb{E}[g(\boldsymbol{\eta})] \approx g(\hat{\boldsymbol{\eta}}^{(k)}) + \frac{1}{2} \text{tr}[ \nabla^2 g(\hat{\boldsymbol{\eta}}^{(k)}) \Sigma^{(k)} ] $
      • Conjugate Update: Use the previous quantity to update $ q^{(k)}(\boldsymbol{\psi}) $

Further Reading

  1. Original VB-LD paper: Wang & Blei (2013)