Variational Autoencoders (VAEs)

4 min read

A Variational Autoencoder (VAE) is a probabilistic generative model that can learn the underlying distribution of data and generate new samples that look like the original data.

Think of VAEs as autoencodersโ€”but with a twist that allows them to generate entirely new, meaningful data rather than just compressing and reconstructing input.


๐Ÿ” The Intuition #

Letโ€™s say you show a VAE thousands of handwritten digits (MNIST dataset). It learns:

  • How to compress each digit into a compact “essence” (called the latent space).
  • How to generate new digits by sampling from that latent space.

So instead of memorizing data, it learns to model the probability distribution of digits. Once trained, you can sample a point in latent space, and the VAE will imagine a new digit that “could exist.”


๐Ÿ—๏ธ Architecture Overview #

A VAE consists of two main parts:

1. Encoder (Inference Network): #

  • Compresses input data into a latent vector (z).
  • Outputs two vectors: mean (ฮผ) and standard deviation (ฯƒ).
  • Instead of outputting a fixed point, it defines a distribution in latent space.

2. Decoder (Generative Network): #

  • Takes the sampled latent vector (z) and tries to reconstruct the original data.
  • Learns to decode the “essence” back into a complete image, sentence, etc.

๐Ÿงฎ The Reparameterization Trick #

To allow backpropagation through the random sampling process, we use this trick:

Instead of sampling directly from the distribution:

z = ฮผ + ฯƒ * ฮต

Where ฮต ~ N(0, 1) (standard normal noise)

This lets the model be trained end-to-end with gradient descent.


๐Ÿ“‰ VAE Loss Function #

The VAE loss has two parts:

  1. Reconstruction Loss
    Measures how well the output matches the input (e.g., MSE or binary cross-entropy).
  2. KL Divergence (Regularization Term)
    Measures how close the learned latent distribution is to a standard normal distribution.

Full Loss: #

L=Reconstruction Loss+ฮฒโ‹…DKL(q(zโˆฃx)โˆฅp(z))\mathcal{L} = \text{Reconstruction Loss} + \beta \cdot D_{KL}(q(z|x) \parallel p(z))

  • q(zโˆฃx)q(z|x) is the encoder’s output (posterior)
  • p(z)p(z) is the prior (usually standard normal N(0,1)\mathcal{N}(0, 1))
  • ฮฒ\beta controls how strong the regularization is (in ฮฒ-VAE, itโ€™s tunable)

๐Ÿงช VAE in Action โ€“ A Simple PyTorch Example #

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc_mu = nn.Linear(400, 20)
        self.fc_logvar = nn.Linear(400, 20)
        self.fc2 = nn.Linear(20, 400)
        self.fc3 = nn.Linear(400, 784)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

Loss Function: #

def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

๐Ÿ” Comparison: Autoencoder vs. VAE #

FeatureAutoencoderVAE
TypeDeterministicProbabilistic
Latent spaceFixed encodingDistribution (mean & variance)
Generation capabilityPoorExcellent
Use caseCompression/ReconstructionGeneration/Interpolation

๐Ÿ–ผ๏ธ Use Cases of VAEs #

  1. Image Generation
    Create realistic faces, handwritten digits, or other image types.
  2. Anomaly Detection
    Train a VAE on normal data. High reconstruction loss = anomaly.
  3. Denoising
    VAE learns to reconstruct clean images from noisy inputs.
  4. Representation Learning
    Learn structured, low-dimensional latent representations of data.
  5. Semi-Supervised Learning
    Leverage unsupervised VAEs with a few labeled examples.

๐ŸŒŒ Visualizing Latent Space #

One of the coolest features of VAEs: you can plot the latent space (especially with 2D or 3D latent dimensions).

  • Nearby points often generate similar outputs.
  • You can smoothly interpolate between two points to blend one digit into another.

Example: #

Interpolate between โ€œ3โ€ and โ€œ8โ€ to generate morphing digits.


๐Ÿง  Advantages of VAEs #

โœ… Simple and stable to train
โœ… Smooth and interpretable latent space
โœ… Good for scientific applications (due to probabilistic nature)


โš ๏ธ Limitations #

โŒ Often generate blurry images
โŒ Struggles to match GANs in visual realism
โŒ Needs carefully tuned KL loss to avoid under/over-regularization


๐Ÿš€ Summary #

ComponentRole
EncoderMaps data โ†’ latent distribution
ReparameterizationEnables gradient flow
DecoderGenerates data from latent z
KL DivergenceRegularizes latent space
ReconstructionForces fidelity to input

VAEs are a powerful tool for generating new data with structure, and they serve as a bridge between simple autoencoders and more complex models like GANs or Diffusion Models.


Would you like a tutorial notebook or a Google Colab link with a working VAE example on a dataset like MNIST or CIFAR-10?

Updated on June 6, 2025