多項式の基本と高速変換技法

多項式の定義と表現形式

多項式とは、有限個の項からなる式 \(f(x) = \sum_{i=0}^{n} a_i x^i\) のことを指す。各項の係数は \(a_i\) で表され、最高次の項の次数を「次数(degree)」と呼ぶ。

多項式の表現方法には主に2種類ある:

  • 係数表現:上記のように係数列 \((a_0, a_1, ..., a_n)\) で表す。
  • 点値表現\(n+1\) 個の異なる点 \((x_i, f(x_i))\) で多項式を一意に決定する。

係数表現から点値表現への変換を評価(evaluation)、逆の操作を補間(interpolation)という。

ラグランジュ補間法

与えられた \(n+1\)\((x_i, y_i)\) を通る \(n\) 次多項式を構成する手法。直接的な連立方程式解法は \(O(n^3)\) だが、ラグランジュの基底多項式を用いることで \(O(n^2)\) で実現できる。

補間多項式は次式で与えられる:

\[ f(x) = \sum_{i=1}^{n+1} y_i \prod_{\substack{j=1 \\ j \ne i}}^{n+1} \frac{x - x_j}{x_i - x_j} \]

特に \(x_i = i\) のような等間隔の場合、階乗とその逆元を前計算することで \(O(n)\) 補間が可能となる。

ラグランジュ補間の実装例
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll MOD = 998244353;

ll modpow(ll a, ll n) {
    ll res = 1;
    while (n) {
        if (n & 1) res = res * a % MOD;
        a = a * a % MOD;
        n >>= 1;
    }
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n; ll k;
    cin >> n >> k;
    vector<pair<ll, ll>> pts(n);
    for (auto& [x, y] : pts) cin >> x >> y;

    ll ans = 0;
    for (int i = 0; i < n; ++i) {
        ll num = pts[i].second, den = 1;
        for (int j = 0; j < n; ++j) {
            if (i == j) continue;
            num = num * ((k - pts[j].first + MOD) % MOD) % MOD;
            den = den * ((pts[i].first - pts[j].first + MOD) % MOD) % MOD;
        }
        ans = (ans + num * modpow(den, MOD - 2)) % MOD;
    }
    cout << ans << '\n';
}

高速フーリエ変換(FFT)

多項式乗算は係数表現では \(O(n^2)\) かかるが、点値表現では各点で値を掛けるだけで済む(\(O(n)\))。この性質を利用して、以下の3ステップで高速化する:

  1. DFT:係数表現 → 点値表現(単位円上の特定点で評価)
  2. 点ごとの乗算
  3. IDFT:点値表現 → 係数表現(補間)

ここで用いられるのが n 次の原始単位根 \(\omega_n = e^{2\pi i / n}\)。重要な性質として:

  • \(\omega_n^{2k} = \omega_{n/2}^k\)
  • \(\omega_n^{n/2 + k} = -\omega_n^k\)\(n\) が偶数のとき)

これにより、再帰的に偶奇に分割して計算でき、全体で \(O(n \log n)\) となる。

実装上は、ビット反転順にデータを並べ替える「バタフライ操作」を事前に行い、再帰を避けた反復的実装が一般的である。

FFTによる多項式乗算
#include <bits/stdc++.h>
using namespace std;

const double PI = acos(-1.0);

struct Complex {
    double real, imag;
    Complex(double r = 0, double i = 0) : real(r), imag(i) {}
    Complex operator+(const Complex& o) const { return {real + o.real, imag + o.imag}; }
    Complex operator-(const Complex& o) const { return {real - o.real, imag - o.imag}; }
    Complex operator*(const Complex& o) const {
        return {real * o.real - imag * o.imag, real * o.imag + imag * o.real};
    }
};

void fft(vector<Complex>& a, bool invert) {
    int n = a.size();
    for (int i = 1, j = 0; i < n; ++i) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) j ^= bit;
        j ^= bit;
        if (i < j) swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        double ang = 2 * PI / len * (invert ? -1 : 1);
        Complex wlen(cos(ang), sin(ang));
        for (int i = 0; i < n; i += len) {
            Complex w(1);
            for (int j = 0; j < len / 2; ++j) {
                Complex u = a[i + j], v = a[i + j + len / 2] * w;
                a[i + j] = u + v;
                a[i + j + len / 2] = u - v;
                w = w * wlen;
            }
        }
    }

    if (invert) {
        for (auto& x : a) {
            x.real /= n;
            x.imag /= n;
        }
    }
}

