RNNの進化:LSTMとGRUを理解する

1. 基本的なRNN (再帰的ニューラルネットワーク)

1.1 RNNの基本的なアイデア

例えば、動物を推測するゲームを考えると、新しい手がかり(入力)を得たときに、前の記憶(隠れ状態)を考慮して現在の推測(出力)を更新します。これがRNNの本質です。

1.2 RNNの欠点

RNNは長距離依存性を持つシーケンスを処理する際に、勾配消失/爆発問題に直面します。

1.3 PyTorchでのRNNモデルの実装


import torch
import torch.nn as nn

# 1. RNNモデルの作成
rnn = nn.RNN(input_size=5, hidden_size=6, num_layers=1)

# 2. 入力データの準備
input_data = torch.randn(20, 3, 5)

# 3. 隠れ状態の初期化
hidden_state = torch.randn(1, 3, 6)

# 4. モデルの実行
output, hidden_state = rnn(input_data, hidden_state)

print(f"出力形状: {output.shape}") # torch.Size([20, 3, 6])
print(f"最終的な隠れ状態: {hidden_state.shape}") # torch.Size([1, 3, 6])
  

2. LSTM (長短期記憶ネットワーク)

2.1 LSTMの主要な設計

LSTMは「セル状態」を導入し、忘れゲート、入力ゲート、出力ゲートを通じて情報を制御します。

2.2 PyTorchでのLSTMの実装


import torch
import torch.nn as nn

# 1. LSTMモデルの作成
lstm = nn.LSTM(input_size=5, hidden_size=6, num_layers=1, bidirectional=False)

# 2. 入力データの準備
input_data = torch.randn(4, 3, 5)

# 3. 隠れ状態とセル状態の初期化
hidden_state = torch.randn(1, 3, 6)
cell_state = torch.randn(1, 3, 6)

# 4. モデルの実行
output, (hidden_state, cell_state) = lstm(input_data, (hidden_state, cell_state))

print(f"LSTM出力形状: {output.shape}") # torch.Size([4, 3, 6])
print(f"隠れ状態形状: {hidden_state.shape}") # torch.Size([1, 3, 6])
print(f"セル状態形状: {cell_state.shape}") # torch.Size([1, 3, 6])
  

3. GRU (ゲート付き再帰ユニット)

3.1 GRUの主要なアイデア

GRUはLSTMの簡略化版で、更新ゲートとリセットゲートを使用して計算量を削減します。

3.2 PyTorchでのGRUの実装


import torch
import torch.nn as nn

# 1. GRUモデルの作成
gru = nn.GRU(input_size=5, hidden_size=6, num_layers=1)

# 2. 入力データの準備
input_data = torch.randn(2, 3, 5)

# 3. 隠れ状態の初期化
hidden_state = torch.randn(1, 3, 6)

# 4. モデルの実行
output, hidden_state = gru(input_data, hidden_state)

print(f"GRU出力形状: {output.shape}") # torch.Size([2, 3, 6])
print(f"最終的な隠れ状態: {hidden_state.shape}") # torch.Size([1, 3, 6])
  

4. RNN、LSTM、GRUの比較

特性 伝統的なRNN LSTM GRU
**主要な構造** シンプルな再帰行列 忘れゲート + 入力ゲート + 出力ゲート + セル状態 更新ゲート + リセットゲート (状態の統合)
**長距離依存性の処理** 悪い (勾配消失しやすい) 良い (長距離依存性を解決) 中程度 (LSTMに近い)
**計算複雑度** 低 (高速) 高 (遅く、パラメータ数が多い) 中 (LSTMより20-30%速い)
**メモリ使用量** 少ない 多い (セル状態を保存する必要がある) 中程度
**コードの実装難易度** 簡単 複雑 (二つの状態を管理する必要がある) 簡単 (一つの状態を管理する)
**2026年の推奨用途** 短いテキスト、教育用 重要なタスク、小さなデータセット、高精度が必要 大量のデータ、リアルタイム予測、モバイル端末

タグ: RNN LSTM GRU PyTorch 時間系列予測

6月15日 18:06 投稿