P5298 [PKUWC2018] Minimax 解説:セグメント木マージによる木DP最適化

問題の分析

この問題は、セグメント木のマージ操作を用いて木構造上の動的計画法を最適化する手法が鍵となります。 まず、値の範囲が最大で10^9まで及ぶため、離散化(座標圧縮)が必要です。各値の出現確率を管理する必要があるため、基本的な木DPを考えます。

動的計画法の設計

dp[v][j] を頂点vにおいて、j番目に小さい値が出現する確率と定義します。遷移は以下の3パターンに分類されます。

1. 葉ノードの場合

そのノードの値の出現確率を1とします。

2. 子ノードが1つの場合

子ノードの確率をそのまま継承します。

3. 子ノードが2つの場合

左右の子ノードl, rについて、確率p_vを用いて以下のように遷移します:

dp[v][j] = dp[l][j] × [p_v × Σ(k=1 to j-1) dp[r][k] + (1-p_v) × Σ(k=j+1 to m) dp[r][k]] + dp[r][j] × [p_v × Σ(k=1 to j-1) dp[l][k] + (1-p_v) × Σ(k=j+1 to m) dp[l][k]]

この遷移式は、値域上の累積和(前缀和・後缀和)を必要とすることがわかります。これをセグメント木で効率的に管理できます。

セグメント木マージの実装

各ノードで値域全体の情報を管理する必要があるため、セグメント木のマージを活用します。マージの過程で、確率の累積和を計算しつつ、各値に係数を掛ける処理を行います。

int mergeNodes(int nodeA, int nodeB, int left, int right, 
               int coefA, int coefB, int prob) {
    if (nodeA == 0 && nodeB == 0) return 0;
    if (nodeA == 0) {
        applyMultiply(nodeB, coefB);
        return nodeB;
    }
    if (nodeB == 0) {
        applyMultiply(nodeA, coefA);
        return nodeA;
    }
    propagate(nodeA);
    propagate(nodeB);
    
    int mid = (left + right) >> 1;
    int leftSumA = treeSum[leftChild[nodeA]];
    int rightSumA = treeSum[rightChild[nodeA]];
    int leftSumB = treeSum[leftChild[nodeB]];
    int rightSumB = treeSum[rightChild[nodeB]];
    
    int invProb = (1 - prob + MOD) % MOD;
    
    leftChild[nodeA] = mergeNodes(
        leftChild[nodeA], leftChild[nodeB], left, mid,
        (coefA + rightSumB * invProb % MOD) % MOD,
        (coefB + rightSumA * invProb % MOD) % MOD, prob
    );
    
    rightChild[nodeA] = mergeNodes(
        rightChild[nodeA], rightChild[nodeB], mid + 1, right,
        (coefA + leftSumB * prob % MOD) % MOD,
        (coefB + leftSumA * prob % MOD) % MOD, prob
    );
    
    updateSum(nodeA);
    return nodeA;
}

完全な実装例

#include <cstdio>
#include <algorithm>
using namespace std;

typedef long long ll;

const int MAXN = 300005;
const int NODES = MAXN * 40;
const ll MOD = 998244353;

int nodeCount;
int valueCount, sortedVals[MAXN];
int root[MAXN], resultProb[MAXN];
int parent[MAXN], children[MAXN][2], childCnt[MAXN], probVal[MAXN];
int leftChild[NODES], rightChild[NODES], treeSum[NODES], lazyMul[NODES];

ll readInt() {
    ll res = 0;
    char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') {
        res = res * 10 + c - '0';
        c = getchar();
    }
    return res;
}

ll modPow(ll base, ll exp) {
    ll result = 1;
    while (exp > 0) {
        if (exp & 1) result = result * base % MOD;
        base = base * base % MOD;
        exp >>= 1;
    }
    return result;
}

void updateNode(int idx) {
    treeSum[idx] = (treeSum[leftChild[idx]] + treeSum[rightChild[idx]]) % MOD;
}

void applyMultiply(int idx, ll mult) {
    if (idx == 0) return;
    treeSum[idx] = treeSum[idx] * mult % MOD;
    lazyMul[idx] = lazyMul[idx] * mult % MOD;
}

void propagate(int idx) {
    if (lazyMul[idx] == 1) return;
    if (leftChild[idx]) applyMultiply(leftChild[idx], lazyMul[idx]);
    if (rightChild[idx]) applyMultiply(rightChild[idx], lazyMul[idx]);
    lazyMul[idx] = 1;
}

