PyTorch自動微分のリバースモード実装

基礎理論

演算子オーバーロード(OO)とリバースモード自動微分の概念を組み合わせた実装手法について解説する。Pythonを用いてPyTorchのコアメカニズムを再現する。

演算子オーバーロード手法

現代プログラミング言語の多態性を利用し、基本演算子の微分規則をカプセル化する。演算子の再定義により計算グラフを構築し、連鎖律で微分結果を合成する。

利点

  • 実装が簡潔で言語の基本機能のみを要求
  • ネイティブなコーディングスタイルを維持

制約

  • 制御フロー処理が困難
  • 高階微分への拡張性に課題

リバースモード自動微分

計算過程を特殊なデータ構造に記録し、出力から入力に向かって連鎖律を適用する。各演算ノードでの局所的な勾配計算を可能にする。

リバースモード計算フロー図

計算過程: $$\frac{\partial f}{\partial x}=\sum_{k=1}^{N} \frac{\partial f}{\partial v_{k}} \frac{\partial v_{k}}{\partial \boldsymbol{x}}$$

実装例

Tensor計算を追跡する変数クラス:

import numpy as np

class Tensor:
    _counter = 0
    
    def __init__(self, data, label=None):
        self.data = data
        self.label = label or f'tensor_{Tensor._counter}'
        Tensor._counter += 1
    
    def __repr__(self):
        return f"{self.label}: {self.data}"
    
    @staticmethod
    def create(data, label=None):
        return Tensor(data, label)

演算記録用テープ機構:

class OperationRecord:
    def __init__(self, inputs, outputs, backward_func):
        self.input_keys = inputs
        self.output_keys = outputs
        self.backward_fn = backward_func

operation_history = []

def clear_history():
    global operation_history
    operation_history = []
    Tensor._counter = 0

乗算演算子の実装例:

def multiply(a, b):
    result = Tensor(a.data * b.data)
    
    def backward(grad_output):
        grad_a = grad_output * b.data
        grad_b = grad_output * a.data
        return [grad_a, grad_b]
    
    op_record = OperationRecord(
        inputs=[a.label, b.label],
        outputs=[result.label],
        backward_func=backward
    )
    operation_history.append(op_record)
    return result

勾配計算関数:

def compute_gradients(target, sources):
    grad_map = {target.label: 1.0}
    
    for record in reversed(operation_history):
        output_grads = [grad_map.get(out) for out in record.output_keys]
        input_grads = record.backward_fn(*output_grads)
        
        for key, grad_val in zip(record.input_keys, input_grads):
            grad_map[key] = grad_val if key not in grad_map else grad_map[key] + grad_val
    
    return [grad_map[src.label] for src in sources]

検証例

計算グラフの構築:

clear_history()
x = Tensor.create(2.0, 'x')
y = Tensor.create(5.0, 'y')

z1 = np.log(x.data)
z2 = x.data * y.data
z3 = z1 + z2
z4 = np.sin(y.data)
result = Tensor(z3 - z4)

勾配計算実行:

dx, dy = compute_gradients(result, [x, y])
print(f"∂result/∂x = {dx}")
print(f"∂result/∂y = {dy}")

タグ: 自動微分 リバースモード 演算子オーバーロード PyTorch チェーンルール

5月16日 18:26 投稿