VAEs

Resources

Main Idea

In order to create a generative model for p(x), we assume there is some latent distribution on z, such that we can transform (decode) samples from this to resemble the data. We maximize a lower bound (ELBO) on the likelihood over the observed data, thus keeping the data fixed as the truth, while letting our model vary (like in Bayesian Methods).

Aside

Within the field of Generative Modeling, VAEs extend autoencoders by imposing an organized, probabilistic structure on the latent space (regularization).

Later, we will use this result from Bayes Theorem

p(x)=p(x|z)p(z)p(z|x)

Variational Inference

Variational Inference is a method to approximate a distribution. It is simply parameterizing a distribution, and using some loss function (such as the KL-divergence) to determine which set of parameters best fit that distribution.

DKL(q(y) || p(y))=Eyqq(y)p(y)

Formulation

The prior p(z)=N(0,I), and pθ(x|z) is the probabilistic decoder, taking input z and returning a distribution over the data x. By using the result above, we can theoretically use the prior and the likelihood to calculate p(z|x) as

p(z|x)=p(x|z)p(z)p(x)=p(x|z)p(z)zp(x|z)p(z)dz,

but the integral in the denominator is an integration in whatever dimension z is, making it intractable.

Thus, we use Variational Inference to approximate p(z|x)qϕ(z|x). Note that all variations of p depend on θ, so ϕ must be tied to θ for this approximation to be accurate. We choose that qϕ is Gaussian, with qϕ(z|x)=N(Eϕ(x),σϕ(x)2I). Taking input x and outputting a distribution in the latent space, qϕ(z|x) is the probabilistic encoder, aiming to approximate the posterior.

Similarly, we assume that pθ(x|z) is also Gaussian, pθ(x|z)=N(Dθ(z),cI) with c>0. This allows us to write the log-likelihood term as a mean-squared error (MSE). Note that the encoder mean / variance and the decoder mean are often parameterized through the standard machinery in Deep Learning: Neural Networks.

Evidence Lower Bound (ELBO)

Begin with

p(x)=p(x|z)p(z)q(z|x)p(z|x)q(z|x)

Take the log, then split into three terms

logp(x)=logp(x|z)logq(z|x)p(z)+logq(z|x)p(z|x)

Then, take the expectation over z drawn from the approximate posterior: zq(z|x), which turns the log fraction terms into DKL terms. The left hand size does not depend on z, so it's expectation is dropped.

logp(x)=Ezq(z|x)logp(x|z)DKL(q(z|x) || p(z))+DKL(q(z|x) || p(z|x))

The term DKL(q(z|x) || p(z|x)) involves the intractable posterior. This term is also nonnegative, thus dropping it gives the inequality

logp(x)Ezq(z|x)logp(x|z)DKL(q(z|x) || p(z))ELBO

Note that this term is also equal to DKL(q(z|x) || p(z|x)).

Info-VAE or MMD-VAE

A standard VAE forms a loss as a reconstruction and latent-space structure term, such as

L(x)=Ezqϕ(z|x)[logp(x|z)]+βD(qϕ(z|x)||p(z)),

where D is taken as the KL-Divergence. In the above, the term D(qϕ(z|x)||p(z)) encourages qϕ(z|x) to match p(z), regardless of x. So the solution would be qϕ(z|x)=p(z), which would imply that the encoder qϕ actually contains no information from the data x. So while we could tune β to find a balancing point between these competing objectives, we can instead find a way of loosening the requirement itself. While we want the latent space to have some structure, this is in general (in expectation), not for each data sample.

Motivated by this, we can replace the latent-space structure term with an average term, e.g. via Jensen's Inequality if D is convex in the first argument (like KL-Divergence is),

D(qϕ(z)Ex[qϕ(z|x)]||p(z))Ex[D(qϕ(z|x)||p(z))].

Thus, the aforementioned lost is unnecessarily strict for our goal. The new form would be

L(x)=Ezqϕ(z|x)[logp(x|z)]+βD(qϕ(z)||p(z)),

where we will be taking Ex by cycling through training data. The key difference here is that the divergence term depends on the distribution marginalized over x.

In this form, we can't compute DKL(qϕ(z)||p(z)) analytically like before. Thus, we'd have to estimate anyways, so we may as well take D as another Divergence, such as the Maximum Mean Discrepancy (MMD). To do this, we take a batch of x, compute qϕ(z|x) (such as μϕ(x), σϕ2(x)), and take one z from each of these (reparameterization trick). The p(z) is easier to sample. We then use these samples to compute an empirical estimate of MMD (replace expectations with empirical means). The negative log-likelihood term stays the same.

From Expectation to a Practical Loss

For the vanilla VAE, we had the sample loss as

L(x)=Ezq(z|x)logp(x|z)+DKL(q(z|x) || p(z)).

The first expectation is approximated empirically with a single sample. This is the basis of the reparameterization trick, which effectively allows sampling from a parameterized distribution without breaking the computational graph for automatic differentiation. Because q(z|x) is Gaussian, we can draw from the standard normal, then scale by the standard deviation from q(z|x) and add the mean from q(z|x) (these are really the objects the encoder outputs anyways). The KL-divergence has an analytic form when it is between Gaussians, so this simplifies just in terms of the mean and standard deviation of qϕ, which are μϕ and σϕ (diagonal).

So the loss is really

z^qϕ(z|x)L(x)=logpθ(x|z^)+12(tr(diag(σϕ2(x))+μϕT(x)μϕ(x)d+log(j=1d(σϕ2(x))j)).

Further, by assuming that there is additive Gaussian noise on the reconstruction target (not that the data is Gaussian) with mean Dθ(z^) and standard deviation σ, the reconstruction term simplifies to the Mean Squared Error, giving the final form as

z^qϕ(z|x)L(x)=12σ2||xDθ(z^)||22+12(tr(diag(σϕ2(x))+μϕT(x)μϕ(x)d+log(j=1d(σϕ2(x))j)).

In practice, we use some Unconstrained Optimization method over "batch" estimates of this per-sample loss as a surrogate for the expected loss:

minθ,ϕEx[L(x)].