深層学習モデルのパラメータ数は年々増加の一途をたどっており、高性能化の一方で計算リソースやストレージの消費量が課題となっています。特にモバイル機器やエッジデバイスへの展開において、モデルの軽量化は不可欠な技術です。本稿では、モデル圧縮の主要な手法である「低ランク近似」「枝刈り(プルーニング)」「知識蒸留」の原理と実装例について解説します。
モデル圧縮の基本概念
- モデル圧縮: 精度を維持しつつ、モデルのパラメータ数やサイズを削減する技術。
- 推論高速化: 計算量を減らし、レイテンシを短縮すること。
- 精度: データセットに対する予測の正確さ。
- モデルサイズ: モデルの保存に必要なメモリ容量。
主要な圧縮アルゴリズムの原理
1. 低ランク近似
重み行列を特異値分解(SVD)などにより分解し、低ランクの行列積で近似することでパラメータ数を削減します。元の行列 W を U, S, V に分解し、重要度の低い特異値をカットすることで次元を圧縮します。
数式: W ≈ U' * S' * V'^T
2. 枝刈り(プルーニング)
モデル内の重要度の低いニューロンや接続(重み)を削除する手法です。削除後のモデルはスパースな構造となり、専用のライブラリやハードウェアによる高速化が期待できます。
3. 知識蒸留
大規模な「教師モデル」の出力確率分布を、小規模な「生徒モデル」に学習させる手法です。正解ラベルだけでなく、教師モデルの持つ「暗黙知」を引き継ぐことで、小型モデルでも高い精度を実現します。
実装例と解説
低ランク近似による線形層の実装
PyTorchを用いて、重み行列を低ランク近似する層を定義します。ここではランク r を指定し、2つの小さな行列の積で元の重みを近似します。
import torch
import torch.nn as nn
class LowRankLinear(nn.Module):
def __init__(self, input_dim, output_dim, rank):
super().__init__()
# 元の重み行列 (input_dim x output_dim) を2つの小さい行列に分解
# U: (input_dim x rank), V: (rank x output_dim)
self.matrix_u = nn.Parameter(torch.randn(input_dim, rank))
self.matrix_v = nn.Parameter(torch.randn(rank, output_dim))
# 残差項やバイアスが必要な場合はここに追加可能
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, x):
# 通常の線形変換 W*x を (U * V) * x で近似
approx_weight = self.matrix_u @ self.matrix_v
return x @ approx_weight + self.bias
この実装では、パラメータ数を input_dim * output_dim から (input_dim + output_dim) * rank に削減でき、rank が小さいほど圧縮率が高まります。
枝刈りの実装
PyTorchの torch.nn.utils.prune モジュールを使用して、重みのいくつかをランダムに削除(マスク処理)します。
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
def apply_pruning(model, amount=0.3):
"""モデル内の線形層に対してランダム枝刈りを適用する"""
for module in model.modules():
if isinstance(module, nn.Linear):
# 指定した割合の重みをランダムに選んで削除(マスクを適用)
prune.random_unstructured(module, name='weight', amount=amount)
# 枝刈りを永続化する場合(マスクを重みに統合)
prune.remove(module, 'weight')
return model
# 使用例
net = nn.Sequential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10))
pruned_net = apply_pruning(net)
この処理により、重みテンソルの一部がゼロになります。実際の推論高速化には、スパース行列最適化対応のランタイムが必要です。
知識蒸留の実装
教師モデルの出力をソフトターゲットとして生徒モデルを学習させます。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.fc(x)
# 教師モデル(大規模、または事前学習済み)と生徒モデル(軽量)
teacher = SimpleNet(784, 10)
student = SimpleNet(784, 10)
# 教師モデルの重みは固定
for param in teacher.parameters():
param.requires_grad = False
optimizer = torch.optim.Adam(student.parameters())
def distillation_loss(student_output, teacher_output, labels, temperature=5.0, alpha=0.5):
# ソフトターゲット損失 (教師と生徒の確率分布の近さ)
soft_loss = nn.KLDivLoss(reduction='batchmean')(
F.log_softmax(student_output / temperature, dim=1),
F.softmax(teacher_output / temperature, dim=1)
)
# ハードターゲット損失 (正解ラベルとの誤差)
hard_loss = nn.CrossEntropyLoss()(student_output, labels)
return alpha * soft_loss + (1.0 - alpha) * hard_loss
# 学習ループ(擬似コード)
# for data, labels in dataloader:
# t_out = teacher(data)
# s_out = student(data)
# loss = distillation_loss(s_out, t_out, labels)
# loss.backward()
# optimizer.step()
温度パラメータ temperature を調整することで、確率分布を滑らかにし、生徒モデルが教師モデルのクラス間の関係性を学習しやすくします。
応用分野とツール
- エッジコンピューティング: スマートフォンやIoTデバイスなど、リソース制約のある環境でのAI実行に必須。
- クラウドサーバー: スループットの向上とコスト削減。
- ツール: TensorFlow Model Optimization Toolkit, PyTorch Pruning, NVIDIA TensorRT (量子化やレイヤーフュージョン含む)。