結腸がん(CRC)は世界で3番目に多い悪性腫瘍であり、内視鏡下ポリペクトミーはCRCの予防と死亡率の低下に不可欠な手段です。しかし、内視鏡検査環境は複雑であり、低コントラスト、不均一な照明、手術器具の干渉などが診断精度の制約要因となることがあります。最近、ハルビン工科大学をはじめとする研究機関のチームが、ウェーブレット変換と拡散モデルを融合させた**WDNet**という新しいディープラーニングネットワークを提案し、ポリープと手術器具の高精度かつリアルタイムなセグメンテーションを実現しました。本論文は医学AIのトップカンファレンスであるMICCAI 2025で発表され、ここではその原理と核心モジュールのコードを解説します。
1. 背景問題
背景:内視鏡画像セグメンテーションが直面する課題
ロボット支援手術において、ポリープと手術器具をリアルタイムかつ正確に特定することは極めて重要です。しかし、既存のセグメンテーションアルゴリズムは以下の課題に直面しがちです:
- 環境干渉の深刻化:煙、血液、反射、低コントラストにより微小な病変の捕捉が困難になります。
- 対象スケールの多様性:ポリープと器具のサイズや形状は多様であり、単一スケールの特徴抽出では局所的詳細とグローバルな文脈を両立させることが困難です。
- 多目的タスクの欠如:既存研究の多くは単一対象のセグメンテーションに重点を置いており、複雑な臨床手術シーンで病変と器具を同時に処理することは困難です。
2. 核心的な突破:WDNetの革新的アーキテクチャ
WDNetは多段階の特徴抽出と融合手法を採用しており、その核心アーキテクチャは3つの重要なモジュールから構成されています:
2.1 ウェーブレット変換モジュール
WDNetはまず離散ウェーブレット変換(DWT)を用いて画像を低周波数(LL)と高周波数(LH、HL、HH)サブバンドに分解します。その中で:
- 低周波数サブバンド:グローバルな解剖構造情報を捕捉する役割を担います。
- 高周波数サブバンド:画像のエッジ、テクスチャなどのミクロな詳細を正確に抽出します。
この周波数域分解手法は、ノイズの多い環境下でのモデルの耐干渉能力を効果的に向上させます。
この公式は本質的に階層的周波数域融合フレームワークであり、その各部分の具体的な意味は以下の通りです:
- F_fused:融合特徴マップで、このモジュールが最終的に出力する特徴マップです。異なる周波数における画像特徴を統合し、後続のセグメンテーションネットワークに豊富な入力を提供します。
- ∑l=1^L:多段階分解を表し、Lはウェーブレット分解の段数を表します。ネットワークは画像を一度見るだけでなく、多段階分解(レイヤーごとに視野を拡大または縮小するようなもの)を通じて、異なるスケールで情報を抽出します。
- 左側の部分は低周波数経路で、核心コンポーネントはG_ϕ^(l)(DWT_ϕ^(l)(I))です。これは画像Iを第l段のウェーブレット分解することを意味し、画像の低周波数成分(LLサブバンド)を抽出します。この部分はポリープと手術器具のグローバルな解剖構造(物体の輪郭と位置)を保持します。
- 右側の部分は高周波数経路で、核心コンポーネントは∑k∈Ω G_ψ^(l,k)(∂/∂k DWT_ψ^(l)(I))です。∂/∂kは高周波数成分を抽出することを意味し、Ωは異なる高周波数方向(水平、垂直、対角線など)を表します。画像の高周波数詳細(LH、HL、HHサブバンド)を捕捉し、この部分には結腸壁のテクスチャ、ポリープの微小なエッジ、手術器具の金属エッジが含まれます。
最後に抽出された「大まかな輪郭」(低周波)と「細かなエッジ」(高周波)を強制的に組み合わせることで、モデルは複雑な内視鏡画像を処理する際に、大きな方向性を見失うことなく、微小な詳細も逃しません。
2.2 拡散特徴精細化モジュール(DFR)
本研究で最も注目すべき革新的な部分です。このモジュールは確率微分方程式(SDE)に基づく反復最適化プロセスを導入し、内視鏡手術で一般的に見られる境界のぼやけや手術器具による鏡面反射などの課題を解決することを目的としています。このモジュールは確率微分方程式(SDE)の特徴進化プロセスをシミュレーションすることで、対象特徴の反復最適化を実現し、以下を含みます:
- 適応的ノイズゲート:4段階の反復拡散を通じて、対象境界の明確さを強化し、背景の干渉を抑制します。
- 精度向上:消融実験によると、このモジュールにより器具先端の境界ぼやけ度が12%有意に低下しました。
従来の静的特徴抽出とは異なり、特徴精細化を「時間」tに沿って進化する動的プロセスと見なし、反復的にノイズを除去し構造化情報を強化します。以下に少し複雑な公式が示されます:
dX_t = f_θ^(t)(X_(t-1))dt + σ_t Γ(X_(t-1))dW_t
深度学習ネットワークで実現するため、論文ではこれを離散反復形式に変換しています:
X^(t) = ReLU(BN(K^(t) * X^(t-1))) + σ_t Γ(X^(t-1)) ⊙ ε_t
この公式は2つの主要な部分で構成されます:
- 確定的ドリフト項:f_θ^(t)(X_(t-1))dt、つまり公式実現の前半部分ReLU(BN(K^(t) * X^(t-1)))です。これは特徴進化の予測可能な部分です。公式に記述されているf_θ^(t)は実際には標準的な畳み込み操作(ReLU活性化とBN正規化を含む)です。これは特徴がどの明確な方向に進化すべきか(例えば、初期特徴抽出から高レベルの意味的特徴へ)を決定します。これはモデルが既存の経験に基づいて、特徴が「本来」どうあるべきかと判断するものです。学習された3×3畳み込みカーネルK^(t)を用いて前のレイヤーの特徴を変換・集約し、核心的な意味的特徴を抽出します。これは共通性の学習に相当し、医学画像には安定した解剖構造(例:結腸内壁のマクロな形状)が存在します。この項は学習された畳み込みカーネルを通じて、特徴を「ポリープや器具のように見える」という大方向へ導き、セグメンテーションの正確性を保証します。
- 適応的拡散項:つまり公式の後半部分σ_t Γ(X^(t-1)) ⊙ ε_tです。これは特徴進化のランダム/不確定な部分です。この部分は不確定性の処理を行い、内視鏡画像にはノイズ、反射、境界ぼやけ(不確定性が非常に高い)が満ちています。ランダム項(dW_t)を加えることで、モデルはぼやけた領域で微細な「ランダム探索」を行い、局所的最適解から脱出し、真の境界を特定しやすくなります。その中で:
- dW_t(ウィナ過程):白ノイズやランダムな干渉と考えることができます。
- σ_t(ノイズ強度):どれだけのランダム性を加えるかを制御し、文中では固定値0.1に設定されています。
- Γ(X_(t-1))(適応的ゲート):この公式の「魂」であり、「相手によって対応を変える」を実現します。画像のすべての領域が同等のノイズ処理を必要とするわけではありません。シグモイド関数を用いて現在の特徴マップに基づきどこにランダムなノイズを加えるかを決定します。その巧妙さは手術器具の境界/エッジに対して非常に敏感である点にあります。画像内部の平坦な領域ではそれほど効果を発揮しませんが、ぼやけた境界では微小なランダムな変動(拡散)を導入します。
なぜこの複雑な公式を使うのでしょうか。従来のネットワークでは特徴抽出は機械的です。しかし、このSDE公式を通じて、ネットワークは以下の目的を実現できます:
- 反復最適化:複数回の反復(T=4)を通じて、特徴をランダムな変動の助けを借りて、徐々に最も正確なエッジを「探し当て」ます。
- 強靭性の向上:ランダム項の存在は特徴に「フィルター」を加えるようなもので、低コントラスト、反射干渉の複雑な環境でも依然として対象境界を明確に捕捉し、背景ノイズを抑制します。
この公式により、ネットワークは単に畳み込みを積み重ねるのではなく、水滴が拡散するように、確定的な指導の方向性の下で微小なランダムな調整を自動的に行い、最も正確な病変と器具の輪郭を描き出します。
反復ステップ数(T)に関して、論文の研究によると4ステップ(T=4)の拡散反復が正確性と計算効率の最適なバランス点であることが示されています。2ステップと比較して、4ステップはIoUを大幅に向上させますが、6ステップに増やすと精度はわずかに向上しますが、推論速度(FPS)は約30%大幅に低下します。Tステップの反復後、最終特徴X^(T)は全平均プーリング(GAP)を通じて出力チャネルにマッピングされ、最終的な精細化された表現が生成されます。
DFRモジュールは複雑な内視鏡シーンで顕著な優位性を示しています:定性テストにおいて、このモジュールは手術器具先端の境界ぼやけ度を12%低下させ、血液、遮蔽物、または不均一な照明による誤報を効果的に処理しました。DFRを通常の畳み込み層に置き換えると、EPIDデータセットでのIoUが6.71%低下し、Kvasirデータセットでは8.4%低下しました。これは、ノイズ処理と精細なセグメンテーションにおけるこの拡散メカニズムの必要性を十分に証明しています。
区別すべき問題:論文の拡散(Diffusion)メカニズムは、私たちがよく知る生成型拡散モデル(Stable DiffusionやDDPMなど)とは、設計の初期目的、数学的実装、実行効率において顕著に異なります。生成型拡散モデルは「無から有を生む」(純粋なノイズから画像を生成)もので、通常「前向きノイズ追加」と「後向きノイズ除去」の2つの完全な段階を含み、多くの場合ノイズ残差を予測する必要があります。
一方、WDNetの拡散モジュールは「精緻な彫刻」(既存の特徴を詳細に修正)であり、確率微分方程式(SDE)に基づく特徴進化に基づいています。これは新しいコンテンツを生成するのではなく、特徴正則化の手段として機能し、境界の明確さを強化し、背景干渉を抑制することを目的としてセグメンテーション精度を向上させます。
2.3 双流階層的特徴融合
内視鏡画像では、対象物体のサイズは非常に大きく異なります。例えば、微小な早期ポリープと画面の半分を占める手術器具などです。このような多スケールの対象を同時に処理するため、著者は式5で記述される双流構造を設計しました:
X̃^(l) = F_θ1^(l)(X^(l-1))
X^(l) = Φ(F_θ2^(l)(X̃^(l)))
その中で:
- F_θ1(局所パターンの集約):空間制限付き畳み込みを利用して局所的なテクスチャと微小な特徴を捕捉します。
- F_θ2(意味的一貫性の強化):非線形多様体学習を通じてグローバルな意味の一貫性を強化し、大型物体(例:器具の柄)のセグメンテーションが断片化されないようにします。
- Φ(·)(遷移射影):これは適応的選択メカニズムです。特徴レイヤー変換段階にある場合は、1×1畳み込みを用いて射影マッピング(P(·))を行い、そうでなければ元の情報を保持する恒等写像(I(·))を使用します。
モデルは非線形変換を連鎖させることで、局所パターンの集約と高レベルの意味融合を組み合わせ、異なるサイズの対象(微小なポリープと大型器具など)に対する強固なセグメンテーション能力を確保します。この設計により、ネットワークは「局所的なエッジを見る」と「グローバルな形状を把握する」の間でバランスを達成しています。
2.4 高品質データセットの貢献:EPID
モデル性能を検証するため、研究チームはEPID(EndoPolyp-Instrument Dataset)という包括的なデータセットを提供しました。このデータセットは100症例から10,046フレームの画像を含み、ポリープと手術器具に対する精細な双ターゲット注釈が行われており、業界の空白を埋めるものです。
3. コード解説
プロジェクトのソースコードとデータセットはGitHubでオープンソース化されています:https://github.com/hedongdong6060/WDNet。モデルの核心ファイルは`Wavelet.py`と`build.py`の2つです。
3.1 双流ネットワークアーキテクチャの具体的な実装(`WaveletNet`)
`Wavelet.py`では、`WaveletNet`クラスが論文で述べられている双流特徴抽出を完全に実装しています:
- 入力処理:`forward`関数は2つの入力`input1`と`input2`を受け取り、それぞれウェーブレット変換後の低周波数成分と高周波数成分に対応します。
- ブランチ設計:コードでは`branch1`(低周波を処理)と`branch2`(高周波を処理)が定義されています。これら2つのブランチは、`BasicBlock`と`DoubleBasicBlock`を使用して階層化された特徴抽出を行うという点で高度に対称的な構造をしています。
- 特徴融合:`forward`関数の末尾では、コードは`torch.cat`を用いて2つのブランチを異なるスケール(c1からc5)で特徴マップを連結(Concatenation)し、これは論文の多段階特徴融合の考え方に対応します。
class WaveletNet(nn.Module):
def __init__(self, in_channels, num_classes, num_diffusion_steps=4):
super(WaveletNet, self).__init__()
# 5つの異なるレイヤーのチャネル数を設定:64, 128, 256, 512, 1024
l1c, l2c, l3c, l4c, l5c = 64, 128, 256, 512, 1024
# ブランチ1(Branch1):低周波サブバンド特徴を処理する
# 5つのレイヤーのDoubleBasicBlockの組み合わせを含む
self.b1_1_1 = nn.Sequential(conv3x3(in_channels, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c))
self.b1_1_2_down = down_conv(l1c, l2c) # ダウンサンプリング
# ...(b1ブランチの他の同様な定義を省略、すべてDoubleBasicBlock構造に従う)
# ブランチ2(Branch2):高周波サブバンド特徴を処理する
# ブランチ1と高度に対称的な構造で、特徴次元が一致し後続の融合を容易にする
self.b2_1_1 = nn.Sequential(conv3x3(1, l1c), conv3x3(l1c, l1c), BasicBlock(l1c, l1c))
self.b2_1_2_down = down_conv(l1c, l2c) # ダウンサンプリング
# ...(b2ブランチの他の同様な定義を省略)
# デコーダーモジュールを初期化
self.decoder = WaveletDecoder(num_classes, num_diffusion_steps)
def forward(self, input1, input2):
# フロー1:低周波ブランチの前方伝播、段階的にダウンサンプリングし多スケール特徴を抽出
x1_1 = self.b1_1_1(input1)
x1_2 = self.b1_2_1(self.b1_1_2_down(x1_1))
x1_3 = self.b1_3_1(self.b1_2_2_down(x1_2))
x1_4_1 = self.b1_4_1(self.b1_3_2_down(x1_3))
x1_4_2 = self.b1_5_1(self.b1_4_2_down(x1_4_1))
# フロー2:高周波ブランチの前方伝播、低周波ブランチと同期的に実行
x2_1 = self.b2_1_1(input2)
x2_2 = self.b2_2_1(self.b2_1_2_down(x2_1))
x2_3 = self.b2_3_1(self.b2_2_2_down(x2_2))
x2_4_1 = self.b2_4_1(self.b2_3_2_down(x2_3))
x2_4_2 = self.b2_5_1(self.b2_4_2_down(x2_4_1))
# 重要なステップ:特徴融合。catを用いて2つの流れを5つの解像度レベルでそれぞれ結合
c5 = torch.cat([x1_4_2, x2_4_2], dim=1) # 最深層の特徴融合
c4 = torch.cat([x1_4_1, x2_4_1], dim=1)
c3 = torch.cat([x1_3, x2_3], dim=1)
c2 = torch.cat([x1_2, x2_2], dim=1)
c1 = torch.cat([x1_1, x2_1], dim=1) # 最浅層の特徴融合
# 融合された特徴を拡散デコーダーに送り最終的なセグメンテーション画像を生成
out = self.decoder(c5, c4, c3, c2, c1)
return out
3.2 拡散特徴精細化モジュール(`DiffusionBlock`)
これは論文で最も核心的な革新的な点であり、コードでは以下のように実装されています:
- 反復プロセス:`DiffusionBlock`クラスは`nn.ModuleList`を用いて`num_diffusion_steps`(デフォルトは4ステップ)個の処理層を保持します。
- ノイズ駆動:各反復ステップで、コードは`noise = torch.randn_like(x) * 0.1`でガウスノイズを生成し、それを特徴マップに加えます(`x = x + noise`)。その後、`nn.Sequential(Conv2d, BatchNorm2d, ReLU)`を通じて処理します。
- 論理的一貫性:これは論文の式(3)の離散反復形式を完全に再現したものです:まず特徴にノイズを加え、次に畳み込みと非線形変換を通じて精細化を行います。
class DiffusionBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_diffusion_steps=4):
super(DiffusionBlock, self).__init__()
self.in_channels = in_channels # 入力チャネル数
self.out_channels = out_channels # 出力チャネル数
self.num_diffusion_steps = num_diffusion_steps # 拡散反復ステップ数
# 拡散層リスト、各層は3x3畳み込み、正規化、ReLU活性化を含む
self.noise_layers = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(in_channels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
)
for _ in range(num_diffusion_steps)
]
)
# 1x1畳み込み、拡散終了後にチャネル数を目標値に調整
self.channel_project = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
# 設定されたステップ数に従って反復
for layer in self.noise_layers:
# 強度0.1のランダムなガウスノイズを導入
noise = torch.randn_like(x) * 0.1
x = x + noise # ノイズを特徴マップに重ねる
x = layer(x) # 畳み込み層を通じてノイズ除去と特徴精細化を行う
x = self.channel_project(x) # 最終的に出力次元に射影
return x
3.3 デコーダーとアップサンプリング(`WaveletDecoder`)
`WaveletDecoder`クラスは抽出・結合された多スケール特徴をセグメンテーション画像に復元する役割を担います:
- 段階的デコード:`up_conv`(転置畳み込み`nn.ConvTranspose2d`)を使用してアップサンプリングし、各レイヤーの特徴を精細化し境界強化するために`DiffusionBlock`を組み合わせます。
- スキップ接続:各レベルのアップサンプリング後、エンコーダ段階で結合された対応する特徴(`c4`、`c3`など)と再度結合し、低レベルの詳細情報が失われないようにします。
class WaveletDecoder(nn.Module):
def __init__(self, num_classes, num_diffusion_steps=4):
super(WaveletDecoder, self).__init__()
# 5つのレベルの拡散ブロック、異なる空間解像度の特徴を処理
# 入力チャネル数はレイヤーの深さの増加に伴って倍増
self.diff5 = DiffusionBlock(in_channels=2048, out_channels=1024, num_diffusion_steps=num_diffusion_steps)
self.up4 = up_conv(1024, 512) # 2倍アップサンプリング
self.diff4 = DiffusionBlock(in_channels=1536, out_channels=512, num_diffusion_steps=num_diffusion_steps)
self.up3 = up_conv(512, 256) # 2倍アップサンプリング
self.diff3 = DiffusionBlock(in_channels=768, out_channels=256, num_diffusion_steps=num_diffusion_steps)
self.up2 = up_conv(256, 128) # 2倍アップサンプリング
self.diff2 = DiffusionBlock(in_channels=384, out_channels=128, num_diffusion_steps=num_diffusion_steps)
self.up1 = up_conv(128, 64) # 2倍アップサンプリング
self.diff1 = DiffusionBlock(in_channels=192, out_channels=64, num_diffusion_steps=num_diffusion_steps)
# 最終畳み込み層、64チャネルの特徴をクラス数(多目的セグメンテーション)にマッピング
self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, padding=0)
def forward(self, c5, c4, c3, c2, c1):
# 下から上へ段階的に特徴精細化、アップサンプリング、結合融合を行う
d5 = self.diff5(c5)
d4_up = self.up4(d5)
d4_in = torch.cat([d4_up, c4], dim=1) # 前レベルの特徴とエンコーダのスキップ接続特徴を融合
d4 = self.diff4(d4_in)
d3_up = self.up3(d4)
d3_in = torch.cat([d3_up, c3], dim=1)
d3 = self.diff3(d3_in)
d2_up = self.up2(d3)
d2_in = torch.cat([d2_up, c2], dim=1)
d2 = self.diff2(d2_in)
d1_up = self.up1(d2)
d1_in = torch.cat([d1_up, c1], dim=1)
d1 = self.diff1(d1_in)
out = self.out_conv(d1) # 出力マスクを生成
return out
3.4 ウェーブレット分解
`dataloader/custom_transforms.py`では、主に`WaveletTransform`クラスを通じて実現され、`pywt`(PyWavelets)ライブラリを利用して画像の周波数域分解を行います。
class WaveletTransform(object):
"""
2次元離散ウェーブレット変換(DWT)の前処理クラス
"""
def __init__(self, wavelet="db2", level=1):
self.wavelet = wavelet # ウェーブレット基底の種類を設定、デフォルトはdb2
self.level = level # 分解レベルを設定
def __call__(self, sample):
img = sample["image"] # 入力の原始画像を取得
mask = sample["label"] # 対応する注釈マスクを取得
# 入力画像が3チャネル(BGR)の場合、まずグレースケールに変換
if len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# pywtライブラリを用いて単段2次元離散ウェーブレット変換を実行
# LL: 低周波数近似成分(Low-frequency approximation)
# LH: 水平方向高周波数成分(Horizontal detail)
# HL: 垂直方向高周波数成分(Vertical detail)
# HH: 対角線方向高周波数成分(Diagonal detail)
LL, (LH, HL, HH) = pywt.dwt2(img, self.wavelet)
# 低周波数成分をMin-Max正規化
LL = (LL - LL.min()) / (LL.max() - LL.min())
# 3方向の高周波数成分をそれぞれ正規化
LH = (LH - LH.min()) / (LH.max() - LH.min())
HL = (HL - HL.min()) / (HL.max() - HL.min())
HH = (HH - HH.min()) / (HH.max() - HH.min())
# 核心的なステップ:水平、垂直、対角線の3方向の高周波数成分を加算融合
# これは論文で述べられている高周波サブバンドを統合してエッジ詳細を捕捉する操作に対応
merge1 = HH + HL + LH
# 融合された高周波特徴マップを再度正規化
merge1 = (merge1 - merge1.min()) / (merge1.max() - merge1.min())
# 分解後の低周波と高周波特徴を返却、後続の双流ネットワーク(Branch1 & Branch2)の入力として使用
return {"low_freq": LL, "high_freq": merge1, "label": mask}