はまやんはまやんはまやん

hamayanhamayan's blog

Xor-sequences [Codeforces 教育 14 : E]

問題

http://codeforces.com/contest/691/problem/E

n 個の数 a1~an が与えられる。
ここから重複を許して、k 個の数列を作る。
隣り合う2つの数の排他的論理和を2進数にしたときの1の個数が全て3の倍数となる数列を「xor-sequence」と呼ぶ。
数列を作るパターンは全部で n^k 個あるが、xor-sequenceである数列の場合の数を数えよ。
大きくなるので、10^9+7を法とする。

1 <= n <= 100
1 <= k <= 10^18
0 <= ai <= 10^18

考察

1. 組合せの個数なので…

  • 計算で求める系かな?
  • DPかな?

2. aCbなどでは無理そう。DPを作る方針で考えてみる
3. 今回は先頭から数を決定していけば、状態をまとめられそう
4. 普通のDPを考えてみると以下のように作れそう

dp[i][j] = 数列の1~i番目が xor-sequence であり、その数列の最後がj番目の数である場合の数
更新式
dp[i + 1][j] = dp[i][j1] + dp[i][j2] + ... + dp[i][jm]
j1~jmは、j番目とj1~jm番目の排他的論理和を取ると、1の個数が3の倍数になる要素を指す

5. しかし、これでは、dp[10^18][100]であり、計算ができない!どうしよう!
6. 「行列累乗」が使えます

行列累乗とは、高速累乗の行列版で2乗2乗しながら行列掛けてくやつです(詳しくは他サイトで)。
使う条件は、

  • 漸化式が作れる!
  • O(log n)

7. 最初の段階から10^18ステップまで計算する必要があるので、O(log n)解だろうなぁという感じはあった
8. O(log n)解は手法がそんなにないので、そこから考えると分かる(知ってれば)

O(log n)

9. まず数列の各組合せについて、横に置けるか判定して、置ければ1, 置けなければ0とします -> mat[i][j]
10. そうすると、dpの1回の計算は以下のように行える

| dp'[0]   |   | mat[0][0] mat[0][1] ... mat[0][n-1]   |   | dp[0]   |
| dp'[1]   |   | mat[1][0]                             |   | dp[1]   |
| dp'[2]   | = | mat[2][0]                             | * | dp[2]   |
| ...      |   | ...                                   |   | ...     |
| dp'[n-1] |   | mat[n-1][0] ...         mat[n-1][n-1] |   | dp[n-1] |

11. 数列がk個ある場合は、この処理を k-1 回繰り返せば良いので、行列 mat の k-1 乗を高速に計算して終了です

実装

http://codeforces.com/contest/691/submission/19099641

typedef long long ll;
typedef vector<ll> vi;
typedef vector<vi> vvi;
#define MOD 1000000007
vi mul_mat_vec(vvi a, vi b)
{
	int n = b.size();
	vi ret(n, 0);
	rep(i, 0, n)
	{
		ret[i] = 0;
		rep(j, 0, n) ret[i] = (ret[i] + a[i][j] * b[j]) % MOD;
	}
	return ret;
}

vvi mul_mat_mat(vvi a, vvi b)
{
	int n = a.size();
	vvi ret(n, vi(n, 0));
	rep(i, 0, n) rep(j, 0, n)
	{
		ret[i][j] = 0;
		rep(k, 0, n) ret[i][j] = (ret[i][j] + a[i][k] * b[k][j]) % MOD;
	}
	return ret;
}
vvi fastpow(vvi x, ll n)
{
	vvi ret(x.size(), vi(x.size(), 0));
	rep(i, 0, x.size()) ret[i][i] = 1;

	while (0 < n)
	{
		if ((n % 2) == 0)
		{
			x = mul_mat_mat(x, x);
			n >>= 1;
		}
		else
		{
			ret = mul_mat_mat(ret, x);
			--n;
		}
	}
	return ret;
}
//-----------------------------------------------------------------
int n;
ll k;
vi a;
//-----------------------------------------------------------------
int main() {
	while (cin >> n) {
		scanf("%I64d", &k);
		a.resize(n);
		rep(i, 0, n) scanf("%I64d", &a[i]);

		vvi mat(n, vi(n, 0));
		rep(i, 0, n) rep(j, i, n) {
			ll e = a[i] ^ a[j];
			int cnt = 0;
			while (0 < e) {
				if (e % 2 != 0) cnt++;
				e >>= 1;
			}
			if (cnt % 3 == 0) {
				mat[i][j] = mat[j][i] = 1;
			}
		}

		vi init(n, 1);
		vi ret = mul_mat_vec(fastpow(mat, k - 1), init);
		ll ans = 0;
		rep(i, 0, n) ans += ret[i];
		ans %= MOD;
		cout << ans << endl;
	}
}