CodeForces 258E - Little Elephant and Tree 木の部分木操作における複数解法

概要

本問題は複数の解法が存在する良問である。木の2頂点に対して各操作で部分木を合併するクエリが与えられる。各頂点について、自身を含む集合に一切に寄与していない(挿入回数が0の)頂点数を求める問題に帰着する。

問題構造の解析

各ノードに対して自分を祖先または子孫に含む操作回数を記録する。与えられた各クエリ \\( (a_i, b_i) \\) について、\\( subtree(a_i) \cup subtree(b_i) \\) の全ノードのカウントを+1する。 DFS順序を考えると、部分木 \\( x \\) は区間 \\( [dfn[x], mxdfn[x]] \\) で表現される。因此,两つの部分木の和集合は、場合によって1つまたは2つの区間になる。

解法1:平方分割

各ブロック内にバケットと全体オフセット量を保持する数据结构を構成する。ブロックサイズを \\( \sqrt{n} \\) とすると、時間・空間計算量は \\( O(n\sqrt{n}) \\) となる。 #include <bits/stdc++.h> using namespace std; const int MAXN = 100000; const int BLOCK_SIZE = 320; int nodeCount, queryCount; vector<int> adj[MAXN + 1]; int entryTime[MAXN + 1], exitTime[MAXN + 1], nodeAtPos[MAXN + 1], timer; void dfs(int v = 1, int parent = 0) { entryTime[v] = exitTime[v] = ++timer; nodeAtPos[timer] = v; for (int u : adj[v]) { if (u == parent) continue; dfs(u, v); exitTime[v] = exitTime[u]; } } int prefix[MAXN + 2]; vector<pair<pair<int,int>, int>> operations[MAXN + 2]; struct SqrtDecomp { int size; int bucketCount; int baseArr[MAXN + 1]; struct Block { int left, right; int offset; short bucket[2 * MAXN + 1]; } blocks[BLOCK_SIZE]; void build(int n) { size = sqrt(n); bucketCount = (n + size - 1) / size; for (int i = 1; i <= bucketCount; i++) { int l = (i - 1) * size + 1; int r = min(i * size, n); blocks[i].left = l; blocks[i].right = r; blocks[i].offset = 0; memset(blocks[i].bucket, 0, sizeof(blocks[i].bucket)); for (int j = l; j <= r; j++) { blocks[i].bucket[baseArr[j] + n] = 1; } } } void rangeAdd(int L, int R, int value) { int startBucket = (L + size - 1) / size; int endBucket = (R + size - 1) / size; if (startBucket == endBucket) { for (int i = L; i <= R; i++) { int oldVal = baseArr[i]; blocks[startBucket].bucket[oldVal + n]--; baseArr[i] += value; blocks[startBucket].bucket[baseArr[i] + n]++; } return; } for (int i = L; i <= blocks[startBucket].right; i++) { int oldVal = baseArr[i]; blocks[startBucket].bucket[oldVal + n]--; baseArr[i] += value; blocks[startBucket].bucket[baseArr[i] + n]++; } for (int i = blocks[endBucket].left; i <= R; i++) { int oldVal = baseArr[i]; blocks[endBucket].bucket[oldVal + n]--; baseArr[i] += value; blocks[endBucket].bucket[baseArr[i] + n]++; } for (int i = startBucket + 1; i <= endBucket - 1; i++) { blocks[i].offset += value; } } int countZeros() { int total = 0; for (int i = 1; i <= bucketCount; i++) { total += blocks[i].bucket[n - blocks[i].offset]; } return total; } } solver; int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cin >> nodeCount >> queryCount; for (int i = 1; i < nodeCount; i++) { int x, y; cin >> x >> y; adj[x].push_back(y); adj[y].push_back(x); } dfs(); for (int i = 1; i <= queryCount; i++) { int a, b; cin >> a >> b; prefix[entryTime[a]]++; prefix[exitTime[a] + 1]--; prefix[entryTime[b]]++; prefix[exitTime[b] + 1]--; if (entryTime[a] <= entryTime[b] && exitTime[b] <= exitTime[a]) { operations[entryTime[a]].push_back({{entryTime[a], exitTime[a]}, 1}); operations[exitTime[a] + 1].push_back({{entryTime[a], exitTime[a]}, -1}); } else if (entryTime[b] <= entryTime[a] && exitTime[a] <= exitTime[b]) { operations[entryTime[b]].push_back({{entryTime[b], exitTime[b]}, 1}); operations[exitTime[b] + 1].push_back({{entryTime[b], exitTime[b]}, -1}); } else { operations[entryTime[a]].push_back({{entryTime[a], exitTime[a]}, 1}); operations[exitTime[a] + 1].push_back({{entryTime[a], exitTime[a]}, -1}); operations[entryTime[b]].push_back({{entryTime[a], exitTime[a]}, 1}); operations[exitTime[b] + 1].push_back({{entryTime[a], exitTime[a]}, -1}); operations[entryTime[a]].push_back({{entryTime[b], exitTime[b]}, 1}); operations[exitTime[a] + 1].push_back({{entryTime[b], exitTime[b]}, -1}); operations[entryTime[b]].push_back({{entryTime[b], exitTime[b]}, 1}); operations[exitTime[b] + 1].push_back({{entryTime[b], exitTime[b]}, -1}); } } for (int i = 1; i <= nodeCount; i++) { prefix[i] += prefix[i - 1]; } solver.build(nodeCount); vector<int> result(nodeCount + 1); for (int i = 1; i <= nodeCount; i++) { for (auto& op : operations[i]) { int l = op.first.first; int r = op.first.second; int delta = op.second; solver.rangeAdd(l, r, delta); } result[nodeAtPos[i]] = nodeCount - solver.countZeros() - (prefix[i] ? 1 : 0); } for (int i = 1; i <= nodeCount; i++) { cout << result[i] << (i == nodeCount ? '\n' : ' '); } return 0; }

解法2:線分ツリー(标记永久化)

クエリは区間への加算と減算(取消し)である。0の個数を直接管理するよりも、标记の性質を利用することでより効率的に处理できる。 各ノードにはその区間内の0の個数を存储する。标记が正であればその区間内の0は全て消除されるため、0の個数は0となる。减算操作は常に既存の加算を取消すため、标记は非負であることが保证される。 #include <bits/stdc++.h> using namespace std; const int MAXN = 100000; int N, Q; vector<int> graph[MAXN + 1]; int inTime[MAXN + 1], outTime[MAXN + 1], order[MAXN + 1], timer; void build(int v = 1, int p = 0) { inTime[v] = outTime[v] = ++timer; order[timer] = v; for (int u : graph[v]) { if (u == p) continue; build(u, v); outTime[v] = outTime[u]; } } int anc[MAXN + 2]; struct SegTree { struct Node { int l, r; int zeroCount; int lazy; }; vector<Node> t; SegTree(int n) : t(n * 4) {} void init(int v, int l, int r) { t[v].l = l; t[v].r = r; t[v].lazy = 0; t[v].zeroCount = r - l + 1; if (l == r) return; int mid = (l + r) / 2; init(v * 2, l, mid); init(v * 2 + 1, mid + 1, r); } void pull(int v) { if (t[v].lazy > 0) { t[v].zeroCount = 0; } else if (t[v].l == t[v].r) { t[v].zeroCount = 1; } else { t[v].zeroCount = t[v * 2].zeroCount + t[v * 2 + 1].zeroCount; } } void modify(int v, int l, int r, int val) { if (l <= t[v].l && t[v].r <= r) { t[v].lazy += val; pull(v); return; } int mid = (t[v].l + t[v].r) / 2; if (l <= mid) modify(v * 2, l, r, val); if (r > mid) modify(v * 2 + 1, l, r, val); pull(v); } int query() { return t[1].zeroCount; } }; vector<pair<pair<int,int>, int>> ops[MAXN + 2]; int answer[MAXN + 1]; int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cin >> N >> Q; for (int i = 1; i < N; i++) { int x, y; cin >> x >> y; graph[x].push_back(y); graph[y].push_back(x); } build(); for (int i = 1; i <= Q; i++) { int a, b; cin >> a >> b; anc[inTime[a]]++; anc[outTime[a] + 1]--; anc[inTime[b]]++; anc[outTime[b] + 1]--; if (inTime[a] <= inTime[b] && outTime[b] <= outTime[a]) { ops[inTime[a]].push_back({{inTime[a], outTime[a]}, 1}); ops[outTime[a] + 1].push_back({{inTime[a], outTime[a]}, -1}); } else if (inTime[b] <= inTime[a] && outTime[a] <= outTime[b]) { ops[inTime[b]].push_back({{inTime[b], outTime[b]}, 1}); ops[outTime[b] + 1].push_back({{inTime[b], outTime[b]}, -1}); } else { ops[inTime[a]].push_back({{inTime[a], outTime[a]}, 1}); ops[outTime[a] + 1].push_back({{inTime[a], outTime[a]}, -1}); ops[inTime[b]].push_back({{inTime[a], outTime[a]}, 1}); ops[outTime[b] + 1].push_back({{inTime[a], outTime[a]}, -1}); ops[inTime[a]].push_back({{inTime[b], outTime[b]}, 1}); ops[outTime[a] + 1].push_back({{inTime[b], outTime[b]}, -1}); ops[inTime[b]].push_back({{inTime[b], outTime[b]}, 1}); ops[outTime[b] + 1].push_back({{inTime[b], outTime[b]}, -1}); } } for (int i = 1; i <= N; i++) anc[i] += anc[i - 1]; SegTree st(N); st.init(1, 1, N); for (int i = 1; i <= N; i++) { for (auto& op : ops[i]) { st.modify(1, op.first.first, op.first.second, op.second); } answer[order[i]] = N - st.query() - (anc[i] ? 1 : 0); } for (int i = 1; i <= N; i++) { cout << answer[i] << (i == N ? '\n' : ' '); } return 0; }

解法3:最小値と出現回数の管理

全区間の最小値を追跡し、その最小値が0であるかどうかで判定する方式もある。各ノードに(最小値, 最小値の出現回数)を存储し、上传時に两子の情報を適切にマージする。 #include <bits/stdc++.h> using namespace std; const int MAXN = 100000; int vertexCnt, queryCnt; vector<int> tree[MAXN + 1]; int tin[MAXN + 1], tout[MAXN + 1], rev[MAXN + 1], curTime; void traverse(int v = 1, int p = 0) { tin[v] = tout[v] = ++curTime; rev[curTime] = v; for (int u : tree[v]) { if (u == p) continue; traverse(u, v); tout[v] = tout[u]; } } int diff[MAXN + 2]; vector<pair<pair<int,int>, int>> pending[MAXN + 2]; class SegmentTree { public: struct Cell { int l, r; int minVal; int minCount; int addTag; }; vector<Cell> node; SegmentTree(int n) : node(n * 4) {} void build(int idx, int l, int r) { node[idx].l = l; node[idx].r = r; node[idx].addTag = 0; node[idx].minVal = 0; node[idx].minCount = r - l + 1; if (l == r) return; int mid = (l + r) / 2; build(idx * 2, l, mid); build(idx * 2 + 1, mid + 1, r); } void merge(int idx) { int left = idx * 2, right = idx * 2 + 1; if (node[left].minVal == node[right].minVal) { node[idx].minVal = node[left].minVal; node[idx].minCount = node[left].minCount + node[right].minCount; } else if (node[left].minVal < node[right].minVal) { node[idx].minVal = node[left].minVal; node[idx].minCount = node[left].minCount; } else { node[idx].minVal = node[right].minVal; node[idx].minCount = node[right].minCount; } } void applyTag(int idx, int val) { node[idx].minVal += val; node[idx].addTag += val; } void pushDown(int idx) { if (node[idx].addTag != 0) { applyTag(idx * 2, node[idx].addTag); applyTag(idx * 2 + 1, node[idx].addTag); node[idx].addTag = 0; } } void rangeAdd(int idx, int L, int R, int val) { if (L <= node[idx].l && node[idx].r <= R) { applyTag(idx, val); return; } pushDown(idx); int mid = (node[idx].l + node[idx].r) / 2; if (L <= mid) rangeAdd(idx * 2, L, R, val); if (R > mid) rangeAdd(idx * 2 + 1, L, R, val); merge(idx); } int countZeros() { return node[1].minVal == 0 ? node[1].minCount : 0; } }; int result[MAXN + 1]; int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cin >> vertexCnt >> queryCnt; for (int i = 1; i < vertexCnt; i++) { int x, y; cin >> x >> y; tree[x].push_back(y); tree[y].push_back(x); } traverse(); for (int i = 1; i <= queryCnt; i++) { int a, b; cin >> a >> b; diff[tin[a]]++; diff[tout[a] + 1]--; diff[tin[b]]++; diff[tout[b] + 1]--; if (tin[a] <= tin[b] && tout[b] <= tout[a]) { pending[tin[a]].push_back({{tin[a], tout[a]}, 1}); pending[tout[a] + 1].push_back({{tin[a], tout[a]}, -1}); } else if (tin[b] <= tin[a] && tout[a] <= tout[b]) { pending[tin[b]].push_back({{tin[b], tout[b]}, 1}); pending[tout[b] + 1].push_back({{tin[b], tout[b]}, -1}); } else { pending[tin[a]].push_back({{tin[a], tout[a]}, 1}); pending[tout[a] + 1].push_back({{tin[a], tout[a]}, -1}); pending[tin[b]].push_back({{tin[a], tout[a]}, 1}); pending[tout[b] + 1].push_back({{tin[a], tout[a]}, -1}); pending[tin[a]].push_back({{tin[b], tout[b]}, 1}); pending[tout[a] + 1].push_back({{tin[b], tout[b]}, -1}); pending[tin[b]].push_back({{tin[b], tout[b]}, 1}); pending[tout[b] + 1].push_back({{tin[b], tout[b]}, -1}); } } for (int i = 1; i <= vertexCnt; i++) diff[i] += diff[i - 1]; SegmentTree seg(vertexCnt); seg.build(1, 1, vertexCnt); for (int i = 1; i <= vertexCnt; i++) { for (auto& op : pending[i]) { seg.rangeAdd(1, op.first.first, op.first.second, op.second); } result[rev[i]] = vertexCnt - seg.countZeros() - (diff[i] ? 1 : 0); } for (int i = 1; i <= vertexCnt; i++) { cout << result[i] << (i == vertexCnt ? '\n' : ' '); } return 0; }

解法4:撤销可能線分ツリー

木を括弧序列で表現すると、各减算クエリは対応する加算クエリの撤销となる。スタックを用いて操作履歴を記録し、必要に応じて状態を復元する。 #include <bits/stdc++.h> using namespace std; const int MAXN = 100000; int n, q; vector<int> g[MAXN + 1]; int eulerIn[MAXN + 1], eulerOut[MAXN + 1], posToNode[MAXN + 1], tim; void eulerDFS(int v = 1, int parent = 0) { eulerIn[v] = eulerOut[v] = ++tim; posToNode[tim] = v; for (int u : g[v]) { if (u == parent) continue; eulerDFS(u, v); eulerOut[v] = eulerOut[u]; } } int ancDiff[MAXN + 2]; vector<pair<pair<int,int>, int>> ops[MAXN + 2]; class RevTree { public: struct Cell { int l, r; int sum; bool marked; }; vector<Cell> node; stack<stack<pair<int, Cell>>> history; RevTree(int n) : node(n * 4) {} void build(int idx, int l, int r) { node[idx].l = l; node[idx].r = r; node[idx].sum = r - l + 1; node[idx].marked = false; if (l == r) return; int mid = (l + r) / 2; build(idx * 2, l, mid); build(idx * 2 + 1, mid + 1, r); } void saveState(int idx) { history.top().push({idx, node[idx]}); } void pull(int idx) { saveState(idx); node[idx].sum = node[idx * 2].sum + node[idx * 2 + 1].sum; } void setZero(int idx) { saveState(idx); node[idx].sum = 0; node[idx].marked = true; } void propagate(int idx) { saveState(idx); if (node[idx].marked) { setZero(idx * 2); setZero(idx * 2 + 1); node[idx].marked = false; } } void rangeSetOne(int L, int R, int idx = 1) { if (idx == 1) history.push({}); if (L <= node[idx].l && node[idx].r <= R) { setZero(idx); return; } propagate(idx); int mid = (node[idx].l + node[idx].r) / 2; if (L <= mid) rangeSetOne(L, R, idx * 2); if (R > mid) rangeSetOne(L, R, idx * 2 + 1); pull(idx); } void undo() { auto& st = history.top(); while (!st.empty()) { auto [idx, oldState] = st.top(); st.pop(); node[idx] = oldState; } history.pop(); } int queryZeros() { return node[1].sum; } }; int ans[MAXN + 1]; int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cin >> n >> q; for (int i = 1; i < n; i++) { int x, y; cin >> x >> y; g[x].push_back(y); g[y].push_back(x); } eulerDFS(); for (int i = 1; i <= q; i++) { int a, b; cin >> a >> b; ancDiff[eulerIn[a]]++; ancDiff[eulerOut[a] + 1]--; ancDiff[eulerIn[b]]++; ancDiff[eulerOut[b] + 1]--; if (eulerIn[a] <= eulerIn[b] && eulerOut[b] <= eulerOut[a]) { ops[eulerIn[a]].push_back({{eulerIn[a], eulerOut[a]}, 1}); ops[eulerOut[a] + 1].push_back({{eulerIn[a], eulerOut[a]}, -1}); } else if (eulerIn[b] <= eulerIn[a] && eulerOut[a] <= eulerOut[b]) { ops[eulerIn[b]].push_back({{eulerIn[b], eulerOut[b]}, 1}); ops[eulerOut[b] + 1].push_back({{eulerIn[b], eulerOut[b]}, -1}); } else { ops[eulerIn[a]].push_back({{eulerIn[a], eulerOut[a]}, 1}); ops[eulerOut[a] + 1].push_back({{eulerIn[a], eulerOut[a]}, -1}); ops[eulerIn[b]].push_back({{eulerIn[a], eulerOut[a]}, 1}); ops[eulerOut[b] + 1].push_back({{eulerIn[a], eulerOut[a]}, -1}); ops[eulerIn[a]].push_back({{eulerIn[b], eulerOut[b]}, 1}); ops[eulerOut[a] + 1].push_back({{eulerIn[b], eulerOut[b]}, -1}); ops[eulerIn[b]].push_back({{eulerIn[b], eulerOut[b]}, 1}); ops[eulerOut[b] + 1].push_back({{eulerIn[b], eulerOut[b]}, -1}); } } for (int i = 1; i <= n; i++) ancDiff[i] += ancDiff[i - 1]; RevTree seg(n); seg.build(1, 1, n); for (int i = 1; i <= n; i++) { for (auto& op : ops[i]) { if (op.second == -1) { seg.undo(); } } for (auto& op : ops[i]) { if (op.second == 1) { seg.rangeSetOne(op.first.first, op.first.second); } } ans[posToNode[i]] = n - seg.queryZeros() - (ancDiff[i] ? 1 : 0); } for (int i = 1; i <= n; i++) { cout << ans[i] << (i == n ? '\n' : ' '); } return 0; }

各解法の比較

平方分割は実装か简单だが、\\( O(n\sqrt{n}) \\) の計算量で大规模な入力には不向きである。解法2と3はともに \\( O((n+m)\log n) \\) の計算量を実現し、マーク永久化や最小値追跡という観点から効率的な実装が可能である。 撤销可能線分ツリーは思想上 간단だが、オペレーションごとにスタックへのコピーが必要となり、実質的なメモリ消費は \\( O(m\log n) \\) となる。题目の要件に対しては过于複雑な実装となっている。

タグ: segment-tree sqrt-decomposition dfs-order rollback-data-structure lazy-propagation

Sat, 09 May 2026 23:34:14 +0900 投稿