In deep learning models, sampling from a probability distribution is typically a non-differentiable operation. This makes it impossible to use backpropagation to compute gradients and update model parameters.
To address this issue, D. P. Kingma and M. Welling introduced the Reparameterization Trick in their Variational Autoencoder (VAE) paper.
Principle
The Reparameterization Trick makes the sampling process differentiable by separating the sampling operation into two parts: random noise independent of the parameters and a deterministic transformation dependent on the parameters.
For example, consider sampling a random variable $z$ from a normal distribution $\mathcal{N}(\mu, \sigma^2)$ with mean $\mu$ and variance $\sigma^2$:
$$ z \sim \mathcal{N}(\mu, \sigma^2) $$
This sampling operation is non-differentiable. However, by using random noise $\epsilon$ sampled from a standard normal distribution $\mathcal{N}(0, 1)$, $z$ can be expressed as:
$$ z = \mu + \sigma \odot \epsilon $$ $$ \text{where } \epsilon \sim \mathcal{N}(0, 1) $$
($\odot$ denotes element-wise multiplication. In the multi-dimensional case, $\sigma$ corresponds to the vector of standard deviations or the square root of the covariance matrix.)
This transformation makes $z$ a deterministic function with respect to $\mu$ and $\sigma$, allowing $\mu$ and $\sigma$ (which are typically outputs of a neural network) to be differentiable. The randomness is encapsulated in $\epsilon$, and gradients can be backpropagated through $\mu$ and $\sigma$.

Sampling Using the Reparameterization Trick in PyTorch
PyTorch’s torch.distributions module provides distributions that support the reparameterization trick. Distribution objects such as the Normal class support reparameterized sampling through the rsample() method.
import torch
from torch.distributions import Normal
# Mean and standard deviation obtained as neural network outputs
# As an example, consider a normal distribution with mu=0, sigma=1
mu = torch.tensor(0.0, requires_grad=True)
sigma = torch.tensor(1.0, requires_grad=True)
# Create a normal distribution object
m = Normal(mu, sigma)
# Sample using the rsample() method
# This sampling applies the reparameterization trick, making gradient computation possible
z = m.rsample()
print(f"Sampled z: {z}")
print(f"z requires_grad flag: {z.requires_grad}") # True
You can check whether a distribution supports the reparameterization trick by examining the m.has_rsample property.
print(f"Does Normal distribution support rsample: {m.has_rsample}") # True
The reparameterization trick is a crucial technique not only in VAEs but also in training deep learning models with stochastic elements, such as stochastic policy gradient methods in reinforcement learning.
References
- Kingma, D. P., & Welling, M. (2013). “Auto-Encoding Variational Bayes”. arXiv preprint arXiv:1312.6114.
- ReNom, Variational Autoencoder