int createNode() {
    int idx = ++nodeCount;
    leftChild[idx] = rightChild[idx] = treeSum[idx] = 0;
    lazyMul[idx] = 1;
    return idx;
}

void insertValue(int &idx, int l, int r, int pos, int val) {
    if (idx == 0) idx = createNode();
    if (l == r) {
        treeSum[idx] = val;
        return;
    }
    propagate(idx);
    int mid = (l + r) >> 1;
    if (pos <= mid) 
        insertValue(leftChild[idx], l, mid, pos, val);
    else 
        insertValue(rightChild[idx], mid + 1, r, pos, val);
    updateNode(idx);
}

int mergeNodes(int nodeA, int nodeB, int l, int r, 
               ll coefA, ll coefB, ll p) {
    if (nodeA == 0 && nodeB == 0) return 0;
    if (nodeA == 0) {
        applyMultiply(nodeB, coefB);
        return nodeB;
    }
    if (nodeB == 0) {
        applyMultiply(nodeA, coefA);
        return nodeA;
    }
    propagate(nodeA);
    propagate(nodeB);
    
    int mid = (l + r) >> 1;
    ll lSumA = treeSum[leftChild[nodeA]];
    ll rSumA = treeSum[rightChild[nodeA]];
    ll lSumB = treeSum[leftChild[nodeB]];
    ll rSumB = treeSum[rightChild[nodeB]];
    ll invP = (1 - p + MOD) % MOD;
    
    leftChild[nodeA] = mergeNodes(
        leftChild[nodeA], leftChild[nodeB], l, mid,
        (coefA + rSumB * invP % MOD) % MOD,
        (coefB + rSumA * invP % MOD) % MOD, p
    );
    
    rightChild[nodeA] = mergeNodes(
        rightChild[nodeA], rightChild[nodeB], mid + 1, r,
        (coefA + lSumB * p % MOD) % MOD,
        (coefB + lSumA * p % MOD) % MOD, p
    );
    
    updateNode(nodeA);
    return nodeA;
}

void buildTree(int v) {
    if (childCnt[v] == 0) {
        insertValue(root[v], 1, valueCount, probVal[v], 1);
    } else if (childCnt[v] == 1) {
        buildTree(children[v][0]);
        root[v] = root[children[v][0]];
    } else {
        buildTree(children[v][0]);
        buildTree(children[v][1]);
        root[v] = mergeNodes(
            root[children[v][0]], root[children[v][1]], 
            1, valueCount, 0, 0, probVal[v]
        );
    }
}

void extractResult(int idx, int l, int r) {
    if (idx == 0) return;
    if (l == r) {
        resultProb[l] = treeSum[idx];
        return;
    }
    propagate(idx);
    int mid = (l + r) >> 1;
    extractResult(leftChild[idx], l, mid);
    extractResult(rightChild[idx], mid + 1, r);
}

int main() {
    int n = readInt();
    for (int i = 1; i <= n; i++) {
        parent[i] = readInt();
        if (parent[i]) {
            children[parent[i]][childCnt[parent[i]]++] = i;
        }
    }
    
    for (int i = 1; i <= n; i++) {
        probVal[i] = readInt();
        if (childCnt[i] == 0) {
            sortedVals[++valueCount] = probVal[i];
        } else {
            probVal[i] = probVal[i] * modPow(10000, MOD - 2) % MOD;
        }
    }
    
    sort(sortedVals + 1, sortedVals + valueCount + 1);
    for (int i = 1; i <= n; i++) {
        if (childCnt[i] == 0) {
            probVal[i] = lower_bound(sortedVals + 1, sortedVals + valueCount + 1, probVal[i]) - sortedVals;
        }
    }
    
    buildTree(1);
    extractResult(root[1], 1, valueCount);
    
    ll answer = 0;
    for (int i = 1; i <= valueCount; i++) {
        answer = (answer + (ll)i * sortedVals[i] % MOD * resultProb[i] % MOD * resultProb[i] % MOD) % MOD;
    }
    
    printf("%lld\n", answer);
    return 0;
}

計算量解析

セグメント木のマージは、全体としてO(n log n)の時間計算量で動作します。各ノードのマージ操作は、2つのセグメント木の対応するノードを統合する形で行われ、全ノードを通しての合計計算量はノード総数に比例するためです。

空間計算量はO(n log n)となり、動的にノードを確保することで効率的にメモリを使用します。

タグ: セグメント木 木DP 確率 データ構造 競技プログラミング

6月28日 02:04 投稿