Variational Inference Step-by-Step (Part 1: Motivation)
Published:
Hi everybody! For the past few days I have been struggling to study Variational Inference from many resources. I found that many textbooks jump into rigorous details even before I can quite understand the intuition behind it. Then these series of tutorial videos by Chieh Wu from Northeastern University came to the rescue. His explanation is excellent and very intuitive for me. I wrote this step-by-step explanations series of post mainly as lecture notes for myself that heavily follows his videos. I hope this could be useful for those who share the same frustation in learning the subject. Enjoy!
Motivation
This note will begin by discussing the reason why variational inference is useful in many applications. We will assume that some basic of probabilistic graphical models has already been covered by the reader. First we will start the discussion by considering the bayesian network below:
- image here *
We know that the arrows in the network indicate the conditional independence between random variables. The joint probability density function (PDF) could be factorized as the product of these local independence. Namely, \[ P(x_1,x_2,…,x_5) = P(x_1)P(x_3|x_1)P(x_2|x_1)P(x_4|x_2,x_3)P(x_5|x_3) \] We can see that the graphical model is a very neat tools to determine the joint PDF factorization. Unfortunately, finding the conditional independence is still not so straightforward. Consider this example of finding: \(P(x_3,x_4\mid x_1,x_2,x_5)\) and remember this conditional probability rule: \[ P(Z|X)=\frac{P(X,Z)}{P(X)} \]
Assuming that \(x_1,x_2\) and \(x_5\) are continuous variables, the conditional distribution could be expressed as follow: \[P(x_3,x_4|x_1,x_2,x_5) = \frac{P(x_1,x_2,x_3,x_4,x_5)} {\int_{x_3}\int_{x_4}(x_1,x_2,x_3,x_4,x_5)d_{z_1}d_{z_2}d_{z_5}} \]
We can see that the conditional probability computation gets complicated as we have to marginalized the joint probability over the continuous variable. This computation become intractable when we don’t have analytical solution for the integrals.
One would questions: wouldn’t it be possible if we could somehow get the conditional probability just from the joint probability, completely bypassing step that involves intractable marginalization. The answer is: yes. In fact, it is the whole point of variational inference.
Side Notes
As the matter of fact, variational inference is not the only way to address this inference problem. We will briefly sum up the list of some widely known solutions in this table:
Metropolis Hasting | Variational Inference | Laplace Approximation |
---|---|---|
Exact solution by a MCMC | Deterministic solution | (also) Deterministic solution |
More accurate | Fairly good approximation | Much less accurate |
Takes longer to compute | Takes less time to compute | Fast to compute |
Overview
Recall again the equation of conditional probability rule above. Lets break it down a little bit from the inference point of view. The conditional \(P(Z\mid X)\) for which we strive to find, is surely unknown. The denominator \(P(X)\) is also unknown (since the analytical solution of the integrals is unavailable). Meanwhile, the numerator, a joint probability \(P(X,Z)\) is easily known by factorizing the conditional independence.
Next, we define \(q(Z)\) as a probability distribution that approximate \(P(Z\mid X)\). We then use the property of Kullback–Leibler divergence, or, in short, KL divergence, that measure how one distribution diverge from other distribution (The next section discuss how this property is derived). The property is defined as follow:
\[ \begin{equation} \label{eq:klproperty} \begin{aligned} \ln p(X) & = \mathcal{L}(q) + KL(q||p) & & (1) \end{aligned} \end{equation} \]
In general, variational inference strive to approximate \(p(Z\mid X)\) by finding \(q(Z)\) that minimize the \(KL(q\|p)\). The value of the left hand side of the equation is fixed as the outcome of random variable \(X\) is given. However, terms in the right hand side can be changed by adjusting \(q(Z)\). Therefore, minimizing \(KL(q\|p)\) would be equivalent to maximizing \(\mathcal{L}(q)\) (in the next section, we will briefly show how \(\mathcal{L}\) acts as the lower bound of KL). We will break down the terms in the right hand side to show why it is particularly more convenient to maximize \(\mathcal{L}\) instead of minimizing \(KL(q\|p)\).
\[KL(q||p) = q(Z)\log \frac{p(Z|X)}{q(Z} \] \[ \mathcal{L} = \sum q(Z)\log \frac{p(X,Z)}{q(Z)}\]
\(KL(q\|p)\) contains probability distribution \(p(Z\mid X)\) which we try to avoid, while lower bound \(\mathcal{L}(q)\) contains known probability distribution \(p(X,Y)\). Thus, it is obvious that \(\mathcal{L}\) is preferable to be dealt with. In other words, variational inference finds \(q(Z)\) that maximize the lower bound \(\mathcal{L}(q)\). Namely, it finds \(argmax_{q(z)}q(z)\log \frac{p(X,Z)}{q(Z)}\). Finding \(q(Z)\) itself is not trivial and will be the subject of later sections.
Deriving KL-Divergence Equation
Before diving deeper to variational inference, it is useful to understand how we could come up with the KL divergence property stated above. KL divergence is built on top of the entropy measure of a probability distribution (this quantity is first introduced in information theory). Intuitively, entropy provide a measure of how much information we could gain by knowing a certain probability distribution. This measure is given by the following equality: \[ \begin{equation} \label{eq:entropy} \begin{aligned} \mathcal{H}=-\sum_i P(x_i)\log P(x_i) \end{aligned} \end{equation} \] The equivalent measure for continuous random variable could be given by just replacing sum by integrals, namely: \[\mathcal{H}= -\int P(x_i)\log P(x_i)\]
KL divergence gives the quantity that indicates how similar two probability distributions are to each other. The smaller the value of \(KL(p\|q)\) is, the more similar distribution \(p\) is, to distribution \(q\). KL divergence defines the difference of the distributions as the distance of their entropy, well.. almost. The most straightforward expression of this intuition which one could think of would be:
\[ \begin{aligned} KL(p||q) & = \mathcal{H}_q - \mathcal{H}_p \newline & = (-\int q(x)\log q(x)) - (-\int p(x)\log p(x)) \end{aligned} \]
However, it should be emphasized that KL divergence measures the relative distance of the two distribution, which means that they are asymmetrical. In a way, the distance from \(q\) to \(p\) is not the same as the distance from \(p\) to \(q\). To incorporate this intuition, it modified the above expression to be as follow: \[ KL(p||q) = \big[-\int p(x)\log q(x)\big] - \big[-\int p(x)\log p(x)\big] \] Notice that the measure of KL divergence with respect to \(p\), would replace \(q(x)\) with \(p(x)\), and thus, making this function asymmetrical. This is, in fact, the actual equation of KL divergence.
While it is easy to explain the idea of KL divergence just by looking at the above formula, people would usually provide the simplified version of it. We will show how one could simplify the equation above:
\[ \begin{aligned} KL(p||q) & = \big[-\int p(x)\log q(x)\big] - \big[-\int p(x)\log p(x)\big] \newline & = \int p(x)\log p(x)) - \int p(x)\log q(x) \newline & = \int p(x) \Big[ \log p(x) - \log q(x) \Big] \newline & = \int p(x) \Big[ \log \frac{p(x)}{q(x)}\Big] \newline & = -\int p(x) \Big[ \log \frac{q(x)}{p(x)}\Big] & \text{(equivalent form by logarithm rule)} \end{aligned} \]
Properties of KL-Divergence
Knowing how KL divergence formula is derived, we could now explain the properties that it holds. The following are some useful properties of KL divergence:
- \(KL(p\|q)\) is always \(>=0\). Gibbs inequality is used to show this. One simple intuition is that when \(p(x) = q(x)\) (or in other words, they are identical distributions) then \(\log \frac{p(x)}{q(x)}\) will be \(0\), meaning that KL divergence values could not get any smaller then 0.
- \(KL(p\|q)\) is not symmetrical, meaning that \(KL(p\|q) \neq KL(q\|p)\)
Variational Inference: Main Idea
This last section will be the gist of this first post of the series. We will now discuss about how to take all properties above to we could derive to equation (1). So let’s revisit variational inference objective which find \(q(Z)\) that approximate \(p(Z|X)\). Now we know that KL divergence is used to measure to quality of how good the approximation is, the smaller the value is the better the approximation is. Using KL divergence formula we can derive the expression as follow:
\[ \begin{aligned} KL(q(Z)||p(Z|X)) {} & = -\int q(Z) \Big[ \log \frac{p(Z|X)}{q(Z)}\Big] \newline & = -\int q(Z) \Big[ \log \frac{p(X,Z)}{p(X)} \frac{1}{q(Z)}\Big] & \text{(by conditional prob. rule)} \newline & = -\int q(Z) \Big[ \log \frac{p(X,Z)}{q(Z)} \frac{1}{p(X)}\Big] & \text{(swapping the denominator)} \newline & = -\int q(Z) \Big[ \log \frac{p(X,Z)}{q(Z)} + \log \frac{1}{p(X)}\Big] & \text{(by logarithm rule)} \newline & = -\int q(Z) \Big[ \log \frac{p(X,Z)}{q(Z)} - \log p(X)\Big] \newline & = \Big[ -\int q(Z) \log \frac{p(X,Z)}{q(Z)}\Big] + \Big[\log p(X) \int q(Z)\Big] & \text{(by distributing the integral)} \newline & = \Big[ -\int q(Z) \log \frac{p(X,Z)}{q(Z)}\Big] + \Big[\log p(X) \Big] & \text{(probability sum to 1)} \end{aligned} \]
We finally then arrive to this equality:
\[ \begin{aligned} \log p(X) & = KL(q(Z)||p(Z|X)) + \int q(Z) \log \frac{p(X,Z)}{q(Z)} \newline \log p(X) & = KL(q(Z)||p(Z|X)) + \mathcal{L} \end{aligned} \]
Where \(\mathcal{L}\) is a function that defined as: \(\mathcal{L} = \int q(Z) \log \frac{p(X,Z)}{q(Z)}\).
At this point, one should ask: how can this equality useful? Well, to begin with, we already know from the derivation in previous section that values of KL will always be \(>=0\). We also know that the value of \(p(X)\), is within \([0,1]\), and therefore \(\log p(X)\) will always be negative. We could then tell that \(\mathcal{L}\) will always negative.
\(\mathcal{L}\) could be interpreted as the Lower Bound of KL. One important thing to notice is that random variable X is given, and therefore the value of \(p(X)\) is fixed. Because of that, you can see that as the value of \(\mathcal{L}\) gets larger and closer to 0, the value of KL will grow smaller (also getting closer to 0).
We could essentially control KL by perturbing \(\mathcal{L}\)! Lets emphasize this point again: by making the lower bound \(\mathcal{L}\) less negative (larger), we are reducing the KL divergence. And so, reducing the KL divergence means that the quality of estimation of \(p(Z\mid X)\) by \(q(Z)\) is getting better.
That’s it for this part! I hope at this point you will have a glimps of what variational inference does and how KL-Divergence could intuitively be used to achieve the objective. In later parts we will discuss briefly the rationale of KL-Divergence lower bound, and how would variational inference approach find \(q(Z)\).