基礎理論
演算子オーバーロード(OO)とリバースモード自動微分の概念を組み合わせた実装手法について解説する。Pythonを用いてPyTorchのコアメカニズムを再現する。
演算子オーバーロード手法
現代プログラミング言語の多態性を利用し、基本演算子の微分規則をカプセル化する。演算子の再定義により計算グラフを構築し、連鎖律で微分結果を合成する。
利点
- 実装が簡潔で言語の基本機能のみを要求
- ネイティブなコーディングスタイルを維持
制約
- 制御フロー処理が困難
- 高階微分への拡張性に課題
リバースモード自動微分
計算過程を特殊なデータ構造に記録し、出力から入力に向かって連鎖律を適用する。各演算ノードでの局所的な勾配計算を可能にする。
計算過程: $$\frac{\partial f}{\partial x}=\sum_{k=1}^{N} \frac{\partial f}{\partial v_{k}} \frac{\partial v_{k}}{\partial \boldsymbol{x}}$$
実装例
Tensor計算を追跡する変数クラス:
import numpy as np
class Tensor:
_counter = 0
def __init__(self, data, label=None):
self.data = data
self.label = label or f'tensor_{Tensor._counter}'
Tensor._counter += 1
def __repr__(self):
return f"{self.label}: {self.data}"
@staticmethod
def create(data, label=None):
return Tensor(data, label)
演算記録用テープ機構:
class OperationRecord:
def __init__(self, inputs, outputs, backward_func):
self.input_keys = inputs
self.output_keys = outputs
self.backward_fn = backward_func
operation_history = []
def clear_history():
global operation_history
operation_history = []
Tensor._counter = 0
乗算演算子の実装例:
def multiply(a, b):
result = Tensor(a.data * b.data)
def backward(grad_output):
grad_a = grad_output * b.data
grad_b = grad_output * a.data
return [grad_a, grad_b]
op_record = OperationRecord(
inputs=[a.label, b.label],
outputs=[result.label],
backward_func=backward
)
operation_history.append(op_record)
return result
勾配計算関数:
def compute_gradients(target, sources):
grad_map = {target.label: 1.0}
for record in reversed(operation_history):
output_grads = [grad_map.get(out) for out in record.output_keys]
input_grads = record.backward_fn(*output_grads)
for key, grad_val in zip(record.input_keys, input_grads):
grad_map[key] = grad_val if key not in grad_map else grad_map[key] + grad_val
return [grad_map[src.label] for src in sources]
検証例
計算グラフの構築:
clear_history()
x = Tensor.create(2.0, 'x')
y = Tensor.create(5.0, 'y')
z1 = np.log(x.data)
z2 = x.data * y.data
z3 = z1 + z2
z4 = np.sin(y.data)
result = Tensor(z3 - z4)
勾配計算実行:
dx, dy = compute_gradients(result, [x, y])
print(f"∂result/∂x = {dx}")
print(f"∂result/∂y = {dy}")