Variational Inference Step-by-Step (Part 2: Mean Field V.I.)

6 minute read

Published:

Hi everyone, welcome to my second post of Variational Inference step-by-step introduction! In the previous post we have explained the reason that motivates the use of variational inference. We also introduced the importance of KL-divergence and how it used in the main idea of variational inference. In this post, we will discuss the idea of finding the variational distribution and how mean field variational inference could help us to do that.

Deriving Variational Inference

Fundamental points that we have learn so far is that: minimizing KL divergence is equivalent to maximizing the lower bound \(\mathcal{L}\), and it is easier and more convenient to do so. Now we will start the discussion that will be the gist of variational inference: how do we find \(q(Z)\) that maximizes \(\mathcal{L}\)?.

Lets suppose we have 2 sets of variables, namely \(X\) and \(Z\), where \(X = \{x_1, x_2, x_3\}\) and \(Z = \{z_1, z_2, z_3\}\). For the sake of convenient example, we pick 3 random variables for each (it could have been picked arbitrarily, though). We denote the join probability of these sets as \(p(X, Z) = p(x_1,x_2,x_3,z_1,z_2,z_3)\).

Remember that our ultimate objective is to find the conditional probability \(p(Z \mid X)\), namely: \(p(z_1, z_2, z_3 \mid x_1, x_2,x_3)\). By conditional probability rule, we can express the conditional probability as the joint over the marginal:

\[ \begin{aligned} P(Z|X) = \frac{p(z_1,z_2,z_3,x_1,x_2,x_3)} {\int_{z_1}\int_{z_2}\int_{z_3}p(z_1,z_2,z_3,x_1,x_2,x_3)d_{z_1}d_{z_2}d_{z_3}} \end{aligned} \]

As explained several times previously, finding the marginal as the denominator is really hard and sometimes impossible. So let us restate our goal again: we will try to find \(q(z_1,z_2,z_3)\) that approximates \(p(z_1, z_2, z_3\mid x_1, x_2,x_3)\) using what we already know: the joint probability \(p(z_1, z_2, z_3, x_1, x_2,x_3)\).

Putting back the pieces from the previous section, we can plugin these random variables into lower bound formula \(\mathcal{L}\), which we are trying to maximize. Namely:

\[ \begin{aligned} \mathcal{L} = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1, z_2, z_3) \ln \frac{p(x_1,x_2,x_3,z_1,z_2,z_3)}{q(z_1,z_2,z_3)} \end{aligned} \]

Note: on previous parts, we use integral notations as we assume that the variables are continuous. For the sake of clarity, we will use summation in this section.

Finding \(q(z_1,z_2,z_3)\) is also hard. So we tackle this problem by making a fundamental assumption that variables in \(Z\) are independent. Therefore, the joint becomes a product of each individual variable: \(q(z_1,z_2,z_3) = q(z_1) q(z_2) q(z_3)\). This assumption is the essence of mean field variational inference.

We can then plugin this back to lower bound equation and express it as: \[ \begin{aligned} \mathcal{L} = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln \frac{p(x_1,x_2,x_3,z_1,z_2,z_3)}{q(z_1) q(z_2) q(z_3)} \end{aligned} \]

At this point, we can simplify the lower bound equation with several steps of mathematical manipulation:

\[ \begin{aligned} & \text{The first two steps would be to apply logarithm rule:} \newline \mathcal{L} & = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) - \ln q(z_1) q(z_2) q(z_3)\big] \newline & = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) - \ln q(z_1) - \ln q(z_2) - \ln q(z_3)\big] \end{aligned}
\]

One trick that is useful to solve this rather complicated equation is: Solve \(q(z_i)\) one at a time, instead of solving the entire \(q(Z)\). We can do this by making an assumption that all \(q(z_j)\) where \(j \neq i\), are known. Lets start with \(q(z_1)\) where we assume that we know \(q(z_2)\) and \(q(z_3)\). This assumption might not necessarily true, but we’ll see how it helps us. We can distribute the summations and break this large equation into 3 parts:

