はまやんはまやんはまやん

hamayanhamayan's blog

二項係数 mod 素数を高速に計算する方法 [累積和, フェルマーの小定理, 繰り返し二乗法, コンビネーション, 10^9+7]

要望

nCk mod 10^9+7を高速に計算したい
n,k≦10^5

追記:llはlong longのことです
使ってるテンプレートはこんな感じです。

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<b;i++)
using namespace std;
typedef long long ll;

高校で習うやり方でやると…

ll mod = 1000000007;
//---------------------------------------------------------------------------------------------------
ll C(int n, int k) {
    ll res = 1;

    rep(i, 0, k) res = (res * (n - i)) % mod;
    rep(i, 0, k) res = (res / (k - i)) % mod;

    return res;
}

このように書きたい感じがある。
n,kが小さければこれでも良いのだが、問題がいくつかある。
1. 計算量がO(k)なので遅い
2. mod上で割り算できない
この問題を解決していこう。

問題1の解決法:累積和

二項係数の別の計算方法として、C(n,k) = n! / { (n-k)! * k! }というものがある。
この式はwikipediaにも載っている式であるが、こちらで計算するようにする。

ll mod = 1000000007;
//---------------------------------------------------------------------------------------------------
ll C(int n, int k) {
    ll a = 1; // = n!
    ll b = 1; // = (n-k)!
    ll c = 1; // = k!

    rep(i, 0, n) a = (a * (n - i)) % mod;
    rep(i, 0, n - k) b = (b * (n - k - i)) % mod;
    rep(i, 0, k) c = (c * (k - i)) % mod;

    ll bc = (b * c) % mod;

    return (a / bc) % mod;
}

このままだと計算量はO(n)であるが、a,b,cはそれぞれ累積和で事前に計算することができる。
事前にO(n)で累積和を計算しておくことで、この部分を高速化しよう。

ll mod = 1000000007;
//---------------------------------------------------------------------------------------------------
ll f[101010];
void init() {
    f[0] = 1;
    rep(i, 1, 101010) f[i] = (f[i - 1] * i) % mod;
}
//---------------------------------------------------------------------------------------------------
ll C(int n, int k) {
    ll a = f[n]; // = n!
    ll b = f[n-k]; // = (n-k)!
    ll c = f[k]; // = k!

    ll bc = (b * c) % mod;

    return (a / bc) % mod;
}

これで計算量はだいぶ良くなった。

問題2の解決法:フェルマーの小定理と繰り返し二乗法

あとは割り算を解決しよう。
これは「ある数で割るかわりに、ある数の逆数をかける」ことで解決する。
よって、問題が「ある数の逆数はなにか」ということになるが、これは『フェルマーの小定理』で解決する。
フェルマーの小定理とは、ある数xのmod p(pは素数)上での逆数x'はx' = x ^ (p - 2)で計算できるというものである。
よって、今回はpが10^9+7なので、ある数xの逆数はx^(10^9+5)ということになる。

ll mod = 1000000007;
//---------------------------------------------------------------------------------------------------
ll f[101010];
void init() {
    f[0] = 1;
    rep(i, 1, 101010) f[i] = (f[i - 1] * i) % mod;
}
//---------------------------------------------------------------------------------------------------
ll inv(ll x) {
    ll res = 1;
    rep(i, 0, mod - 2) res = (res * x) % mod;
    return res;
}
//---------------------------------------------------------------------------------------------------
ll C(int n, int k) {
    ll a = f[n]; // = n!
    ll b = f[n-k]; // = (n-k)!
    ll c = f[k]; // = k!

    ll bc = (b * c) % mod;

    return (a * inv(bc)) % mod;
}

あともう少しである。これで逆数は求められそうだが、O(mod)は計算できない。
そこで累乗計算の高速化でお馴染みである繰り返し二乗法を用いる。
 
繰り返し二乗法とは、指数部を2の累乗の和に分割することで高速に累乗計算を行う手法である。
例えば、x^180を計算したいとする。
180は4+16+32+128であるため、x^180=x^4*x^16*x^32*x^128と4回で計算ができる。
2の累乗であれば、自分の二乗をすることですばやく作ることができる。
x -> x^2 -> x^4 -> x^8 -> x^16 -> x^32 -> x^64 -> x^128
よって、x^180であれば、x^128まで作る7回と、その途中で該当する累乗を書ける4回の13回で計算が終わる。
わかりにくい場合は、ここここも参考にすると良い
yukicoderにverifyに向く問題がある
 
