輪郭をなぞるだけのブログ

浅学菲才のためにおそらく嘘も多い

漸化式の一般項の値を高速に求める

隣接 k 項間漸化式の一般項 a_{n} の値を高速に求める手法を最近知ったのでメモをしておく。
一般項を求めるのは、 k \times k 行列で行列累乗を用いると、O(k^{3} logn) で求めることができる。
ただ、きたまさ法と呼ばれる手法を用いれば、O(k^{2} logn) で求めることができる。

表記など怪しいので、詳しくは参考文献を見ることをお薦めする。

k 項の初期値と、隣接 k 項の関係が以下のように与えられているとする。
a = \{a_{0}, a_{1}, ..., a_{k-1}\},\ d = \{d_{0}, d_{1}, ..., d_{k-1}\}
a_{k} = d_{0}a_{0} + d_{1}a_{1} + ... + d_{k-1}a_{k-1}
ただし、a は初期値のベクトル、da_{k} と隣接 k 項間の関係を書き表すための係数ベクトルである。( 一般的な表記とは異なっているので注意。)

ここで、次のような k 次元ベクトル f(n) を導入する。
f(n) = \{x_{0}, x_{1}, ..., x_{k-1}\}
このとき、a_{n} は以下のように書けるものと定義する。
a_{n} = x_{0}a_{0} + x_{1}a_{1} + ... + x_{k-1}a_{k-1} \ (n \geq k)
n = k のとき、f(k) = d である。

a_{n} を求めるために、f(n) を求めることを考える。
f(N) が分かっているときに、f(2N) が求めることができれば、ダブリングの要領で、k から始めて O(logn) 回の計算で f(n) を求めることができる。

まず、f(N) が分かっているとき、f(N+1)O(k) で求めることができる。
f(N)N 項前から k 個取った項との関係を表すものであるので、
a_{N} = x_{0}a_{0} + x_{1}a_{1} + ... + x_{k-1}a_{k-1}
のとき、すべての a の添え字に +1 をして、
a_{N+1} = x_{0}a_{1} + x_{1}a_{2} + ... + x_{k-1}a_{k}
これに a_{k} の式を代入してまとめると、
a_{N+1} = d_{0}x_{k-1} \cdot a_{0} + (x_{1} + d_{1}x_{k-1}) \cdot a_{1} + ... + (x_{k-2} + d_{k-1}x_{k-1}) \cdot a_{k-1}
すなわち、
f(N+1) = \{d_{0}x_{k-1}, x_{1} + d_{1}x_{k-1}, ..., x_{k-2} + d_{k-1}x_{k-1}\}
となる。

次に、f(N) が分かっているときに、f(2N) を求めることを考える。
a_{N} = x_{0}a_{0} + x_{1}a_{1} + ... + x_{k-1}a_{k-1}
のとき、すべての a の添え字に +N をすると、
a_{2N} = x_{0}a_{N} + x_{1}a_{N+1} + ... + x_{k-1}a_{N-k-1}
となるので、
f(N),\ f(N+1),\ ...,\ f(N+k-1)k 個のベクトルが分かっていれば、f(2N) を求めることができる。
k 個のベクトルは、f(N) から始めて O(k) で求めることができる。
各計算は前述の方法と同様 O(k) であるので、合計で O(k^{2}) となる。
そして、f(2N) の計算では、各ベクトルから a_{0} の係数、a_{1} の係数... を足し込めばいいので、これも O(k^{2}) で求めることができる。

よって、全体で O(k^{2} logn) の計算量で、a_{n} を求めるための係数ベクトルが求まる。

コード
#include <iostream>
#include <vector>
using namespace std;

template<typename T>
struct Kitamasa{
    vector<T> a;    // 初期値ベクトル
    vector<T> d;    // 係数ベクトル
    int k;
    
    Kitamasa(vector<T>& a, vector<T>& d) : a(a), d(d), k((int)a.size()) {}
    
    // a_n の係数ベクトルを求める
    vector<T> dfs(int64_t n){
        if(n == k)  return d;
        
        vector<T> res(k);
        if(n & 1 || n < k * 2){
            vector<T> x = dfs(n - 1);
            for(int i = 0; i < k; ++i)  res[i] = d[i] * x[k - 1];
            for(int i = 0; i + 1 < k; ++i)  res[i + 1] += x[i];
        }
        else{
            vector<vector<T>> x(k, vector<T>(k));
            x[0] = dfs(n >> 1);
            for(int i = 0; i + 1 < k; ++i){
                for(int j = 0; j < k; ++j)  x[i + 1][j] = d[j] * x[i][k - 1];
                for(int j = 0; j + 1 < k; ++j)  x[i + 1][j + 1] += x[i][j];
            }
            for(int i = 0; i < k; ++i){
                for(int j = 0; j < k; ++j){
                    res[j] += x[0][i] * x[i][j];
                }
            }
        }

        return res;
    }

    // a_n を求める
    T calc(int64_t n){
        vector<T> x = dfs(n);
        T res = 0;
        for(int i = 0; i < k; ++i)  res += x[i] * a[i];
        return res;
    }
};
関連問題

T - フィボナッチ

雑記

世の中天才しかいないな…