\[ \begin{aligned} \textbf{Part 1} & : \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln p(X,Z)\newline \textbf{Part 2} & : -\sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln q(z_1)\newline \textbf{Part 3} & : -\sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \big[ \ln q(z_2) + \ln q(z_3) \big] \end{aligned} \]

We go over this large equations part by part starting from the easiest one, part 3:

\[ \begin{aligned} & = -\sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \big[ \ln q(z_2) + \ln q(z_3) \big]\newline & = -\sum_{z_1} q(z_1) \sum_{z_2} \sum_{z_3} q(z_2) q(z_3) \big[ \ln q(z_2) + \ln q(z_3) \big]\newline & \text{by our assumption, we can replace the terms that contains } q(z_2) \text{ and } q(z_3) \text{ as a constant } \mathcal{C}.\newline & = -\sum_{z_1} q(z_1) \mathcal{C} \end{aligned} \]

Next one is part 2:

\[ \begin{aligned} & \text{We can do similar manipulation by swapping the summation:}\newline & = -\sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln q(z_1)\newline & = -\sum_{z_1} q(z_1) \ln q(z_1) \sum_{z_2} \sum_{z_3} q(z_2) q(z_3) \newline & \text{Notice that we have summations over probability distribution that sum to 1:}\newline & = -\sum_{z_1} q(z_1)\ln q(z_1) * 1\newline & = -\sum_{z_1} q(z_1)\ln q(z_1) \end{aligned} \]

The most complicated one would be part 1:

\[ \begin{aligned} & = \sum_{z_1} \sum_{z_2} \sum_{z_3} q(z_1) q(z_2) q(z_3) \ln p(x_1,x_2,x_3,z_1,z_2,z_3)\newline \label{eq:part3} & = \sum_{z_1} q(z_1) \sum_{z_2} \sum_{z_3} q(z_2) q(z_3) \ln p(x_1,x_2,x_3,z_1,z_2,z_3) & & (1) \end{aligned} \]

We can swap the summation again, but still we cannot remove the joint. Now lets take a look at the definition expectation of probability distribution: \[ \begin{aligned} E[x] = & \sum x p(x) \newline \text{Equivalently, for expectation of function }\newline \text{of random variable, we have:}\newline E[f(x)] = & \sum f(x) p(x) \end{aligned} \]

We see that each part of equation (1) corresponds to a part in the definition of expectation, you can look closely here: \[ \begin{aligned} = & \sum_{z_1} q(z_1) \underbrace{\sum_{z_2}\sum_{z_3}}_{summations} \overbrace{q(z_2) q(z_3)}^{\text{pdf p(x)}}\overbrace{\ln p(x_1,x_2,x_3,z_1,z_2,z_3)}^{\text{some function f(x)}} \end{aligned} \]

We, thus, can reexpress equation (1) to as follow: \[ \begin{aligned} = \sum_{z_1} q(z_1) E_{z_2,z_3}\big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) \big] \end{aligned} \]

Finally, we can assemble back all of those 3 parts into a formula of \(\mathcal{L}\): \[ \begin{aligned} \mathcal{L} & = \overbrace{\sum_{z_1} q(z_1) E_{z_2,z_3}\big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) \big]}^{part_1} \overbrace{- \sum_{z_1} q(z_1)\ln q(z_1)}^{part_2} \overbrace{-\sum_{z_1} q(z_1) \mathcal{C}}^{part_3}\newline & \text{We can take this further by moving part 3 and part 1 together:}\newline & = \sum_{z_1} q(z_1) \Bigg[ E_{z_2,z_3} \Big[ \ln p(x_1,x_2,x_3,z_1,z_2,z_3) \Big] - \mathcal{C} \Bigg] - \sum_{z_1} q(z_1)\ln q(z_1) \end{aligned} \]


This is still far from done, but we will stop right here for this post. Hopefully, we can finish the derivation of finding the variational distribution in the next post. Stay tune!

Thank you!