はじめに
Transformerは自然言語処理(NLP)を起点に、画像認識や音声処理など幅広い分野で標準的なアーキテクチャとなっています。その中核を担うのがSelf-Attention機構です。
Self-Attentionは、入力系列の各要素が他のすべての要素との関連度を動的に計算し、文脈に応じた表現を獲得する仕組みです。固定的な重みで信号を処理する従来のフィルタ(たとえば指数移動平均)とは異なり、入力データに依存して重みが変化する点が大きな特徴です。
本記事では、Scaled Dot-Product AttentionとMulti-Head Attentionの数式を導出し、NumPyでスクラッチ実装した上で、PyTorchのnn.MultiheadAttentionとの比較を通じて仕組みを理解します。
なぜAttentionが必要なのか
RNN(再帰型ニューラルネットワーク)は系列データの処理に広く使われてきましたが、2つの根本的な課題があります。
- 逐次処理: 時刻 \(t\) の計算が時刻 \(t-1\) の結果に依存するため、並列化が困難
- 長距離依存性の学習困難: 系列が長くなると勾配消失・爆発により、離れた位置間の関係を学習しにくい
Attention機構はこれらの問題を解決します。各位置が他のすべての位置に直接アクセスでき、系列長に依存する逐次計算が不要です。さらに、Attention重みは入力から動的に計算されるため、固定的な構造に縛られません。
Scaled Dot-Product Attention
Query, Key, Valueの導出
入力系列 \(X \in \mathbb{R}^{n \times d_{\text{model}}}\)(\(n\) はトークン数、\(d_{\text{model}}\) はモデルの次元数)に対し、3つの線形変換を適用してQuery、Key、Valueを生成します。
\[Q = XW_Q, \quad K = XW_K, \quad V = XW_V \tag{1}\]ここで \(W_Q, W_K \in \mathbb{R}^{d_{\text{model}} \times d_k}\)、\(W_V \in \mathbb{R}^{d_{\text{model}} \times d_v}\) は学習可能な重み行列です。直感的には、Queryは「何を探しているか」、Keyは「何を持っているか」、Valueは「実際の情報」に対応します。
Attention計算
Attention関数は以下で定義されます。
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \tag{2}\]この計算を分解して理解します。
ステップ1: 類似度の計算
\[S = QK^T \in \mathbb{R}^{n \times n} \tag{3}\]\(S_{ij}\) はトークン \(i\) のQueryとトークン \(j\) のKeyの内積であり、2つのトークン間の類似度を表します。
ステップ2: スケーリング
\[S_{\text{scaled}} = \frac{S}{\sqrt{d_k}} \tag{4}\]\(d_k\) が大きいとき、内積の値も大きくなります。\(q\) と \(k\) が平均0、分散1の独立な成分を持つとき、内積 \(q \cdot k = \sum_{i=1}^{d_k} q_i k_i\) の分散は \(d_k\) になります。大きな値はsoftmaxを飽和領域に押し込み、勾配が極端に小さくなります。\(\sqrt{d_k}\) で割ることで分散を1に正規化し、この問題を回避します。
ステップ3: Attention重みの計算
\[A = \text{softmax}(S_{\text{scaled}}) \tag{5}\]softmaxにより各行が確率分布(合計1)になります。\(A_{ij}\) はトークン \(i\) がトークン \(j\) にどれだけ注目するかを表します。
ステップ4: 重み付き和
\[\text{Output} = AV \tag{6}\]各トークンの出力は、すべてのValueベクトルのAttention重みによる加重和です。これはソフトな辞書検索とみなせます。Queryで検索し、Keyとのマッチ度に応じてValueを取り出す操作です。
NumPyによるスクラッチ実装
Scaled Dot-Product Attention
import numpy as np
import matplotlib.pyplot as plt
def softmax(x, axis=-1):
"""数値的に安定なsoftmax"""
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return e_x / np.sum(e_x, axis=axis, keepdims=True)
def scaled_dot_product_attention(Q, K, V):
"""
Scaled Dot-Product Attention(式2)
Parameters:
Q: Query行列 (n, d_k)
K: Key行列 (n, d_k)
V: Value行列 (n, d_v)
Returns:
output: Attention出力 (n, d_v)
weights: Attention重み (n, n)
"""
d_k = Q.shape[-1]
# ステップ1-2: 類似度計算とスケーリング
scores = Q @ K.T / np.sqrt(d_k)
# ステップ3: Attention重み
weights = softmax(scores)
# ステップ4: 重み付き和
output = weights @ V
return output, weights
動作確認と可視化
np.random.seed(42)
# 入力: 4トークン、モデル次元8
n_tokens = 4
d_model = 8
d_k = d_v = 8
# ランダムな入力系列
X = np.random.randn(n_tokens, d_model)
# 重み行列(通常は学習で獲得される)
W_Q = np.random.randn(d_model, d_k) * 0.1
W_K = np.random.randn(d_model, d_k) * 0.1
W_V = np.random.randn(d_model, d_v) * 0.1
# Q, K, V の計算(式1)
Q = X @ W_Q
K = X @ W_K
V = X @ W_V
# Attention計算
output, weights = scaled_dot_product_attention(Q, K, V)
print("入力形状:", X.shape)
print("出力形状:", output.shape)
print("Attention重み:\n", np.round(weights, 3))
# Attention重みのヒートマップ
token_labels = ["Token 0", "Token 1", "Token 2", "Token 3"]
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(weights, cmap="Blues", vmin=0, vmax=1)
ax.set_xticks(range(n_tokens))
ax.set_yticks(range(n_tokens))
ax.set_xticklabels(token_labels)
ax.set_yticklabels(token_labels)
ax.set_xlabel("Key position")
ax.set_ylabel("Query position")
ax.set_title("Attention Weights")
for i in range(n_tokens):
for j in range(n_tokens):
ax.text(j, i, f"{weights[i, j]:.2f}",
ha="center", va="center", fontsize=11)
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
Multi-Head Attention
動機と定式化
単一のAttentionヘッドでは、1つの表現空間でしか類似度を計算できません。Multi-Head Attentionは、Q, K, Vを \(h\) 個の異なる部分空間に射影し、それぞれで独立にAttentionを計算することで、多様な関係性を同時に捉えます。
\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W_O \tag{7}\]\[\text{head}_i = \text{Attention}(QW_Q^i, KW_K^i, VW_V^i) \tag{8}\]各ヘッドの次元は \(d_k = d_v = d_{\text{model}} / h\) とするのが標準的です。\(W_O \in \mathbb{R}^{hd_v \times d_{\text{model}}}\) は出力射影行列です。パラメータ数は単一ヘッドのAttentionとほぼ同じですが、異なる部分空間での表現を獲得できます。
NumPy実装
class MultiHeadAttention:
"""Multi-Head Attention(式7, 8)"""
def __init__(self, d_model, n_heads, seed=0):
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
rng = np.random.RandomState(seed)
scale = 0.1
# 各ヘッドの射影行列
self.W_Q = rng.randn(n_heads, d_model, self.d_k) * scale
self.W_K = rng.randn(n_heads, d_model, self.d_k) * scale
self.W_V = rng.randn(n_heads, d_model, self.d_k) * scale
# 出力射影行列
self.W_O = rng.randn(n_heads * self.d_k, d_model) * scale
def forward(self, X):
"""
Parameters:
X: 入力系列 (n, d_model)
Returns:
output: Multi-Head Attention出力 (n, d_model)
all_weights: 各ヘッドのAttention重み (n_heads, n, n)
"""
head_outputs = []
all_weights = []
for i in range(self.n_heads):
Q = X @ self.W_Q[i]
K = X @ self.W_K[i]
V = X @ self.W_V[i]
head_out, weights = scaled_dot_product_attention(Q, K, V)
head_outputs.append(head_out)
all_weights.append(weights)
# ヘッドの結合と出力射影(式7)
concat = np.concatenate(head_outputs, axis=-1)
output = concat @ self.W_O
return output, np.array(all_weights)
Multi-Head Attentionの可視化
np.random.seed(42)
n_tokens = 6
d_model = 16
n_heads = 4
X = np.random.randn(n_tokens, d_model)
mha = MultiHeadAttention(d_model, n_heads, seed=42)
output, all_weights = mha.forward(X)
print("入力形状:", X.shape)
print("出力形状:", output.shape)
print("Attention重み形状:", all_weights.shape)
# 各ヘッドのAttention重みを可視化
fig, axes = plt.subplots(1, n_heads, figsize=(16, 4))
for h in range(n_heads):
im = axes[h].imshow(all_weights[h], cmap="Blues", vmin=0, vmax=1)
axes[h].set_title(f"Head {h}")
axes[h].set_xlabel("Key")
axes[h].set_ylabel("Query")
plt.suptitle("Multi-Head Attention Weights", fontsize=14)
plt.tight_layout()
plt.show()
位置エンコーディング
Self-Attentionは入力の順序に対して不変(permutation-invariant)です。つまり、トークンの並びを変えても出力は同じ集合になります。系列中の位置情報を注入するために、位置エンコーディング(Positional Encoding)が必要です。
Transformerの原論文では正弦波ベースの位置エンコーディングが提案されています。
\[PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \tag{9}\]\[PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \tag{10}\]ここで \(pos\) は系列中の位置、\(i\) は次元のインデックスです。各次元で異なる周波数の正弦波を使うことで、位置ごとにユニークなパターンを生成します。
実装と可視化
def positional_encoding(max_len, d_model):
"""
正弦波位置エンコーディング(式9, 10)
Parameters:
max_len: 最大系列長
d_model: モデル次元
Returns:
PE: 位置エンコーディング行列 (max_len, d_model)
"""
PE = np.zeros((max_len, d_model))
position = np.arange(max_len)[:, np.newaxis]
div_term = 10000 ** (2 * np.arange(d_model // 2) / d_model)
PE[:, 0::2] = np.sin(position / div_term)
PE[:, 1::2] = np.cos(position / div_term)
return PE
# 位置エンコーディングの可視化
max_len = 50
d_model = 64
PE = positional_encoding(max_len, d_model)
fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(PE, cmap="RdBu", aspect="auto")
ax.set_xlabel("Dimension")
ax.set_ylabel("Position")
ax.set_title("Sinusoidal Positional Encoding")
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
PyTorchのnn.MultiheadAttentionとの比較
PyTorchにはnn.MultiheadAttentionが実装されています。ここではスクラッチ実装と同一の重みを設定し、出力が一致することを確認します。
import torch
import torch.nn as nn
np.random.seed(42)
torch.manual_seed(42)
n_tokens = 4
d_model = 8
n_heads = 2
d_k = d_model // n_heads
# 入力データ
X_np = np.random.randn(n_tokens, d_model).astype(np.float32)
# --- スクラッチ実装 ---
mha_np = MultiHeadAttention(d_model, n_heads, seed=0)
out_np, _ = mha_np.forward(X_np)
# --- PyTorch実装 ---
mha_pt = nn.MultiheadAttention(d_model, n_heads, bias=False, batch_first=False)
# PyTorchのin_proj_weightにスクラッチ実装の重みをコピー
# PyTorchは[W_Q; W_K; W_V]を結合した形で保持 (3*d_model, d_model)
W_Q_cat = np.concatenate([mha_np.W_Q[i] for i in range(n_heads)], axis=1).T
W_K_cat = np.concatenate([mha_np.W_K[i] for i in range(n_heads)], axis=1).T
W_V_cat = np.concatenate([mha_np.W_V[i] for i in range(n_heads)], axis=1).T
in_proj_weight = np.concatenate([W_Q_cat, W_K_cat, W_V_cat], axis=0)
with torch.no_grad():
mha_pt.in_proj_weight.copy_(torch.from_numpy(in_proj_weight))
mha_pt.out_proj.weight.copy_(torch.from_numpy(mha_np.W_O.T))
# PyTorchはデフォルトで (seq_len, batch, d_model) の形式
X_pt = torch.from_numpy(X_np).unsqueeze(1) # (n_tokens, 1, d_model)
out_pt, _ = mha_pt(X_pt, X_pt, X_pt)
out_pt = out_pt.squeeze(1).detach().numpy()
# 比較
print("スクラッチ実装の出力:\n", np.round(out_np, 4))
print("PyTorchの出力:\n", np.round(out_pt, 4))
print("最大誤差:", np.max(np.abs(out_np - out_pt)))
同一の重みを設定することで、両者の出力が数値誤差の範囲で一致することを確認できます。
Attentionの直感的な理解
Self-Attentionの本質は、データに依存した重み付き平均です。
通常のフィルタ(移動平均やEMAなど)では、重みは入力に関係なく事前に決められています。一方、Self-Attentionでは重み(Attention重み行列 \(A\))が入力 \(X\) 自身から計算されます。
| 特性 | 固定重みフィルタ | Self-Attention |
|---|---|---|
| 重みの決定 | 事前に固定 | 入力から動的に計算 |
| 参照範囲 | 局所(窓幅に依存) | 系列全体(大域的) |
| 適応性 | なし | 入力ごとに重みが変化 |
| 計算量 | \(O(n)\) | \(O(n^2)\)(系列長の2乗) |
| 用途 | 信号平滑化、ノイズ除去 | 系列間の関係性モデリング |
この「入力依存の動的な重み付け」が、Attentionの強力さの源泉です。各トークンが文脈に応じて注目すべき相手を選択的に決定できるため、長距離依存性や複雑な構造を柔軟に捉えることができます。
まとめ
- Self-Attentionは入力系列の各要素間の関連度を動的に計算し、文脈に応じた表現を獲得する機構
- Scaled Dot-Product AttentionはQuery-Key間の内積を \(\sqrt{d_k}\) でスケーリングし、softmaxでAttention重みを求め、Valueの加重和を計算する
- Multi-Head Attentionは複数の部分空間で独立にAttentionを計算することで、多様なパターンを同時に捉える
- 位置エンコーディングは順序不変なSelf-Attentionに位置情報を注入する
- Attentionの本質はデータ依存の重み付き平均であり、固定重みフィルタとは根本的に異なる
関連記事
- 確率的勾配降下法からAdamまで - Transformerの学習に使われる最適化手法を解説しています。
- 指数移動平均(EMA)フィルタの周波数特性 - 固定重みフィルタとAttentionの動的重み付けの対比を理解するための参考記事です。
- アンサンブル学習の手法と比較 - 機械学習の別のアプローチであるアンサンブル手法を解説しています。
- SVM(サポートベクターマシン)とカーネル法 - カーネルトリックによる非線形変換とAttentionの類似性について理解を深められます。
- 移動平均フィルタの種類と比較 - Attentionを固定重みフィルタと対比するための参考記事です。
参考文献
- Vaswani, A., et al. (2017). “Attention Is All You Need.” NeurIPS 2017.
- PyTorch Documentation:
nn.MultiheadAttention. https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html - Alammar, J. (2018). “The Illustrated Transformer.” https://jalammar.github.io/illustrated-transformer/