Stein’s identity and Stein Variational Gradient Descent


If you have an target distribution p(x) that you want to model, but you can’t compute it (maybe because it involves a partition function), then starting with an initial set of particles \left\{x_i^0\right\}_{i=1}^n, you can iteratively update them as following:

\displaystyle x_i^{l+1} \leftarrow x_i^l + \epsilon_l\hat{\phi}^*\left(x_i^l\right)


  • \epsilon_l is a small step size at iteration l
  • \displaystyle \hat{\phi}^*\left(x\right) = \frac{1}{n}\sum_{j=1}^n \left[k\left(x_j^l, x\right)\nabla_{x_j^l}\log p\left(x_j^l\right) + \nabla_{x_j^l}k\left(x_j^l, x\right)\right] is the steepest update direction
  • k\left( \cdot , \cdot \right) is a kernel, typically RBF.

At the end, this update rule gives a set of particle \left\{x_i\right\}_{i=1}^n that approximates the target distribution.

Note that to compute the update direction, you only need to compute the derivative of the log of the (unnormalized) target distribution with respect to a set of samples \left\{x_j^l\right\}_{j=1}^n.

This result is significant because it is analogous to gradient descent that minimizes a KL divergence. The set of particles \left\{x_i^0\right\}_{i=1}^n, therefore, normally comes from another black-box model.

An interesting intuition is that this direction will push the particles into the regions with high values of p(x), while the second term (derivatives of the kernel), for the case of RBF, has “regularization” effect, which keeps the particles away from each other, prevent them from collapsing into the same mode. This is how the method is probably better than pure MCMC.

The short story

Let x be a continuous random variables in \mathcal{X} \subset \mathbb{R}^d and p\left(x\right) is the (intractable) target distribution. Let \phi\left(x\right) = \left[\phi_1\left(x\right), ..., \phi_d\left(x\right)\right]^T a smooth vector-valued function, the so-called Stein‘s identity says:

\displaystyle \mathbb{E}_{x \sim p}\left[\mathcal{A}_p\phi\left(x\right)\right] = 0

where \mathcal{A}_p\phi\left(x\right) = \phi\left(x\right)\nabla_x\log p\left(x\right)^T + \nabla_x\phi\left(x\right) and \mathcal{A}_p is called the Stein operator.

Let q\left(x\right) be another distribution. Now \mathbb{E}_{x \sim q}\left[\mathcal{A}_p\phi\left(x\right)\right] will no longer be zero. It turns out this can be used to define the discrepancy between two distributions p and q:

\displaystyle \mathbb{S}\left(q, p\right) = \max_\phi \left[\mathbb{E}_{x \sim q}\text{trace}\left(\mathcal{A}_p\phi\left(x\right)\right)\right]^2

Meaning we consider all possible smooth function \phi and use the one that maximizes the violation of the Stein’s identity. This maximum violation is defined to be the Stein discrepancy between q and p.

Considering all possible \phi is impractical. But it turns out if we consider \phi in the unit ball of a reproducing kernel Hilbert space (RKHS) \mathcal{H}^d, then the kernelized Stein discrepancy is defined as:

\displaystyle \mathbb{KS}\left(q, p\right) = \max_{\phi\in\mathcal{H}^d} \left[\mathbb{E}_{x \sim q}\text{trace}\left(\mathcal{A}_p\phi\left(x\right)\right)\right]^2\quad \text{s.t.} \parallel \phi\parallel_{\mathcal{H}^d} \leq 1

and this has a closed-form solution: \displaystyle \phi_{q,p}^*\left(\cdot\right) = \mathbb{E}_{x \sim q}\left[\mathcal{A}_pk\left(x,\cdot\right)\right].

Now if we take T\left(x\right) = x + \epsilon\phi\left(x\right) and q_{[T]} is the distribution of T when x \sim q(x), then [1] shows that the derivatives of the KL divergence between q_{[T]} and p has an interesting form:

\nabla_\epsilon KL\left(q_{[T]} \parallel p\right)\vert_{\epsilon = 0} = -\mathbb{E}_{x \sim q}\left[\text{trace}\left(\mathcal{A}_p\phi\left(x\right)\right)\right]

(without the square)

Relating to all the above, we now have the direction that minimizes the KL divergence is the expectation of the Stein operator:

\displaystyle \phi_{q,p}^*\left(\cdot\right) = \mathbb{E}_{x \sim q}\left[k(x,\cdot)\nabla_x\log p(x) + \nabla_x k(x,\cdot)\right]

The long story

The above is a huge simplification, don’t take it too seriously. If you are dying for more details, the following might help:

[1] introduce the kernelized Stein discrepancy
[2] presents the SVGD algorithm
[3] uses the algorithm to estimate parameters of an energy-based deep neural net.