Higgins, Irina, Loic Matthey, Arka Pal, Christopher Burgess, Xavier Glorot, Matthew Botvinick, Shakir Mohamed, and Alexander Lerchner. 2017. “Β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework.”
In order to create a generative model for , we assume there is some latent distribution on , such that we can transform (decode) samples from this to resemble the data. We maximize a lower bound (ELBO) on the likelihood over the observed data, thus keeping the data fixed as the truth, while letting our model vary (like in Bayesian Methods).
Aside
Within the field of Generative Modeling, VAEs extend autoencoders by imposing an organized, probabilistic structure on the latent space (regularization).
Variational Inference is a method to approximate a distribution. It is simply parameterizing a distribution, and using some loss function (such as the KL-divergence) to determine which set of parameters best fit that distribution.
Formulation
The prior , and is the probabilistic decoder, taking input and returning a distribution over the data . By using the result above, we can theoretically use the prior and the likelihood to calculate as
but the integral in the denominator is an integration in whatever dimension is, making it intractable.
Thus, we use Variational Inference to approximate . Note that all variations of depend on , so must be tied to for this approximation to be accurate. We choose that is Gaussian, with . Taking input and outputting a distribution in the latent space, is the probabilistic encoder, aiming to approximate the posterior.
Similarly, we assume that is also Gaussian, with . This allows us to write the log-likelihood term as a mean-squared error (MSE). Note that the encoder mean / variance and the decoder mean are often parameterized through the standard machinery in Deep Learning: Neural Networks.
Evidence Lower Bound (ELBO)
Begin with
Take the , then split into three terms
Then, take the expectation over drawn from the approximate posterior: , which turns the fraction terms into terms. The left hand size does not depend on , so it's expectation is dropped.
The term involves the intractable posterior. This term is also nonnegative, thus dropping it gives the inequality
Note that this term is also equal to .
Info-VAE or MMD-VAE
A standard VAE forms a loss as a reconstruction and latent-space structure term, such as
where is taken as the KL-Divergence. In the above, the term encourages to match , regardless of . So the solution would be which would imply that the encoder actually contains no information from the data . So while we could tune to find a balancing point between these competing objectives, we can instead find a way of loosening the requirement itself. While we want the latent space to have some structure, this is in general (in expectation), not for each data sample.
Motivated by this, we can replace the latent-space structure term with an average term, e.g. via Jensen's Inequality if is convex in the first argument (like KL-Divergence is),
Thus, the aforementioned lost is unnecessarily strict for our goal. The new form would be
where we will be taking by cycling through training data. The key difference here is that the divergence term depends on the distribution marginalized over .
In this form, we can't compute analytically like before. Thus, we'd have to estimate anyways, so we may as well take as another Divergence, such as the Maximum Mean Discrepancy (MMD). To do this, we take a batch of , compute (such as , ), and take one from each of these (reparameterization trick). The is easier to sample. We then use these samples to compute an empirical estimate of MMD (replace expectations with empirical means). The negative log-likelihood term stays the same.
From Expectation to a Practical Loss
For the vanilla VAE, we had the sample loss as
The first expectation is approximated empirically with a single sample. This is the basis of the reparameterization trick, which effectively allows sampling from a parameterized distribution without breaking the computational graph for automatic differentiation. Because is Gaussian, we can draw from the standard normal, then scale by the standard deviation from and add the mean from (these are really the objects the encoder outputs anyways). The KL-divergence has an analytic form when it is between Gaussians, so this simplifies just in terms of the mean and standard deviation of , which are and (diagonal).
So the loss is really
Further, by assuming that there is additive Gaussian noise on the reconstruction target (not that the data is Gaussian) with mean and standard deviation , the reconstruction term simplifies to the Mean Squared Error, giving the final form as
In practice, we use some Unconstrained Optimization method over "batch" estimates of this per-sample loss as a surrogate for the expected loss: