桁DPとは何か?
桁DP(Digit Dynamic Programming)は、通常ある区間[L, R]内で特定の制約を満たす数字の数を統計するために使用されます。LとRのデータ範囲が大きいため、DP(動的計画法)で統計する必要があることが多いです。
上限Rの処理テクニック
数値の比較ルールから、現在の桁の取りうる値の範囲は、前方の桁の値に依存することがわかります。
もし前方のすべての桁がRと同じ場合、現在の桁の値はRの現在の桁の値を超えることはできません。逆に、前方の桁の中にRより小さい桁が存在する場合、現在の桁の値は制限されず、0〜9のいずれかの数値を取ることができます。
全体として、現在の数は常にRの対応する接頭辞を超えてはなりません。
したがって、状態には上限に関するマークubが必要であり、これは前方の桁がRと**異なることがあったかどうか**を示します。
遷移時、前方の桁または現在の桁がRと異なる場合、遷移後のマークは1になり、そうでなければ0になります。
ub=0の場合を個別に考える
桁数がkの場合、ub=0の状態は1つしかないため、この状態を個別に取り出して議論し、個別に遷移させることを考えることができます。
これにより、コードがより明確になります……でしょうか?
実はそうではありません。
複数のデータ制約[L_i, R_i]がある場合、この方法は包含原理(容斥)に発展し、実装の難易度が**指数関数的に増加**します。
したがって、この方法の使用は推奨されません。
下限Lの処理テクニック
下限Lの処理テクニック?間違いないですね、通常は包含原理を使うので、下限Lは存在しないのです!
そうですね、その通りですが、少し待ってください。
包含原理を使わない場合、下限Lの処理状況を考慮する必要があります。簡単に言えば、lbの下限マークを追加し、上限マークと同じ操作を行います。
では、下限を追加する方案と、通常の包含原理のどちらが良いのでしょうか?
答えは前者です。
なぜでしょうか?これほど簡単に結論を出すとは、理由もないのでしょうか。後者は最も一般的に使用される方法ですよ!
その通りですが、後者は下限マークlbを削除したと思われがちですが、実際はそうではありません。
なぜなら、後者は**先頭ゼロ**を考慮する必要があるからです。先頭ゼロを考慮することは、実質的に下限マークを考慮することと同等であり、**ただL=0であるだけ**です。
また、同様に、複数の範囲制限がある場合、この方法は再び包含原理に発展し、非常に面倒になります。
したがって、前者の書き方を推奨します。(もちろん、包含原理の使用を妨げるものではありません!使いたい場合は使ってください!)
古典的な問題の分析
ここでは、P2657 windy数の問題を例として説明します。
入力時は、文字列として統一して読み込み、Lに必要な先頭ゼロを追加します。
dp配列の状態をどのように定義するかを考えます。
状態には以下の要素が存在すると考えられます:
- 現在列挙している第i桁
- 下限マークxおよび上限マークy
- 現在の桁の数字k
以上の要素から、4次元のdp配列を定義できます。
次に遷移を考えます。
まず、i(現在列挙している桁数)、x(下限マーク)、y(上限マーク)および**前の桁の数字**jを列挙します。なぜこれを列挙するのでしょうか?それは遷移を補助するためです。
次に、現在の桁の数字kを列挙します。この数字kの取りうる範囲はxとyに依存します。x=1の場合、kの下限は制限されません。そうでなければ、Lによって制限されます。同様に、y=1の場合、kの上限は制限されません。そうでなければ、Rによって制限されます。
これで遷移できます。私の遷移は拡散型です。
新しいxとy(uとvとします)を即座に計算できます。遷移式は非常に簡単で、dp[i+1][u][v][k] = dp[i+1][u][v][k] + dp[i][x][y][j]ですが、遷移には条件|j-k| ≥ 2を満たす必要があります。これは問題の制約だからです!
しかし、この条件は少し不足しています。現在まだ先頭ゼロの範囲内にある場合、この条件を保証する必要はありません。したがって、x=0かつiがまだ先頭ゼロの範囲内にある場合を特別に判定する必要があります。
これでこの問題は解決します。
コード例1: Windy数
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
string low, high;
int len, diff;
ll dp[20][2][2][10];
int main() {
cin >> low >> high;
len = high.size();
diff = len - low.size();
while (low.size() < len) low = "0" + low;
dp[0][0][0][0] = 1;
for (int i = 0; i < len; i++) {
for (int tight_low = 0; tight_low < 2; tight_low++) {
for (int tight_high = 0; tight_high < 2; tight_high++) {
for (int prev_digit = 0; prev_digit <= 9; prev_digit++) {
if (dp[i][tight_low][tight_high][prev_digit] == 0) continue;
int min_digit = (tight_low ? low[i] - '0' : 0);
int max_digit = (tight_high ? high[i] - '0' : 9);
for (int curr_digit = min_digit; curr_digit <= max_digit; curr_digit++) {
int new_tight_low = tight_low || (curr_digit > (low[i] - '0'));
int new_tight_high = tight_high || (curr_digit < (high[i] - '0'));
bool valid = true;
if (!(tight_low && i <= diff) && abs(prev_digit - curr_digit) < 2) {
valid = false;
}
if (valid) {
dp[i+1][new_tight_low][new_tight_high][curr_digit] +=
dp[i][tight_low][tight_high][prev_digit];
}
}
}
}
}
}
ll result = 0;
for (int tight_low = 0; tight_low < 2; tight_low++) {
for (int tight_high = 0; tight_high < 2; tight_high++) {
for (int digit = 0; digit <= 9; digit++) {
result += dp[len][tight_low][tight_high][digit];
}
}
}
cout << result << endl;
return 0;
}
第二の問題の詳細解説
この問題はP6218 Round Numbers Sです。
この問題は比較的簡単です。
まず、DP処理時に二進数を使用するため、lとrを最初に二進数に変換します(文字列として保存)。
次に、DPの状態設計を考えます。
二進数では最大31桁なので、大胆にdp[i][p][q][x][y]と定義し、これは現在i桁まで考慮し、その中にp個の**有効な**0とq個の1があり、下限マークがx、上限マークがyであるときの「丸数」の数を表します。
どのような0が**有効**と呼ばれるのでしょうか?明らかに、先頭ゼロは無効です。さもなければ、DPする意味がありません。十分な先頭ゼロを追加すれば、すべての数が「丸数」となってしまいます!
したがって、これはこの問題の落とし穴の一つであり、pに先頭ゼロを含めないように特別に判定する必要があります。
その他の点では、遷移は前の問題と似ています。kを列挙する必要がありますが、今回は**jを列挙する必要はありません**。なぜならjは遷移に何の役割も果たさないからです。
最後に答えを累積する際には、xとy、およびpとqを列挙します。もちろん、p ≥ qである答えのみを統計します。
コード例2: Round Numbers
#include <iostream>
#include <string>
#include <algorithm>
using namespace std;
typedef long long ll;
ll l, r;
string L, R;
int n, diff;
ll dp[40][40][40][2][2];
int main() {
cin >> l >> r;
// Convert to binary strings
while (l) {
L = (char)(l % 2 + '0') + L;
l >>= 1;
}
while (r) {
R = (char)(r % 2 + '0') + R;
r >>= 1;
}
n = R.size();
diff = n - L.size();
while (L.size() < n) L = "0" + L;
dp[0][0][0][0][0] = 1;
for (int i = 0; i < n; i++) {
for (int tight_low = 0; tight_low < 2; tight_low++) {
for (int tight_high = 0; tight_high < 2; tight_high++) {
for ( int zeros = 0; zeros <= i; zeros++) {
for (int ones = 0; zeros + ones <= i; ones++) {
if (dp[i][zeros][ones][tight_low][tight_high] == 0) continue;
int min_bit = (tight_low ? L[i] - '0' : 0);
int max_bit = (tight_high ? R[i] - '0' : 1);
for (int bit = min_bit; bit <= max_bit; bit++) {
int new_zeros = zeros + (!bit && !(tight_low && i <= diff));
int new_ones = ones + bit;
int new_tight_low = tight_low || (bit > (L[i] - '0'));
int new_tight_high = tight_high || (bit < (R[i] - '0'));
dp[i+1][new_zeros][new_ones][new_tight_low][new_tight_high] +=
dp[i][zeros][ones][tight_low][tight_high];
}
}
}
}
}
}
ll result = 0;
for (int zeros = 0; zeros <= n; zeros++) {
for (int ones = 0; zeros + ones <= n && ones <= zeros; ones++) {
for (int tight_low = 0; tight_low < 2; tight_low++) {
for (int tight_high = 0; tight_high < 2; tight_high++) {
result += dp[n][zeros][ones][tight_low][tight_high];
}
}
}
}
cout << result << endl;
return 0;
}
第三の問題の詳細解説
この問題は「同種分布」と呼ばれます。
各桁の数字の合計はDPの進行に伴って絶えず変化しています。これでは処理が難しく、取り扱いが困難です。
どうすればよいでしょうか?
実は非常に簡単です。データ範囲はそれほど大きくないので、少し大胆なアプローチを取ります。**最終的な各桁の数字の合計**を列挙する大きなループを外側に追加します。これにより記録が可能になります。
再述すると、dp[i][f][p][x]を定義し、これは現在i桁まで考慮し、現在の各桁の数字の合計がf、現在の値の modulo の結果がp、上下限の状況がxに圧縮して保存されている状態を表します。
主な遷移の難しさは新しいx、f、pを作成することにあります。xは通常の操作なので詳述しません。fは非常に簡単で、現在の数字kを追加するだけです。pはまず10を掛けてからkを加え、その後moduloを取る必要があります。
全体として、それほど難しくないでしょう。
コード例3: 同種分布
#include <iostream>
#include <string>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
string low, high;
int n, diff;
ll dp[20][200][200][4], result;
void init(int mod_value) {
memset(dp, 0, sizeof(dp));
dp[0][0][0][0] = 1;
return;
}
int main() {
cin >> low >> high;
n = high.size();
diff = n - low.size();
while (low.size() < n) low = "0" + low;
for (int mod = 1; mod <= n * 9; mod++) {
init(mod);
for (int i = 0; i < n; i++) {
for (int sum = 0; sum <= mod; sum++) {
for (int flags = 0; flags < 4; flags++) {
for (int remainder = 0; remainder <= mod; remainder++) {
if (dp[i][sum][remainder][flags] == 0) continue;
int min_digit = (flags & 1 ? 0 : low[i] - '0');
int max_digit = (flags & 2 ? 9 : high[i] - '0');
for (int digit = min_digit; digit <= max_digit && sum + digit <= mod; digit++) {
int new_flags = (flags & 1) || (digit > (low[i] - '0')) ?
(flags | 1) : flags;
new_flags = (flags & 2) || (digit < (high[i] - '0')) ?
(new_flags | 2) : new_flags;
int new_sum = sum + digit;
int new_remainder = (remainder * 10 + digit) % mod;
dp[i+1][new_sum][new_remainder][new_flags] +=
dp[i][sum][remainder][flags];
}
}
}
}
}
for (int flags = 0; flags < 4; flags++) {
result += dp[n][mod][0][flags];
}
}
cout << result << endl;
return 0;
}
まとめ
桁DPは、多くの人がメモ化探索を好んで使用します。しかし、私は通常の递归アプローチがより明確で便利だと考えています。
桁DPの基本的なパターンはほとんど同じであり、すべての問題に同じテンプレートを適用できます。中間の詳細を少し修正したり、いくつかのループを追加したりするだけで、別の問題を解決できます。
その中で、上下限のマークが最も重要な要素です。
補足として、桁DPは通常次元数が多いため、状態圧縮とよく組み合わせられます。様々なものを一つの二進数に圧縮することで、コードを簡潔にし、定数を減らすことができます。