これを採用すると以下のようになる

ll mod = 1000000007;
//---------------------------------------------------------------------------------------------------
ll f[101010];
void init() {
    f[0] = 1;
    rep(i, 1, 101010) f[i] = (f[i - 1] * i) % mod;
}
//---------------------------------------------------------------------------------------------------
ll inv(ll x) {
    ll res = 1;
    ll k = mod - 2;
    ll y = x;
    while (k) {
        if (k & 1) res = (res * y) % mod;
        y = (y * y) % mod;
        k /= 2;
    }
    return res;
}
//---------------------------------------------------------------------------------------------------
ll C(int n, int k) {
    ll a = f[n]; // = n!
    ll b = f[n-k]; // = (n-k)!
    ll c = f[k]; // = k!

    ll bc = (b * c) % mod;

    return (a * inv(bc)) % mod;
}

これでO(logMOD)なので実用的になった

より高速化

逆数をメモ化しておくことで、前計算O(NlogMOD), クエリO(1)が実現できる

ll mod = 1000000007;
//---------------------------------------------------------------------------------------------------
ll f[101010], rf[101010];
ll inv(ll x) {
    ll res = 1;
    ll k = mod - 2;
    ll y = x;
    while (k) {
        if (k & 1) res = (res * y) % mod;
        y = (y * y) % mod;
        k /= 2;
    }
    return res;
}
void init() {
    f[0] = 1;
    rep(i, 1, 101010) f[i] = (f[i - 1] * i) % mod;
    rep(i, 0, 101010) rf[i] = inv(f[i]);
}
//---------------------------------------------------------------------------------------------------
ll C(int n, int k) {
    ll a = f[n]; // = n!
    ll b = rf[n-k]; // = (n-k)!
    ll c = rf[k]; // = k!

    ll bc = (b * c) % mod;

    return (a * bc) % mod;
}

よく使うのでライブラリとしておくことをおすすめする

前計算をもっと高速化する(krotonさん)

ご指摘いただいた、より高速に前計算する方法を紹介する。
この方法は「(x!)^(-1) * x=((x-1)!)^(-1)」であることを利用する。
O(NlogMOD)のlogMODは逆数を計算するときにかかる計算量であるが、これを1回にすることができる。
具体的には
1. (N!)^(-1)を求める O(logMOD)
2. (N!)^(-1) * Nをして((N - 1)!)^(-1)を求める O(1)
3. ((N - 1)!)^(-1) * (N - 1)をして((N - 2)!)^(-1)を求める O(1)
4. これを繰り返して全ての階乗の逆数を求める
これで計算量はO(N+logMOD)となる。

void init() {
    f[0] = 1;
    rep(i, 1, 101010) f[i] = (f[i - 1] * i) % mod;
    rf[101010 - 1] = inv(f[101010 - 1]);
    rrep(i, 101010 - 2, 0) rf[i] = (rf[i + 1] * (i + 1)) % mod;
}

補足1:nPkとnHkの実装

ll P(int n, int k) {
    ll a = f[n]; // = n!
    ll b = rf[n - k]; // = (n-k)!
    return (a * b) % mod;
}
ll H(int n, int k) {
    return C(n + k - 1, k);
}

補足2: 入力チェック

この問題を通すには入力チェックを適切に行う必要がある。

ll C(int n, int k) {
    if (k < 0 || n < k) return 0;

    ll a = f[n]; // = n!
    ll b = rf[n - k]; // = (n-k)!
    ll c = rf[k]; // = k!

    ll bc = (b * c) % mod;

    return (a * bc) % mod;
}
ll P(int n, int k) {
    if (k < 0 || n < k) return 0;

    ll a = f[n]; // = n!
    ll b = rf[n - k]; // = (n-k)!
    return (a * b) % mod;
}
ll H(int n, int k) {
    if (n == 0 && k == 0) return 1;
    return C(n + k - 1, k);
}

問題