vector<long long> multiply(const vector<int>& A, const vector<int>& B) {
    vector<Complex> fa(A.begin(), A.end()), fb(B.begin(), B.end());
    int n = 1;
    while (n < (int)(A.size() + B.size())) n <<= 1;
    fa.resize(n); fb.resize(n);

    fft(fa, false); fft(fb, false);
    for (int i = 0; i < n; ++i) fa[i] = fa[i] * fb[i];
    fft(fa, true);

    vector<long long> result(n);
    for (int i = 0; i < n; ++i)
        result[i] = (long long)round(fa[i].real);
    return result;
}

高速メビウス変換(FMT)と高速ウォルシュ変換(FWT)

FFTを拡張し、ビット演算(OR, AND, XOR)に基づく畳み込みを高速に処理する手法。

OR/AND 畳み込み(FMT)

集合の包含関係に基づき、高次元の累積和(or の場合は部分集合和、and の場合は上位集合和)を用いる。

  • FMT(順変換)
    • OR: \(F(S) = \sum_{T \subseteq S} f(T)\)
    • AND: \(F(S) = \sum_{T \supseteq S} f(T)\)
  • 逆変換:包除原理により符号を反転させて差分を取る。

各ビットごとにマージを行い、\(O(n \log n)\) で処理可能。

XOR 畳み込み(FWT)

次のような線形変換を用いる:

\[ F(S) = \sum_{T} (-1)^{|S \cap T|} f(T) \]

この変換は XOR 畳み込みに対して準同型性を持ち、次のように再帰的に計算できる:

(a0, a1) → (a0 + a1, a0 - a1)

逆変換では結果を \(1/n\) 倍する必要がある。

FMT/FWT 実装例
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const ll MOD = 998244353;

ll modpow(ll a, ll n) {
    ll res = 1;
    while (n) {
        if (n & 1) res = res * a % MOD;
        a = a * a % MOD;
        n >>= 1;
    }
    return res;
}

struct Poly {
    vector<ll> a;
    int n;

    Poly(int size = 0) : n(size), a(size) {}

    void fmt_or(bool inv) {
        for (int len = 1; len < n; len <<= 1) {
            for (int i = 0; i < n; i += len << 1) {
                for (int j = 0; j < len; ++j) {
                    if (!inv) {
                        a[i + j + len] = (a[i + j + len] + a[i + j]) % MOD;
                    } else {
                        a[i + j + len] = (a[i + j + len] - a[i + j] + MOD) % MOD;
                    }
                }
            }
        }
    }

    void fmt_and(bool inv) {
        for (int len = 1; len < n; len <<= 1) {
            for (int i = 0; i < n; i += len << 1) {
                for (int j = 0; j < len; ++j) {
                    if (!inv) {
                        a[i + j] = (a[i + j] + a[i + j + len]) % MOD;
                    } else {
                        a[i + j] = (a[i + j] - a[i + j + len] + MOD) % MOD;
                    }
                }
            }
        }
    }

    void fwt_xor(bool inv) {
        for (int len = 1; len < n; len <<= 1) {
            for (int i = 0; i < n; i += len << 1) {
                for (int j = 0; j < len; ++j) {
                    ll u = a[i + j], v = a[i + j + len];
                    a[i + j] = (u + v) % MOD;
                    a[i + j + len] = (u - v + MOD) % MOD;
                }
            }
        }
        if (inv) {
            ll inv_n = modpow(n, MOD - 2);
            for (ll& x : a) x = x * inv_n % MOD;
        }
    }
};

Poly operator|(const Poly& A, const Poly& B) {
    Poly C(A.n);
    Poly X = A, Y = B;
    X.fmt_or(false); Y.fmt_or(false);
    for (int i = 0; i < X.n; ++i) C.a[i] = X.a[i] * Y.a[i] % MOD;
    C.fmt_or(true);
    return C;
}

Poly operator&(const Poly& A, const Poly& B) {
    Poly C(A.n);
    Poly X = A, Y = B;
    X.fmt_and(false); Y.fmt_and(false);
    for (int i = 0; i < X.n; ++i) C.a[i] = X.a[i] * Y.a[i] % MOD;
    C.fmt_and(true);
    return C;
}

Poly operator^(const Poly& A, const Poly& B) {
    Poly C(A.n);
    Poly X = A, Y = B;
    X.fwt_xor(false); Y.fwt_xor(false);
    for (int i = 0; i < X.n; ++i) C.a[i] = X.a[i] * Y.a[i] % MOD;
    C.fwt_xor(true);
    return C;
}

タグ: 多項式 FFT FWT ラグランジュ補間 FMT

5月20日 17:34 投稿