Table of Contents
Resources:
- Variational Bayes and The Mean-Field Approximation (blog)
- Variational Inference: Mean Field Approximation (Lecture Notes)
- Graphical Models, Exponential Families, and Variational Inference (M Jordan)
- Why is the Variational Bound Tight: the variational bound compared to the original error surface (reddit!)
Inference and Approximate Inference
-
Inference:
Inference usually refers to computing the probability distribution over one set of variables given another.Goals:
- Computing the likelihood of observed data (in models with latent variables).
- Computing the marginal distribution over a given subset of nodes in the model.
- Computing the conditional distribution over a subsets of nodes given a disjoint subset of nodes.
- Computing a mode of the density (for the above distributions).
Approaches:
- Exact inference algorithms:
- Brute force
- The elimination algorithm
- Message passing (sum-product algorithm, belief propagation)
- Junction tree algorithm
- Approximate inference algorithms:
- Loopy belief propagation
- Variational (Bayesian) inference \(+\) mean field approximations
- Stochastic simulation / sampling / MCMC
Inference in Deep Learning - Formulation:
In the context of Deep Learning, we usually have two sets of variables:
(1) Set of visible (observed) variables: \(\: \boldsymbol{v}\)
(2) Set of latent variables: \(\: \boldsymbol{h}\)Inference in DL corresponds to computing the likelihood of observed data \(p(\boldsymbol{v})\).
When training probabilistic models with latent variables, we are usually interested in computing
$$p(\boldsymbol{h} \vert \boldsymbol{v})$$
where \(\boldsymbol{h}\) are the latent variables, and \(\boldsymbol{v}\) are the observed (visible) variables (data).
-
The Challenge of Inference:
Motivation - The Challenge of Inference:
The challenge of inference usually refers to the difficult problem of computing \(p(\boldsymbol{h} \vert \boldsymbol{v})\) or taking expectations wrt it.
Such operations are often necessary for tasks like Maximum Likelihood Learning.Intractable Inference:
In DL, intractable inference problems, usually, arise from interactions between latent variables in a structured graphical model.
These interactions are usually due to:- Directed Models: “explaining away” interactions between mutual ancestors of the same visible unit.
- Undirected Models: direct interactions between the latent variables.
In Models:
- Tractable Inference:
- Many simple graphical models with only one hidden layer have tractable inference.
E.g. RBMs, PPCA.
- Many simple graphical models with only one hidden layer have tractable inference.
- Intractable Inference:
- Most graphical models with multiple hidden layers with hidden variables have intractable posterior distributions.
Exact inference requires an exponential time.
E.g. DBMs, DBNs. - Even some models with only a single layer can be intractable.
E.g. Sparse Coding
- Most graphical models with multiple hidden layers with hidden variables have intractable posterior distributions.
Computing the Likelihood of Observed Data:
We usually want to compute the likelihood of the observed data \(p(\boldsymbol{v})\), equivalently the log-likelihood \(\log p(\boldsymbol{v})\).
This usually requires marginalizing out \(\boldsymbol{h}\).
This problem is intractable (difficult) if it is costly to marginalize \(\boldsymbol{h}\).- Data Likelihood: (intractable)
\(p_{\theta}(\boldsymbol{v})=\int_\boldsymbol{h} p_{\theta}(h) p_{\theta}(v \vert h) dh\) - Marginal Likelihood (evidence): is the data likelihood \(p_{\theta}(\boldsymbol{v})\) (intractable)
\(\int_\boldsymbol{h} p_{\theta}(h) p_{\theta}(v \vert h) dh\) - Prior:
\(p(\boldsymbol{h})\) - (Conditional) Likelihood:
\(p_{\theta}(\boldsymbol{v} \vert h)\) - Joint:
\(p_{\theta}(\boldsymbol{v}, \boldsymbol{h})\) - Posterior: (intractable)
\(p_{\theta}(\boldsymbol{h} \vert \boldsymbol{v})=\frac{p_{\theta}(\boldsymbol{v}, \boldsymbol{h})}{p_{\theta}(\boldsymbol{v})}=\frac{p_{\theta}(\boldsymbol{v} \vert h) p_{\theta}(h)}{\int_{\boldsymbol{h}} p_{\theta}(h) p_{\theta}(x \vert h) d h}\)
-
Approximate Inference:
Approximate Inference is an important and practical approach to confronting the challenge of (intractable) inference.
It poses exact inference as an optimization problem, and aims to approximate the underlying optimization problem.
-
Inference as Optimization:
Exact inference can be described as an optimization problem.- Inference Problem:
- Compute the log-likelihood of the observed data, \(\log p(\boldsymbol{v} ; \boldsymbol{\theta})\).
Can be intractable to marginalize \(\boldsymbol{h}\).
- Compute the log-likelihood of the observed data, \(\log p(\boldsymbol{v} ; \boldsymbol{\theta})\).
- Inference Problem as Optimization - Core Idea:
- Choose a family of distributions over the latent variables \(\boldsymbol{h}\) with its own set of variational parameters \(\boldsymbol{v}\): \(q(\boldsymbol{h} \vert \boldsymbol{v})\).
- Find the setting of the parameters that makes our approximation closest to the posterior distribution over the latent variables \(p(\boldsymbol{h} \vert \boldsymbol{v})\).
I.E. Optimization - Use learned \(q\) in place of the posterior (as an approximation).
- Optimization - Fitting \(q\) to the posterior \(p\):
- Optimize \(q\) to approximate \(p(\boldsymbol{h} \vert \boldsymbol{v})\)
- Similarity Measure: use the KL-Divergence as a similarity measure between the two distributions
$$D_{\mathrm{KL}}(q \| p) = \mathrm{E}_ {h \sim q}\left[\log \frac{q(h)}{p(h\vert {v})}\right] =\int_{h} q(h) \log \left(\frac{q(h)}{p(h\vert {v})}\right) dh$$
- Intractability: minimizing the KL Divergence (above) is an intractable problem.
Because the expression contains the intractable term \(p(\boldsymbol{h}\vert \boldsymbol{v})\) which we were trying to avoid.
- Evidence Lower Bound:
- We rewrite the KL Divergence expression in terms of log-likelihood of the data:
$$\begin{aligned} D_{\mathrm{KL}}(q \| p) &=\int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h \vert v)} dh \\ &=\int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h, v)} dh+\int_{\boldsymbol{h}} q(h) \log p(v) dh \\ &=\int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h, v)} dh+\log p(\boldsymbol{v}) \end{aligned}$$
where we’re using Bayes theorem on the second line and the RHS integral simplifies because it’s simply integrating over the support of \(q\) and \(p\) is not a function of \(h\).
Thus,$$\log p(\boldsymbol{v}) = D_{\mathrm{KL}}(q \| p) - \int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h, v)} dh$$
- Notice that since the KL-Divergence is Non-Negative:
$$\begin{align} D_{\mathrm{KL}}(q \| p) &\geq 0 \\ D_{\mathrm{KL}}(q \| p) - \int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h, v)} dh &\geq - \int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h, v)} dh \\ \log p(\boldsymbol{v}) &\geq - \int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h, v)} dh \end{align} $$
Thus, the term \(- \int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h, v)} dh\) provides a lower-bound on the log likelihood of the data.
- We rewrite the term as:
$$\mathcal{L}(\boldsymbol{v}, \boldsymbol{\theta}, q) = - \int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h, v)} dh$$
the Evidence Lower Bound (ELBO) AKA Variational Free Energy.
Thus,$$\log p(\boldsymbol{v}) \geq \mathcal{L}(\boldsymbol{v}, \boldsymbol{\theta}, q)$$
- The Evidence Lower Bound can also be defined as:
$$\begin{align} \mathcal{L}(\boldsymbol{v}, \boldsymbol{\theta}, q) &= - \int_{\boldsymbol{h}} q(h) \log \frac{q(h)}{p(h, v)} dh \\ &= \log p(\boldsymbol{v} ; \boldsymbol{\theta})-D_{\mathrm{KL}}(q(\boldsymbol{h} \vert \boldsymbol{v}) \| p(\boldsymbol{h} \vert \boldsymbol{v} ; \boldsymbol{\theta})) \\ &= \mathbb{E}_ {\mathbf{h} \sim q}[\log p(\boldsymbol{h}, \boldsymbol{v})]+H(q) \end{align} $$
The latter being the canonical definition of the ELBO.
- We rewrite the KL Divergence expression in terms of log-likelihood of the data:
- Inference with the Evidence Lower Bound:
- For an appropriate choice of \(q, \mathcal{L}\) is tractable to compute.
- For any choice of \(q, \mathcal{L}\) provides a lower bound on the likelihood
- For \(q(\boldsymbol{h} \vert \boldsymbol{v})\) that are better approximations of \(p(\boldsymbol{h} \vert \boldsymbol{v}),\) the lower bound \(\mathcal{L}\) will be tighter
I.E. closer to \(\log p(\boldsymbol{v})\). - When \(q(\boldsymbol{h} \vert \boldsymbol{v})=p(\boldsymbol{h} \vert \boldsymbol{v}),\) the approximation is perfect, and \(\mathcal{L}(\boldsymbol{v}, \boldsymbol{\theta}, q)=\log p(\boldsymbol{v} ; \boldsymbol{\theta})\).
- Maximizing the ELBO minimizes the KL-Divergence \(D_{\mathrm{KL}}(q \| p)\).
- Inference:
We can thus think of inference as the procedure for finding the \(q\) that maximizes \(\mathcal{L}\):- Exact Inference: maximizes \(\mathcal{L}\) perfectly by searching over a family of functions \(q\) that includes \(p(\boldsymbol{h} \vert \boldsymbol{v})\).
- Approximate Inference: approximate inference uses approximate optimization to find \(q\).
We can make the optimization procedure less expensive but approximate by:- Restricting the family of distributions \(q\) that the optimization is allowed to search over
- Using an imperfect optimization procedure that may not completely maximize \(\mathcal{L}\) but may merely increase it by a significant amount.
- Core Idea of Variational Inference:
We don’t need to explicitly compute the posterior (or the marginal likelihood), we can solve an optimization problem by finding the right distribution \(\) that best fits the Evidence Lower Bound.
Learning and Inference wrt the ELBO - Summary:
The ELBO \(\mathcal{L}(\boldsymbol{v}, \boldsymbol{\theta}, q)\) is a lower bound on \(\log p(\boldsymbol{v} ; \boldsymbol{\theta})\):- Inference: can be viewed as maximizing \(\mathcal{L}\) with respect to \(q\).
- Learning: can be viewed as maximizing \(\mathcal{L}\) with respect to \(\theta\).
Notes:
- The difference between the ELBO and the KL divergence is the log normalizer (i.e. the evidence), which is the quantity that the ELBO bounds.
- Maximizing the ELBO is equivalent to Minimizing the KL-Divergence.
- Inference Problem:
-
Expectation Maximization:
The Expectation-Maximization Algorithm is an iterative method to find maximum likelihood or maximum a posteriori (MAP) estimates of parameters in statistical models with unobserved latent variables.It is based on maximizing a lower bound \(\mathcal{L}\).
It is not an approach to approximate inference.
It is an approach to learning with an approximate posterior.The EM Algorithm:
The EM Algorithm consists of alternating between two steps until convergence:- The E(xpectation)-step:
- Let \(\theta^{(0)}\) denote the value of the parameters at the beginning of the step.
- Set \(q\left(\boldsymbol{h}^{(i)} \vert \boldsymbol{v}\right)=p\left(\boldsymbol{h}^{(i)} ; \boldsymbol{\theta}^{(0)}\right)\) for all indices \(i\) of the training examples \(\boldsymbol{v}^{(i)}\) we want to train on (both batch and minibatch variants are valid).
By this we mean \(q\) is defined in terms of the current parameter value of \(\boldsymbol{\theta}^{(0)}\);
if we vary \(\boldsymbol{\theta},\) then \(p(\boldsymbol{h} \vert \boldsymbol{v} ; \boldsymbol{\theta})\) will change, but \(q(\boldsymbol{h} \vert \boldsymbol{v})\) will remain equal to \(p\left(\boldsymbol{h} \vert \boldsymbol{v} ; \boldsymbol{\theta}^{(0)}\right)\).
- The M(aximization)-step:
- Completely or partially maximize
$$\sum_i \mathcal{L}\left(\boldsymbol{v}^{(i)}, \boldsymbol{\theta}, q\right)$$
with respect to \(\boldsymbol{\theta}\) using your optimization algorithm of choice.
- Completely or partially maximize
Relation to Coordinate Ascent:
The algorithm can be viewed as a Coordinate Ascent algorithm to maximize \(\mathcal{L}\).
On one step, we maximize \(\mathcal{L}\) with respect to \(q,\) and on the other, we maximize \(\mathcal{L}\) with respect to \(\boldsymbol{\theta}\).
Stochastic Gradient Ascent on latent variable models can be seen as a special case of the EM algorithm where the M-step consists of taking a single gradient step.Other variants of the EM algorithm can make much larger steps. For some model families, the M-step can even be performed analytically, jumping all the way to the optimal solution for \(\theta\) given the current \(q\).
As Approximate Inference - Interpretation:
Even though the E-step involves exact inference, the EM algorithm can be viewed as using approximate inference.
The M-step assumes that the same value of \(q\) can be used for all values of \(\theta\).
This will introduce a gap between \(\mathcal{L}\) and the true \(\log p(\boldsymbol{v})\) as the M-step moves further and further away from the value \(\boldsymbol{\theta}^{(0)}\) used in the E-step.
Fortunately, the E-step reduces the gap to zero again as we enter the loop for the next time.Insights/Takeaways:
- The Basic Structure of the Learning Process:
We update the model parameters to improve the likelihood of a completed dataset, where all missing variables have their values provided by an estimate of the posterior distribution.This particular insight is not unique to the EM algorithm. For example, using gradient descent to maximize the log-likelihood also has this same property; the log-likelihood gradient computations require taking expectations with respect to the posterior distribution over the hidden units.
- Reusing \(q\):
We can continue to use one value of \(q\) even after we have moved to a different value of \(\theta\).
This particular insight is used throughout classical machine learning to derive large M-step updates.
In the context of deep learning, most models are too complex to admit a tractable solution for an optimal large M-step update, so this second insight, which is more unique to the EM algorithm, is rarely used.
Notes:
- The Expectation-Maximization Algorithm and Derivation (Blog!)
- The EM algorithm enables us to make large learning steps with a fixed \(q\)
- The E(xpectation)-step:
- MAP Inference:
MAP Inference is an alternative form of inference where we are interested in computing the single most likely value of the missing variables, rather than to infer the entire distribution over their possible values \(p(\boldsymbol{h} \vert \boldsymbol{v})\).
In the context of latent variable models, we compute:$$\boldsymbol{h}^{* }=\underset{\boldsymbol{h}}{\arg \max } p(\boldsymbol{h} \vert \boldsymbol{v})$$
As Approximate Inference:
It is not usually thought of as approximate inference, since it computes the exact most likely value of \(\boldsymbol{h}^{* }\).
However, to develop a learning process wrt maximizing the lower bound \(\mathcal{L}(\boldsymbol{v}, \boldsymbol{h}, q),\) then it is helpful to think of MAP inference as a procedure that provides a value of \(q\).
In this sense, we can think of MAP inference as approximate inference, because it does not provide the optimal \(q\).
We can derive MAP Inference as a form of approximate inference by restricting the family of distributions \(q\) may be drawn from.
Derivation:- We require \(q\) to take on a Dirac distribution:
$$q(\boldsymbol{h} \vert \boldsymbol{v})=\delta(\boldsymbol{h}-\boldsymbol{\mu})$$
- This means that we can now control \(q\) entirely via \(\boldsymbol{\mu}\).
- Dropping terms of \(\mathcal{L}\) that do not vary with \(\boldsymbol{\mu},\) we are left with the optimization problem:
$$\boldsymbol{\mu}^{* }=\underset{\mu}{\arg \max } \log p(\boldsymbol{h}=\boldsymbol{\mu}, \boldsymbol{v})$$
- which is equivalent to the MAP inference problem:
$$\boldsymbol{h}^{* }=\underset{\boldsymbol{h}}{\arg \max } p(\boldsymbol{h} \vert \boldsymbol{v})$$
The Learning Procedure with MAP Inference:
We can, thus, justify a learning procedure similar to EM, where we alternate between:- Performing MAP inference to infer \(\boldsymbol{h}^{* }\), and
- Updating update \(\boldsymbol{\theta}\) to increase \(\log p\left(\boldsymbol{h}^{* }, \boldsymbol{v}\right)\).
As Coordinate Ascent:
As with EM, this is a form of coordinate ascent on \(\mathcal{L},\) where we alternate between using inference to optimize \(\mathcal{L}\) with respect to \(q\) and using parameter updates to optimize \(\mathcal{L}\) with respect to \(\boldsymbol{\theta}\).Lower Bound (ELBO) Justification:
The procedure as a whole can be justified by the fact that \(\mathcal{L}\) is a lower bound on \(\log p(\boldsymbol{v})\).
In the case of MAP inference, this justification is rather vacuous, because the bound is infinitely loose, due to the Dirac distribution’s differential entropy of negative infinity.
Adding noise to \(\mu\) would make the bound meaningful again.MAP Inference in Deep Learning - Applications:
MAP Inference is commonly used in deep learning as both a feature extractor and a learning mechanism.
It is primarily used for sparse coding models.MAP Inference in Sparse Coding Models:
Summary:
Learning algorithms based on MAP inference enable us to learn using a point estimate of \(p(\boldsymbol{h} \vert \boldsymbol{v})\) rather than inferring the entire distribution.
- We require \(q\) to take on a Dirac distribution:
-
Variational Inference and Learning:
Main Idea - Restricting family of distributions \(q\):
The core idea behind variational learning is that we can maximize \(\mathcal{L}\) over a restricted family of distributions \(q\).
This family should be chosen so that it is easy to compute \(\mathbb{E}_ {q} \log p(\boldsymbol{h}, \boldsymbol{v})\).
A typical way to do this is to introduce assumptions about how \(q\) factorizes.
Mainly, we make a Mean-Field Approximation to \(q\).Mean-Field Approximation:
Mean-Field Approximation is a type of Variational Bayesian Inference where we assume that the unknown variables can be partitioned so that each partition is independent of the others.
The Mean-Field Approximation assumes the variational distribution over the latent variables factorizes as:$$q(\boldsymbol{h} \vert \boldsymbol{v})=\prod_{i} q\left(h_{i} \vert \boldsymbol{v}\right)$$
I.E. it imposes the restriction that \(q\) is a factorial distribution.
More generally, we can impose any graphical model structure we choose on \(q,\) to flexibly determine how many interactions we want our approximation to capture.
This fully general graphical model approach is called structured variational inference (Saul and Jordan, 1996).The Optimal Probability Distribution \(q\):
The beauty of the variational approach is that we do not need to specify a specific parametric form for \(q\).
We specify how it should factorize, but then the optimization problem determines the optimal probability distribution within those factorization constraints.
The Inference Optimization Problem:- For discrete latent variables: we use traditional optimization techniques to optimize a finite number of variables describing the \(q\) distribution.
- For continuous latent variables: we use calculus of variations to perform optimization over a space of functions and actually determine which function should be used to represent \(q\).
- Calculus of Variations removes much of the responsibility from the human designer of the model, who now must specify only how \(q\) factorizes, rather than needing to guess how to design a specific \(q\) that can accurately approximate the posterior.
Calculus of variations is the origin of the names “variational learning” and “variational inference”, but the names apply in both discrete and continuous cases.
KL-Divergence Optimization:
- The Inference Optimization Problem boils down to maximizing \(\mathcal{L}\) with respect to \(q\).
- This is equivalent to minimizing \(D_{\mathrm{KL}}(q(\boldsymbol{h} \vert \boldsymbol{v}) \| p(\boldsymbol{h} \vert \boldsymbol{v}))\).
- Thus, we are fitting \(q\) to \(p\).
- However, we are doing so with the opposite direction of the KL-Divergence. We are, unnaturally, assuming that \(q\) is constant and \(p\) is varying.
- In the inference optimization problem, we choose to use \(D_{\mathrm{KL}}\left(q(\boldsymbol{h} \vert \boldsymbol{v}) \| p(\boldsymbol{h} \vert \boldsymbol{v})\right)\) for computational reasons.
- Specifically, computing \(D_{\mathrm{KL}}\left(q(\boldsymbol{h} \vert \boldsymbol{v}) \| p(\boldsymbol{h} \vert \boldsymbol{v})\right)\) involves evaluating expectations with respect to \(q,\) so by designing \(q\) to be simple, we can simplify the required expectations.
- The opposite direction of the KL divergence would require computing expectations with respect to the true posterior.
Because the form of the true posterior is determined by the choice of model, we cannot design a reduced-cost approach to computing \(D_{\mathrm{KL}}(p(\boldsymbol{h} \vert \boldsymbol{v}) \| q(\boldsymbol{h} \vert \boldsymbol{v}))\) exactly.
- Three Cases for Optimization:
- If \(q\) is high and \(p\) is high, then we are happy (i.e. low KL divergence).
- If \(q\) is high and \(p\) is low then we pay a price (i.e. high KL divergence).
- If \(q\) is low then we dont care (i.e. also low KL divergence, regardless of \(p\)).
- Optimization-based Inference vs Maximum Likelihood (ML) Learning:
- ML-Learning: fits a model to data by minimizing \(D_{\mathrm{KL}}\left(p_{\text {data }} \| p_{\text {model }}\right)\).
It encourages the model to have high probability everywhere that the data has high probability, - Optimization-based Inference:
It encourages \(q\) to have low probability everywhere the true posterior has low probability.
- ML-Learning: fits a model to data by minimizing \(D_{\mathrm{KL}}\left(p_{\text {data }} \| p_{\text {model }}\right)\).
Variational (Bayesian) Inference:
Variational Bayesian Inference AKA Variational Bayes is most often used to infer the conditional distribution over the latent variables given the observations (and parameters).
This is also known as the posterior distribution over the latent variables:$$p(z \vert x, \alpha)=\frac{p(z, x \vert \alpha)}{\int_{z} p(z, x \vert \alpha)}$$
which is usually intractable.
Notes:
- KL Divergence Optimization:
Optimizing the KL-Divergence given by:$$D_{\mathrm{KL}}(q \| p) = \mathrm{E}_ {z \sim q}\left[\log \frac{q(z)}{p(z\vert x)}\right] =\int_{z} q(z) \log \left(\frac{q(z)}{p(z\vert x)}\right) dz$$
- Three Cases for Optimization:
- If \(q\) is high and \(p\) is high, then we are happy (i.e. low KL divergence).
- If \(q\) is high and \(p\) is low then we pay a price (i.e. high KL divergence).
- If \(q\) is low then we dont care (i.e. also low KL divergence, regardless of \(p\)).
- Three Cases for Optimization:
Variational Inference and Learning
- Variational Inference - Discrete Latent Variables:
Variational Inference with Discrete Latent Variables is relatively straightforward.
Representing \(q\):
We define a distribution \(q\) where each factor of \(q\) is just defined by a lookup table over discrete states.
In the simplest case, \(h\) is binary and we make the mean field assumption that \(q\) factorizes over each individual \(h_{i}\).
In this case we can parametrize \(q\) with a vector \(\hat{h}\) whose entries are probabilities.
Then \(q\left(h_{i}=1 \vert \boldsymbol{v}\right)=\hat{h}_ {i}\).
Optimizing \(q\):
After determining how to represent \(q\) we simply optimize its parameters.
For discrete latent variables this is just a standard optimization problem e.g. gradient descent.
However, because this optimization must occur in the inner loop of a learning algorithm, it must be very fast1.
A popular choice is to iterate fixed-point equations; to solve:$$\frac{\partial}{\partial \hat{h}_ {i}} \mathcal{L}=0$$
for \(\hat{h}_ {i}\).
We repeatedly update different elements of \(\hat{\boldsymbol{h}}\) until we satisfy a convergence criterion.Application - Binary Sparse Coding:
-
Variational Inference - Continuous Latent Variables:
Variational Inference and Learning with Continuous Latent Variables requires the use of the calculus of variations for maximizing \(\mathcal{L}\) with respect to \(q(\boldsymbol{h} \vert \boldsymbol{v})\).In most cases, practitioners need not solve any calculus of variations problems themselves. Instead, there is a general equation for the mean field fixed-point updates.
The General Equation for Mean-Field Fixed-Point Updates:
If we make the mean field approximation$$q(\boldsymbol{h} \vert \boldsymbol{v})=\prod_{i} q\left(h_{i} \vert \boldsymbol{v}\right)$$
and fix \(q\left(h_{j} \vert \boldsymbol{v}\right)\) for all \(j \neq i,\) then the optimal \(q\left(h_{i} \vert \boldsymbol{v}\right)\) may be obtained by normalizing the unnormalized distribution:
$$\tilde{q}\left(h_{i} \vert \boldsymbol{v}\right) = \exp \left(\mathbb{E}_{\mathbf{h}_{-i} \sim q\left(\mathbf{h}_ {-i} \vert \boldsymbol{v}\right)} \log \tilde{p}(\boldsymbol{v}, \boldsymbol{h})\right) = e^{\mathbb{E}_{\mathbf{h}_ {-i} \sim q\left(\mathbf{h}_ {-i} \vert \boldsymbol{v}\right)} \log \tilde{p}(\boldsymbol{v}, \boldsymbol{h})}$$
as long as \(p\) does not assign \(0\) probability to any joint configuration of variables.
- Carrying out the expectation inside the equation will yield the correct functional form of \(q\left(h_{i} \vert \boldsymbol{v}\right)\).
- The General Equation yields the mean field approximation for any probabilistic model.
- Deriving functional forms of \(q\) directly using calculus of variations is only necessary if one wishes to develop a new form of variational learning.
- The General Equation is a fixed-point equation, designed to be iteratively applied for each value of \(i\) repeatedly until convergence.Functional Form of the Optimal Distribution/Solution:
The General Equation tells us the functional form that the optimal solution will take, whether we arrive there by fixed-point equations or not.
This means we can take the functional form from that equation but regard some of the values that appear in it as parameters, which we can optimize with any optimization algorithm we like.
For examples of real applications of variational learning with continuous variables in the context of deep learning, see Goodfellow et al. (2013d).
- Interactions between Learning and Inference:
Using approximate inference as part of a learning algorithm affects the learning process, and this in turn affects the accuracy of the inference algorithm.
Analysis:- The training algorithm tends to adapt the model in a way that makes the approximating assumptions underlying the approximate inference algorithm become more true.
- When training the parameters, variational learning increases
$$\mathbb{E}_ {\mathbf{h} \sim q} \log p(\boldsymbol{v}, \boldsymbol{h})$$
- For a specific \(v\) this:
- increases \(p(\boldsymbol{h} \vert \boldsymbol{v})\) for values of \(\boldsymbol{h}\) that have high probability under \(q(\boldsymbol{h} \vert \boldsymbol{v})\) and
- decreases \(p(\boldsymbol{h} \vert \boldsymbol{v})\) for values of \(\boldsymbol{h}\) that have low probability under \(q(\boldsymbol{h} \vert \boldsymbol{v})\).
- This behavior causes our approximating assumptions to become self-fulfilling prophecies.
If we train the model with a unimodal approximate posterior, we will obtain a model with a true posterior that is far closer to unimodal than we would have obtained by training the model with exact inference.
Computing the Effect (Harm) of using Variational Inference:
Computing the true amount of harm imposed on a model by a variational approximation is thus very difficult.- There exist several methods for estimating \(\log p(\boldsymbol{v})\):
We often estimate \(\log p(\boldsymbol{v} ; \boldsymbol{\theta})\) after training the model and find that the gap with \(\mathcal{L}(\boldsymbol{v}, \boldsymbol{\theta}, q)\) is small.- From this, we can conclude that our variational approximation is accurate for the specific value of \(\boldsymbol{\theta}\) that we obtained from the learning process.
- We should not conclude that our variational approximation is accurate in general or that the variational approximation did little harm to the learning process.
- To measure the true amount of harm induced by the variational approximation:
- We would need to know \(\boldsymbol{\theta}^{* }=\max_{\boldsymbol{\theta}} \log p(\boldsymbol{v} ; \boldsymbol{\theta})\).
- It is possible for \(\mathcal{L}(\boldsymbol{v}, \boldsymbol{\theta}, q) \approx \log p(\boldsymbol{v} ; \boldsymbol{\theta})\) and \(\log p(\boldsymbol{v} ; \boldsymbol{\theta}) \ll \log p\left(\boldsymbol{v} ; \boldsymbol{\theta}^{* }\right)\) to hold simultaneously.
- If \(\max_{q} \mathcal{L}\left(\boldsymbol{v}, \boldsymbol{\theta}^{* }, q\right) \ll \log p\left(\boldsymbol{v} ; \boldsymbol{\theta}^{* }\right),\) because \(\boldsymbol{\theta}^{* }\) induces too complicated of a posterior distribution for our \(q\) family to capture, then the learning process will never approach \(\boldsymbol{\theta}^{* }\).
- Such a problem is very difficult to detect, because we can only know for sure that it happened if we have a superior learning algorithm that can find \(\boldsymbol{\theta}^{* }\) for comparison.
-
Learned Approximate Inference:
Motivation:
Explicitly performing optimization via iterative procedures such as fixed-point equations or gradient-based optimization is often very expensive and time consuming.
Many approaches to inference avoid this expense by learning to perform approximate inference.Learned Approximate Inference:
Learns to perform approximate inference by viewing the (multistep iterative) optimization process as a function \(f\) that maps an input \(v\) to an approximate distribution \(q^{* }=\arg \max_{q} \mathcal{L}(\boldsymbol{v}, q)\), and then approximates this function with a neural network that implements an approximation \(f(\boldsymbol{v} ; \boldsymbol{\theta})\).Wake-Sleep:
Motivation:- One of the main difficulties with training a model to infer \(h\) from \(v\) is that we do not have a supervised training set with which to train the model.
- Given a \(v\) we do not know the appropriate \(h\).
- The mapping from \(v\) to \(h\) depends on the choice of model family, and evolves throughout the learning process as \(\theta\) changes.
Wake-Sleep Algorithm:
The wake-sleep algorithm (Hinton et al., 1995b; Frey et al., 1996) resolves this problem by drawing samples of both \(h\) and \(v\) from the model distribution.- For example, in a directed model, this can be done cheaply by performing ancestral sampling beginning at \(h\) and ending at \(v\).
The inference network can then be trained to perform the reverse mapping: predicting which \(h\) caused the present $\boldsymbol{v}$.
DrawBacks:
The main drawback to this approach is that we will only be able to train the inference network on values of \(\boldsymbol{v}\) that have high probability under the model.
Early in learning, the model distribution will not resemble the data distribution, so the inference network will not have an opportunity to learn on samples that resemble data.Relation to Biological Dreaming:
Generative Modeling - Application:
Learned approximate inference has recently become one of the dominant approaches to generative modeling, in the form of the Variational AutoEncoder (Kingma, 2013; Rezende et al., 2014).
In this elegant approach, there is no need to construct explicit targets for the inference network.
Instead, the inference network is simply used to define \(\mathcal{L},\) and then the parameters of the inference network are adapted to increase \(\mathcal{L}\).
Mathematics of Approximate Inference
Directional Derivative
The Calculus of Variations (Blog!)
-
Calculus of Variations:
Method for finding the stationary functions of a functional \(I[f]\) (function of functions) by solving a differential equation.
Formally, calculus of variations seeks to find the function \(y=f(x)\) such that the integral (functional):
$$I[y]=\int_{x_{1}}^{x_{2}} L\left(x, y(x), y^{\prime}(x)\right) d x$$
$$\begin{array}{l}{\text {where}}\\{x_{1}, x_{2} \text { are constants, }} \\ {y(x) \text { is twice continuously differentiable, }} \\ {y^{\prime}(x)=d y / d x} \\ {L\left(x, y(x), y^{\prime}(x)\right) \text { is twice continuously differentiable with respect to its arguments } x, y, y^{\prime}}\end{array}$$
is stationary.
Euler Lagrange Equation - Finding Extrema:
Finding the extrema of functionals is similar to finding the maxima and minima of functions. The maxima and minima of a function may be located by finding the points where its derivative vanishes (i.e., is equal to zero). The extrema of functionals may be obtained by finding functions where the functional derivative is equal to zero. This leads to solving the associated Euler–Lagrange equation.The Euler Lagrange Equation is a second-order partial differential equation whose solutions are the functions for which a given functional is stationary:
$$\frac{\partial L}{\partial f}-\frac{d}{d x} \frac{\partial L}{\partial f^{\prime}} = 0$$
It is defined in terms of the functional derivative:
$$\frac{\delta J}{\delta f(x)} = \frac{\partial L}{\partial f}-\frac{d}{d x} \frac{\partial L}{\partial f^{\prime}} = 0$$
Shortest Path between Two Points:
Find path such that the distance \(AB\) between two points is minimized.
Using the arc length, we define the following functional:$$\begin{align} I &= \int_{A}^{B} dS \\ &= \int_{A}^{B} \sqrt{dx^2 + dy^2} \\ &= \int_{A}^{B} \sqrt{1 + \left(\dfrac{dy}{dx}\right)^2} dx \\ &= \int_{x_1}^{x_2} \sqrt{1 + \left(\dfrac{dy}{dx}\right)^2} dx \end{align} $$
- Now, we formulate the variational problem:
Find the extremal function \(y=f(x)\) between two points \(A=(x_1, y_1)\) and \(B=(x_2, y_2)\) such that the following integral is minimized:$$I[y] = \int_{x_{1}}^{x_{2}} \sqrt{1+\left[y^{\prime}(x)\right]^{2}} d x$$
where \(y^{\prime}(x)=\frac{d y}{d x}, y_{1}=f\left(x_{1}\right), y_{2}=f\left(x_{2}\right)\).
- Solution:
We use the Euler-Lagrange Equation to find the extremal function \(f(x)\) that minimizes the functional \(I[y]\):$$\frac{\partial L}{\partial f}-\frac{d}{d x} \frac{\partial L}{\partial f^{\prime}}=0$$
where \(L=\sqrt{1+\left[f^{\prime}(x)\right]^{2}}\).
- Since \(f\) does not appear explicity in \(L,\) the first term in the Euler-Lagrange equation vanishes for all \(f(x)\)
$$\frac{\partial L}{\partial f} = 0$$
- Thus,
$$\frac{d}{d x} \frac{\partial L}{\partial f^{\prime}}=0$$
- Substituting for \(L\) and taking the derivative:
$$\frac{d}{d x} \frac{f^{\prime}(x)}{\sqrt{1+\left[f^{\prime}(x)\right]^{2}}}=0$$
for some constant \(c\).
- If the derivative \(\frac{d}{dx}\), above, is zero, then
$$\frac{f^{\prime}(x)}{\sqrt{1+\left[f^{\prime}(x)\right]^{2}}}=c$$
for some constant \(c\).
- Square both sides:
$$\frac{\left[f^{\prime}(x)\right]^{2}}{1+\left[f^{\prime}(x)\right]^{2}}=c^{2}$$
where \(0 \leq c^{2}<1\).
- Solving:
$$\left[f^{\prime}(x)\right]^{2}=\frac{c^{2}}{1-c^{2}}$$
\(\implies\)
$$f^{\prime}(x)=m$$
is a constant \(m\).
- Integrating:
$$f(x)=m x+b$$
is an equation of a (straight) line, where \(m=\frac{y_{2}-y_{1}}{x_{2}-x_{1}} \quad\) and \(\quad b=\frac{x_{2} y_{1}-x_{1} y_{2}}{x_{2}-x_{1}}\).
In other words, the shortest distance between two points is a straight line.
We have found the extremal function \(f(x)\) that minimizes the functional \(A[y]\) so that \(A[f]\) is a minimum.
- Since \(f\) does not appear explicity in \(L,\) the first term in the Euler-Lagrange equation vanishes for all \(f(x)\)
- Now, we formulate the variational problem:
-
Mean Field Methods:
-
Mean Field Approximations:
-
To achieve this speed, we typically use special optimization algorithms that are designed to solve comparatively small and simple problems in few iterations. ↩