C++ アルゴリズムとデータ構造のライブラリ
#include "library/sequence/wavelet_matrix.hpp"access(k): $O(logV)$ 元配列のk番目の値を取得rank(x, r): $O(logV)$ 区間[0, r)に含まれる値xの個数kth_smallest(l, r, k): $O(logV)$ 区間[l, r)内でk番目(0-indexed)に小さい値kth_largest(l, r, k): $O(logV)$ 区間[l, r)内でk番目(0-indexed)に大きい値range_freq(l, r, lower, upper): $O(logV)$ 区間[l, r)内で値が $lower \le x \lt upper$ なものの個数prev_value(l, r, upper): $O(logV)$ 区間[l, r)内でupperより小さいものの中の最大値next_value(l, r, lower): $O(logV)$ 区間[l, r)内でlowerより大きいものの中の最小値※ $V$ は扱う値のビット数
vector<int> a = {5, 2, 8, 5, 1, 3};
WaveletMatrix wm(a);
// 0〜4番目の範囲で、小さい方から0番目(最小値)を取得
int val = wm.kth_smallest(0, 4, 0); // 2
// 0〜6番目の範囲で、3以上6未満の値が何個あるか
int count = wm.range_freq(0, 6, 3, 6); // 5, 5, 3 の3個
#pragma once
#include "library/sequence/bit_dict.hpp"
struct WaveletMatrix {
int n;
int max_log;
vector<BitDict> matrix;
vector<int> mid_points; // 各段での 0 と 1 の境界線 (0の個数)
// 構築: O(N log V)
WaveletMatrix(vector<long long> v, int max_log = 32)
: n(v.size()), max_log(max_log) {
matrix.assign(max_log, BitDict(n));
mid_points.resize(max_log);
vector<long long> left(n), right(n);
for (int d = max_log - 1; d >= 0; d--) {
vector<long long> l_vals, r_vals;
for (int i = 0; i < n; i++) {
if ((v[i] >> d) & 1) {
matrix[d].set(i);
r_vals.push_back(v[i]);
} else {
l_vals.push_back(v[i]);
}
}
matrix[d].build();
mid_points[d] = l_vals.size();
// v を次の段のために並び替える (0を前、1を後に集める)
v = l_vals;
v.insert(v.end(), r_vals.begin(), r_vals.end());
}
}
// k番目の値を返す: O(log V)
long long access(int k) const {
long long res = 0;
for (int d = max_log - 1; d >= 0; d--) {
bool bit = matrix[d].access(k);
if (bit) {
res |= (1LL << d);
k = mid_points[d] + matrix[d].rank1(k);
} else {
k = matrix[d].rank0(k);
}
}
return res;
}
// [0, r) に含まれる x の個数: O(log V)
int rank(long long x, int r) const {
int l = 0;
for (int d = max_log - 1; d >= 0; d--) {
bool bit = (x >> d) & 1;
if (bit) {
l = mid_points[d] + matrix[d].rank1(l);
r = mid_points[d] + matrix[d].rank1(r);
} else {
l = matrix[d].rank0(l);
r = matrix[d].rank0(r);
}
}
return r - l;
}
// [l, r) 内で k 番目に小さい値: O(log V)
long long kth_smallest(int l, int r, int k) const {
long long res = 0;
for (int d = max_log - 1; d >= 0; d--) {
int cnt0 = matrix[d].rank0(r) - matrix[d].rank0(l);
if (k < cnt0) {
l = matrix[d].rank0(l);
r = matrix[d].rank0(r);
} else {
res |= (1LL << d);
k -= cnt0;
l = mid_points[d] + matrix[d].rank1(l);
r = mid_points[d] + matrix[d].rank1(r);
}
}
return res;
}
// [l, r) 内で k 番目に大きい値: O(log V)
long long kth_largest(int l, int r, int k) const {
return kth_smallest(l, r, (r - l) - 1 - k);
}
// [l, r) 内で [lower, upper) に含まれる要素数: O(log V)
int range_freq(int l, int r, long long lower, long long upper) const {
return count_less(l, r, upper) - count_less(l, r, lower);
}
// [l, r) 内で val 未満の要素数 (内部用補助関数)
int count_less(int l, int r, long long val) const {
int res = 0;
for (int d = max_log - 1; d >= 0; d--) {
bool bit = (val >> d) & 1;
int cnt0 = matrix[d].rank0(r) - matrix[d].rank0(l);
if (bit) {
res += cnt0; // 0のビットを持つものは確実に val より小さい
l = mid_points[d] + matrix[d].rank1(l);
r = mid_points[d] + matrix[d].rank1(r);
} else {
l = matrix[d].rank0(l);
r = matrix[d].rank0(r);
}
}
return res;
}
// [l, r) 内で upper より小さい最大値
long long prev_value(int l, int r, long long upper) const {
int cnt = count_less(l, r, upper);
return (cnt == 0) ? -1 : kth_smallest(l, r, cnt - 1);
}
// [l, r) 内で lower 以上の最小値
long long next_value(int l, int r, long long lower) const {
int cnt = count_less(l, r, lower);
return (cnt == (r - l)) ? -1 : kth_smallest(l, r, cnt);
}
};#line 2 "library/sequence/bit_dict.hpp"
struct BitDict {
using uint = uint64_t;
int n;
vector<uint> bit; // ビット列本体
vector<int> sum; // 累積和(各ワード開始時点での1の総数)
BitDict() {} // 空のコンストラクタ(ウェーブレット行列のvector確保用)
// 64ビット単位で格納するため、(n/64)+1 個の要素を確保
BitDict(int n) : n(n) { // n は扱うビット列の長さ(最大インデックス + 1)
bit.assign((n >> 6) + 1, 0);
}
// k番目のビットを1にする
void set(int k) { bit[k >> 6] |= (1ULL << (k & 63)); }
// 累積和を構築する(setの後に必ず呼ぶ)
void build() {
sum.assign(bit.size() + 1, 0);
for (int i = 0; i < (int)bit.size(); i++) {
sum[i + 1] = sum[i] + __builtin_popcountll(bit[i]);
}
}
// k番目のビットを取得
bool access(int k) const { return (bit[k >> 6] >> (k & 63)) & 1; }
// [0, k) 内の 1 の個数
int rank1(int k) const {
int idx = k >> 6;
int offset = k & 63;
uint mask = (1ULL << offset) - 1;
return sum[idx] + __builtin_popcountll(bit[idx] & mask);
}
// [0, k) 内の 0 の個数(ウェーブレット行列で多用する)
int rank0(int k) const { return k - rank1(k); }
// j番目(1-indexed)の1の位置: O(log N)
int select(int j) const {
if (j <= 0 || j > sum.back()) return -1;
int left = 0, right = n;
while (right - left > 1) {
int mid = (left + right) / 2;
if (rank1(mid) >= j)
right = mid;
else
left = mid;
}
return left;
}
};
#line 3 "library/sequence/wavelet_matrix.hpp"
struct WaveletMatrix {
int n;
int max_log;
vector<BitDict> matrix;
vector<int> mid_points; // 各段での 0 と 1 の境界線 (0の個数)
// 構築: O(N log V)
WaveletMatrix(vector<long long> v, int max_log = 32)
: n(v.size()), max_log(max_log) {
matrix.assign(max_log, BitDict(n));
mid_points.resize(max_log);
vector<long long> left(n), right(n);
for (int d = max_log - 1; d >= 0; d--) {
vector<long long> l_vals, r_vals;
for (int i = 0; i < n; i++) {
if ((v[i] >> d) & 1) {
matrix[d].set(i);
r_vals.push_back(v[i]);
} else {
l_vals.push_back(v[i]);
}
}
matrix[d].build();
mid_points[d] = l_vals.size();
// v を次の段のために並び替える (0を前、1を後に集める)
v = l_vals;
v.insert(v.end(), r_vals.begin(), r_vals.end());
}
}
// k番目の値を返す: O(log V)
long long access(int k) const {
long long res = 0;
for (int d = max_log - 1; d >= 0; d--) {
bool bit = matrix[d].access(k);
if (bit) {
res |= (1LL << d);
k = mid_points[d] + matrix[d].rank1(k);
} else {
k = matrix[d].rank0(k);
}
}
return res;
}
// [0, r) に含まれる x の個数: O(log V)
int rank(long long x, int r) const {
int l = 0;
for (int d = max_log - 1; d >= 0; d--) {
bool bit = (x >> d) & 1;
if (bit) {
l = mid_points[d] + matrix[d].rank1(l);
r = mid_points[d] + matrix[d].rank1(r);
} else {
l = matrix[d].rank0(l);
r = matrix[d].rank0(r);
}
}
return r - l;
}
// [l, r) 内で k 番目に小さい値: O(log V)
long long kth_smallest(int l, int r, int k) const {
long long res = 0;
for (int d = max_log - 1; d >= 0; d--) {
int cnt0 = matrix[d].rank0(r) - matrix[d].rank0(l);
if (k < cnt0) {
l = matrix[d].rank0(l);
r = matrix[d].rank0(r);
} else {
res |= (1LL << d);
k -= cnt0;
l = mid_points[d] + matrix[d].rank1(l);
r = mid_points[d] + matrix[d].rank1(r);
}
}
return res;
}
// [l, r) 内で k 番目に大きい値: O(log V)
long long kth_largest(int l, int r, int k) const {
return kth_smallest(l, r, (r - l) - 1 - k);
}
// [l, r) 内で [lower, upper) に含まれる要素数: O(log V)
int range_freq(int l, int r, long long lower, long long upper) const {
return count_less(l, r, upper) - count_less(l, r, lower);
}
// [l, r) 内で val 未満の要素数 (内部用補助関数)
int count_less(int l, int r, long long val) const {
int res = 0;
for (int d = max_log - 1; d >= 0; d--) {
bool bit = (val >> d) & 1;
int cnt0 = matrix[d].rank0(r) - matrix[d].rank0(l);
if (bit) {
res += cnt0; // 0のビットを持つものは確実に val より小さい
l = mid_points[d] + matrix[d].rank1(l);
r = mid_points[d] + matrix[d].rank1(r);
} else {
l = matrix[d].rank0(l);
r = matrix[d].rank0(r);
}
}
return res;
}
// [l, r) 内で upper より小さい最大値
long long prev_value(int l, int r, long long upper) const {
int cnt = count_less(l, r, upper);
return (cnt == 0) ? -1 : kth_smallest(l, r, cnt - 1);
}
// [l, r) 内で lower 以上の最小値
long long next_value(int l, int r, long long lower) const {
int cnt = count_less(l, r, lower);
return (cnt == (r - l)) ? -1 : kth_smallest(l, r, cnt);
}
};