多項式の定義と表現形式
多項式とは、有限個の項からなる式 \(f(x) = \sum_{i=0}^{n} a_i x^i\) のことを指す。各項の係数は \(a_i\) で表され、最高次の項の次数を「次数(degree)」と呼ぶ。
多項式の表現方法には主に2種類ある:
- 係数表現:上記のように係数列 \((a_0, a_1, ..., a_n)\) で表す。
- 点値表現:\(n+1\) 個の異なる点 \((x_i, f(x_i))\) で多項式を一意に決定する。
係数表現から点値表現への変換を評価(evaluation)、逆の操作を補間(interpolation)という。
ラグランジュ補間法
与えられた \(n+1\) 点 \((x_i, y_i)\) を通る \(n\) 次多項式を構成する手法。直接的な連立方程式解法は \(O(n^3)\) だが、ラグランジュの基底多項式を用いることで \(O(n^2)\) で実現できる。
補間多項式は次式で与えられる:
特に \(x_i = i\) のような等間隔の場合、階乗とその逆元を前計算することで \(O(n)\) 補間が可能となる。
ラグランジュ補間の実装例
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const ll MOD = 998244353;
ll modpow(ll a, ll n) {
ll res = 1;
while (n) {
if (n & 1) res = res * a % MOD;
a = a * a % MOD;
n >>= 1;
}
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n; ll k;
cin >> n >> k;
vector<pair<ll, ll>> pts(n);
for (auto& [x, y] : pts) cin >> x >> y;
ll ans = 0;
for (int i = 0; i < n; ++i) {
ll num = pts[i].second, den = 1;
for (int j = 0; j < n; ++j) {
if (i == j) continue;
num = num * ((k - pts[j].first + MOD) % MOD) % MOD;
den = den * ((pts[i].first - pts[j].first + MOD) % MOD) % MOD;
}
ans = (ans + num * modpow(den, MOD - 2)) % MOD;
}
cout << ans << '\n';
}
高速フーリエ変換(FFT)
多項式乗算は係数表現では \(O(n^2)\) かかるが、点値表現では各点で値を掛けるだけで済む(\(O(n)\))。この性質を利用して、以下の3ステップで高速化する:
- DFT:係数表現 → 点値表現(単位円上の特定点で評価)
- 点ごとの乗算
- IDFT:点値表現 → 係数表現(補間)
ここで用いられるのが n 次の原始単位根 \(\omega_n = e^{2\pi i / n}\)。重要な性質として:
- \(\omega_n^{2k} = \omega_{n/2}^k\)
- \(\omega_n^{n/2 + k} = -\omega_n^k\)(\(n\) が偶数のとき)
これにより、再帰的に偶奇に分割して計算でき、全体で \(O(n \log n)\) となる。
実装上は、ビット反転順にデータを並べ替える「バタフライ操作」を事前に行い、再帰を避けた反復的実装が一般的である。
FFTによる多項式乗算
#include <bits/stdc++.h>
using namespace std;
const double PI = acos(-1.0);
struct Complex {
double real, imag;
Complex(double r = 0, double i = 0) : real(r), imag(i) {}
Complex operator+(const Complex& o) const { return {real + o.real, imag + o.imag}; }
Complex operator-(const Complex& o) const { return {real - o.real, imag - o.imag}; }
Complex operator*(const Complex& o) const {
return {real * o.real - imag * o.imag, real * o.imag + imag * o.real};
}
};
void fft(vector<Complex>& a, bool invert) {
int n = a.size();
for (int i = 1, j = 0; i < n; ++i) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
double ang = 2 * PI / len * (invert ? -1 : 1);
Complex wlen(cos(ang), sin(ang));
for (int i = 0; i < n; i += len) {
Complex w(1);
for (int j = 0; j < len / 2; ++j) {
Complex u = a[i + j], v = a[i + j + len / 2] * w;
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
w = w * wlen;
}
}
}
if (invert) {
for (auto& x : a) {
x.real /= n;
x.imag /= n;
}
}
}
vector<long long> multiply(const vector<int>& A, const vector<int>& B) {
vector<Complex> fa(A.begin(), A.end()), fb(B.begin(), B.end());
int n = 1;
while (n < (int)(A.size() + B.size())) n <<= 1;
fa.resize(n); fb.resize(n);
fft(fa, false); fft(fb, false);
for (int i = 0; i < n; ++i) fa[i] = fa[i] * fb[i];
fft(fa, true);
vector<long long> result(n);
for (int i = 0; i < n; ++i)
result[i] = (long long)round(fa[i].real);
return result;
}
高速メビウス変換(FMT)と高速ウォルシュ変換(FWT)
FFTを拡張し、ビット演算(OR, AND, XOR)に基づく畳み込みを高速に処理する手法。
OR/AND 畳み込み(FMT)
集合の包含関係に基づき、高次元の累積和(or の場合は部分集合和、and の場合は上位集合和)を用いる。
- FMT(順変換):
- OR: \(F(S) = \sum_{T \subseteq S} f(T)\)
- AND: \(F(S) = \sum_{T \supseteq S} f(T)\)
- 逆変換:包除原理により符号を反転させて差分を取る。
各ビットごとにマージを行い、\(O(n \log n)\) で処理可能。
XOR 畳み込み(FWT)
次のような線形変換を用いる:
この変換は XOR 畳み込みに対して準同型性を持ち、次のように再帰的に計算できる:
(a0, a1) → (a0 + a1, a0 - a1)
逆変換では結果を \(1/n\) 倍する必要がある。
FMT/FWT 実装例
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const ll MOD = 998244353;
ll modpow(ll a, ll n) {
ll res = 1;
while (n) {
if (n & 1) res = res * a % MOD;
a = a * a % MOD;
n >>= 1;
}
return res;
}
struct Poly {
vector<ll> a;
int n;
Poly(int size = 0) : n(size), a(size) {}
void fmt_or(bool inv) {
for (int len = 1; len < n; len <<= 1) {
for (int i = 0; i < n; i += len << 1) {
for (int j = 0; j < len; ++j) {
if (!inv) {
a[i + j + len] = (a[i + j + len] + a[i + j]) % MOD;
} else {
a[i + j + len] = (a[i + j + len] - a[i + j] + MOD) % MOD;
}
}
}
}
}
void fmt_and(bool inv) {
for (int len = 1; len < n; len <<= 1) {
for (int i = 0; i < n; i += len << 1) {
for (int j = 0; j < len; ++j) {
if (!inv) {
a[i + j] = (a[i + j] + a[i + j + len]) % MOD;
} else {
a[i + j] = (a[i + j] - a[i + j + len] + MOD) % MOD;
}
}
}
}
}
void fwt_xor(bool inv) {
for (int len = 1; len < n; len <<= 1) {
for (int i = 0; i < n; i += len << 1) {
for (int j = 0; j < len; ++j) {
ll u = a[i + j], v = a[i + j + len];
a[i + j] = (u + v) % MOD;
a[i + j + len] = (u - v + MOD) % MOD;
}
}
}
if (inv) {
ll inv_n = modpow(n, MOD - 2);
for (ll& x : a) x = x * inv_n % MOD;
}
}
};
Poly operator|(const Poly& A, const Poly& B) {
Poly C(A.n);
Poly X = A, Y = B;
X.fmt_or(false); Y.fmt_or(false);
for (int i = 0; i < X.n; ++i) C.a[i] = X.a[i] * Y.a[i] % MOD;
C.fmt_or(true);
return C;
}
Poly operator&(const Poly& A, const Poly& B) {
Poly C(A.n);
Poly X = A, Y = B;
X.fmt_and(false); Y.fmt_and(false);
for (int i = 0; i < X.n; ++i) C.a[i] = X.a[i] * Y.a[i] % MOD;
C.fmt_and(true);
return C;
}
Poly operator^(const Poly& A, const Poly& B) {
Poly C(A.n);
Poly X = A, Y = B;
X.fwt_xor(false); Y.fwt_xor(false);
for (int i = 0; i < X.n; ++i) C.a[i] = X.a[i] * Y.a[i] % MOD;
C.fwt_xor(true);
return C;
}