EMアルゴリズム(Expectation-Maximization Algorithm)は、観測データに隠れた変数(潜在変数)が含まれる統計モデルのパラメータを推定するための反復的なアルゴリズムです。
例えば、フィットネスクラブの利用者の年齢分布を考えた場合、筋トレ目的の20代とメタボ対策目的の50代という2つのグループが存在し、それぞれの年齢分布が正規分布に従うと仮定できます。この場合、各利用者がどちらのグループに属するかは直接観測できない「潜在変数」となります。
EMアルゴリズムの基本的な考え方は、以下の2つのステップを交互に繰り返すことで、モデルのパラメータと潜在変数の推定を同時に行うというものです。
- Eステップ (Expectation Step): 現在のモデルパラメータを使って、観測データから潜在変数の期待値(または確率分布)を推定します。上記の例では、各利用者がどちらのグループに属するかの「負担率」を計算します。
- Mステップ (Maximization Step): Eステップで推定された潜在変数の期待値を用いて、モデルのパラメータを最大化します。上記の例では、負担率を重みとして、各グループの正規分布の平均や分散といったパラメータを更新します。
このプロセスを繰り返すことで、モデルのパラメータと潜在変数の推定が徐々に改善され、最終的に対数尤度が局所最大値に収束します。
混合ガウスモデルとEMアルゴリズム
EMアルゴリズムは、複数のガウス分布が重なり合って観測データが生成されると仮定する混合ガウスモデル (Gaussian Mixture Model, GMM) のパラメータ推定によく用いられます。
Q関数
EMアルゴリズムでは、観測データ \(x\) と潜在変数 \(z\) の両方がわかっている仮想的な状況を「完全データ」と呼びます。完全データにおける同時分布 \(p(x, z | \theta)\) の対数尤度を考えます。
しかし、実際には潜在変数 \(z\) は未知です。そこで、現在のパラメータ推定値 \(\hat{\theta}\) を用いて、潜在変数 \(z\) の事後分布 \(p(z|x, \hat{\theta})\) を計算し、この分布に関する完全データの対数尤度の期待値を求めます。これがQ関数です。
\[ Q(\theta, \hat{\theta}) = \mathbb{E}\_{p(z|x,\hat{\theta})}[\log p(x,z|\theta)] = \int p(z|x,\hat{\theta})\log p(x,z|\theta)dz \]EMアルゴリズムのMステップでは、このQ関数を最大化するパラメータ \(\theta\) を求めます。
EMアルゴリズムによる混合ガウスモデルのパラメータ更新式
混合ガウスモデルのパラメータ(各ガウス分布の重み \(\pi_j\)、平均 \(\mu_j\)、分散 \(\sigma_j^2\))をEMアルゴリズムで推定する際の更新式は以下の通りです。

- 初期化: 各ガウス分布のパラメータ \(\hat{\pi}_j^{(0)}, \hat{\mu}_j^{(0)}, \hat{\sigma}_j^{2(0)}\) をランダムな値で初期化します。
- Eステップ: 各データ点 \(x_i\) が、どのガウス分布から生成されたかを示す「負担率」 \(r_{ij}\) を計算します。 \( r*{ij} = p(z*{ij}=1 | x*i, \hat{\theta}^{(t)}) = \frac{\hat{\pi}\_j^{(t)} \mathcal{N}(x_i | \hat{\mu}\_j^{(t)}, \hat{\sigma}\_j^{2(t)})}{\sum*{k=1}^K \hat{\pi}_k^{(t)} \mathcal{N}(x_i | \hat{\mu}\_k^{(t)}, \hat{\sigma}\_k^{2(t)})} \) ここで \(z_{ij}=1\) はデータ点 \(x_i\) が \(j\) 番目のガウス分布に属することを示します。
- Mステップ: Eステップで計算された負担率 \(r_{ij}\) を用いて、新しいパラメータ \(\hat{\theta}^{(t+1)}\) を計算します。 \( N*j = \sum*{i=1}^N r*{ij} \) \( \hat{\pi}\_j^{(t+1)} = \frac{N_j}{N} \) \( \hat{\mu}\_j^{(t+1)} = \frac{1}{N_j} \sum*{i=1}^N r*{ij} x_i \) \( \hat{\sigma}\_j^{2(t+1)} = \frac{1}{N_j} \sum*{i=1}^N r\_{ij} (x_i - \hat{\mu}\_j^{(t+1)})^2 \)
- 収束判定: パラメータの変化が十分に小さくなるか、最大反復回数に達するまでステップ2と3を繰り返します。
参考
- 手塚 太郎, 『しくみがわかるベイズ統計と機械学習』, 講談社 (2017)