Reparameterization Trick

Explanation of the reparameterization trick introduced in VAEs, covering the mathematical principle and practical implementation using PyTorch's rsample() method.

In deep learning models, sampling from a probability distribution is typically a non-differentiable operation. This makes it impossible to use backpropagation to compute gradients and update model parameters.

To address this issue, D. P. Kingma and M. Welling introduced the Reparameterization Trick in their Variational Autoencoder (VAE) paper.

Principle

The Reparameterization Trick makes the sampling process differentiable by separating the sampling operation into two parts: random noise independent of the parameters and a deterministic transformation dependent on the parameters.

For example, consider sampling a random variable $z$ from a normal distribution $\mathcal{N}(\mu, \sigma^2)$ with mean $\mu$ and variance $\sigma^2$:

$$ z \sim \mathcal{N}(\mu, \sigma^2) $$

This sampling operation is non-differentiable. However, by using random noise $\epsilon$ sampled from a standard normal distribution $\mathcal{N}(0, 1)$, $z$ can be expressed as:

$$ z = \mu + \sigma \odot \epsilon $$ $$ \text{where } \epsilon \sim \mathcal{N}(0, 1) $$

($\odot$ denotes element-wise multiplication. In the multi-dimensional case, $\sigma$ corresponds to the vector of standard deviations or the square root of the covariance matrix.)

This transformation makes $z$ a deterministic function with respect to $\mu$ and $\sigma$, allowing $\mu$ and $\sigma$ (which are typically outputs of a neural network) to be differentiable. The randomness is encapsulated in $\epsilon$, and gradients can be backpropagated through $\mu$ and $\sigma$.

Concept Diagram

Sampling Using the Reparameterization Trick in PyTorch

PyTorch’s torch.distributions module provides distributions that support the reparameterization trick. Distribution objects such as the Normal class support reparameterized sampling through the rsample() method.

import torch
from torch.distributions import Normal

# Mean and standard deviation obtained as neural network outputs
# As an example, consider a normal distribution with mu=0, sigma=1
mu = torch.tensor(0.0, requires_grad=True)
sigma = torch.tensor(1.0, requires_grad=True)

# Create a normal distribution object
m = Normal(mu, sigma)

# Sample using the rsample() method
# This sampling applies the reparameterization trick, making gradient computation possible
z = m.rsample()

print(f"Sampled z: {z}")
print(f"z requires_grad flag: {z.requires_grad}") # True

You can check whether a distribution supports the reparameterization trick by examining the m.has_rsample property.

print(f"Does Normal distribution support rsample: {m.has_rsample}") # True

The reparameterization trick is a crucial technique not only in VAEs but also in training deep learning models with stochastic elements, such as stochastic policy gradient methods in reinforcement learning.

References