serna37's Library

Logo

C++ アルゴリズムとデータ構造のライブラリ

View the Project on GitHub serna37/library-cpp

:heavy_check_mark: Wavelet Search Engine
(library/search/binary_search/wavelet_search_engine.hpp)

Wavelet Search Engine

ウェーブレット行列 + 座標圧縮での、各種問い合わせ関数を整備したもの

できること

計算量

$N$ を配列サイズ、$V$ を配列に含まれるユニークな値の種類数とします。

使い方

構築

vector<long long> A = {10, 40, 20, 10, 30};
WaveletSearchEngine engine(A);

// 区間 [1, 4) すなわち {40, 20, 10} が対象
int l = 1, r = 4;

// 個数を数える (count系)
engine.count_ge(l, r, 20); // 20以上の個数 -> 2 (40, 20)
engine.range_freq(l, r, 15, 35); // 15以上35未満の個数 -> 1 (20)

// 値を探す (find系)
// 見つからない場合は構築時に設定した INF_VAL または -INF_VAL が返ります
// engine.INF_VAL
// と比較すること
engine.find_lt(l, r, 25); // 25未満で最大の要素 -> 20
engine.find_ge(l, r, 20); // 20以上で最小の要素 -> 20

// K番目の値 (0-indexed)
engine.kth_smallest(l, r, 0); // 区間最小値 -> 10
engine.kth_largest(l, r, 0);  // 区間最大値 -> 40

Depends on

Verified with

Code

#pragma once
#include "library/sequence/compressor.hpp"
#include "library/sequence/wavelet_matrix.hpp"
struct WaveletSearchEngine {
    int n;
    Compressor<long long> cp;
    WaveletMatrix wm;
    const long long INF_VAL = 1e18;
    WaveletSearchEngine(const vector<long long> &v)
        : n(v.size()), cp(v), wm({}, 0) {
        vector<int> compressed_all = cp.get_all();
        vector<long long> wm_input(compressed_all.begin(),
                                   compressed_all.end());
        int max_log = 0;
        while ((1LL << max_log) <= (int)cp.size()) max_log++;
        if (max_log == 0) max_log = 1; // 要素が空または1種類の場合
        wm = WaveletMatrix(wm_input, max_log);
    }
    // --- 基本操作 ---
    // i 番目の元の値を返す
    long long access(int i) const { return cp.get_val(wm.access(i)); }
    // [l, r) 内に x が何個含まれるか
    int count_x(int l, int r, long long x) const {
        int id = cp.get_id(x);
        if (id >= cp.size() || cp.get_val(id) != x) return 0;
        return wm.rank(id, r) - wm.rank(id, l);
    }
    // --- 二分探索系統合インターフェース (bi_..._cnt) ---
    // [l, r) 内で x 未満 (Less Than) の個数
    int count_lt(int l, int r, long long x) const {
        return wm.count_less(l, r, cp.get_id(x));
    }
    // [l, r) 内で x 以下 (Less Equal) の個数
    int count_le(int l, int r, long long x) const {
        return wm.count_less(l, r, cp.get_upper_id(x));
    }
    // [l, r) 内で x 以上 (Greater Equal) の個数
    int count_ge(int l, int r, long long x) const {
        return (r - l) - count_lt(l, r, x);
    }
    // [l, r) 内で x より大きい (Greater Than) の個数
    int count_gt(int l, int r, long long x) const {
        return (r - l) - count_le(l, r, x);
    }
    // --- 値検索系統合インターフェース (bi_..._val) ---
    // [l, r) 内で k 番目に小さい値 (0-indexed)
    long long kth_smallest(int l, int r, int k) const {
        if (k < 0 || k >= (r - l)) return INF_VAL;
        return cp.get_val(wm.kth_smallest(l, r, k));
    }
    // [l, r) 内で k 番目に大きい値 (0-indexed)
    long long kth_largest(int l, int r, int k) const {
        if (k < 0 || k >= (r - l)) return -INF_VAL;
        return cp.get_val(wm.kth_largest(l, r, k));
    }
    // [l, r) 内で x 未満の最大値 (Less Than Value)
    long long find_lt(int l, int r, long long x) const {
        int cnt = count_lt(l, r, x);
        return (cnt == 0) ? -INF_VAL : kth_smallest(l, r, cnt - 1);
    }
    // [l, r) 内で x 以下の最大値 (Less Equal Value)
    long long find_le(int l, int r, long long x) const {
        int cnt = count_le(l, r, x);
        return (cnt == 0) ? -INF_VAL : kth_smallest(l, r, cnt - 1);
    }
    // [l, r) 内で x 以上の最小値 (Greater Equal Value)
    long long find_ge(int l, int r, long long x) const {
        int cnt = count_lt(l, r, x);
        return (cnt == (r - l)) ? INF_VAL : kth_smallest(l, r, cnt);
    }
    // [l, r) 内で x より大きい最小値 (Greater Than Value)
    long long find_gt(int l, int r, long long x) const {
        int cnt = count_le(l, r, x);
        return (cnt == (r - l)) ? INF_VAL : kth_smallest(l, r, cnt);
    }
    // --- 応用クエリ ---
    // [l, r) 内で [lower, upper) に含まれる要素数
    int range_freq(int l, int r, long long lower, long long upper) const {
        if (lower >= upper) return 0;
        return wm.count_less(l, r, cp.get_id(upper)) -
               wm.count_less(l, r, cp.get_id(lower));
    }
};
#line 2 "library/sequence/compressor.hpp"
template <typename T> struct Compressor {
    vector<T> origin, dict;
    Compressor(const vector<T> &v) : origin(v), dict(v) {
        sort(dict.begin(), dict.end());
        dict.erase(unique(dict.begin(), dict.end()), dict.end());
    }
    int size() const { return dict.size(); }
    // 値 -> ID (圧縮)
    int get_id(T x) const {
        return lower_bound(dict.begin(), dict.end(), x) - dict.begin();
    }
    // 値 -> ID (upper_bound版)
    int get_upper_id(T x) const {
        return upper_bound(dict.begin(), dict.end(), x) - dict.begin();
    }
    // ID -> 値 (復元)
    T get_val(int id) const { return dict[id]; }
    // すべて圧縮
    vector<int> get_all() {
        vector<int> res;
        for (auto &&x : origin) res.emplace_back(get_id(x));
        return res;
    }
};
#line 2 "library/sequence/bit_dict.hpp"
struct BitDict {
    using uint = uint64_t;
    int n;
    vector<uint> bit; // ビット列本体
    vector<int> sum;  // 累積和(各ワード開始時点での1の総数)
    BitDict() {}      // 空のコンストラクタ(ウェーブレット行列のvector確保用)
    // 64ビット単位で格納するため、(n/64)+1 個の要素を確保
    BitDict(int n) : n(n) { // n は扱うビット列の長さ(最大インデックス + 1)
        bit.assign((n >> 6) + 1, 0);
    }
    // k番目のビットを1にする
    void set(int k) { bit[k >> 6] |= (1ULL << (k & 63)); }
    // 累積和を構築する(setの後に必ず呼ぶ)
    void build() {
        sum.assign(bit.size() + 1, 0);
        for (int i = 0; i < (int)bit.size(); i++) {
            sum[i + 1] = sum[i] + __builtin_popcountll(bit[i]);
        }
    }
    // k番目のビットを取得
    bool access(int k) const { return (bit[k >> 6] >> (k & 63)) & 1; }
    // [0, k) 内の 1 の個数
    int rank1(int k) const {
        int idx = k >> 6;
        int offset = k & 63;
        uint mask = (1ULL << offset) - 1;
        return sum[idx] + __builtin_popcountll(bit[idx] & mask);
    }
    // [0, k) 内の 0 の個数(ウェーブレット行列で多用する)
    int rank0(int k) const { return k - rank1(k); }
    // j番目(1-indexed)の1の位置: O(log N)
    int select(int j) const {
        if (j <= 0 || j > sum.back()) return -1;
        int left = 0, right = n;
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (rank1(mid) >= j)
                right = mid;
            else
                left = mid;
        }
        return left;
    }
};
#line 3 "library/sequence/wavelet_matrix.hpp"
struct WaveletMatrix {
    int n;
    int max_log;
    vector<BitDict> matrix;
    vector<int> mid_points; // 各段での 0 と 1 の境界線 (0の個数)
    // 構築: O(N log V)
    WaveletMatrix(vector<long long> v, int max_log = 32)
        : n(v.size()), max_log(max_log) {
        matrix.assign(max_log, BitDict(n));
        mid_points.resize(max_log);
        vector<long long> left(n), right(n);
        for (int d = max_log - 1; d >= 0; d--) {
            vector<long long> l_vals, r_vals;
            for (int i = 0; i < n; i++) {
                if ((v[i] >> d) & 1) {
                    matrix[d].set(i);
                    r_vals.push_back(v[i]);
                } else {
                    l_vals.push_back(v[i]);
                }
            }
            matrix[d].build();
            mid_points[d] = l_vals.size();
            // v を次の段のために並び替える (0を前、1を後に集める)
            v = l_vals;
            v.insert(v.end(), r_vals.begin(), r_vals.end());
        }
    }
    // k番目の値を返す: O(log V)
    long long access(int k) const {
        long long res = 0;
        for (int d = max_log - 1; d >= 0; d--) {
            bool bit = matrix[d].access(k);
            if (bit) {
                res |= (1LL << d);
                k = mid_points[d] + matrix[d].rank1(k);
            } else {
                k = matrix[d].rank0(k);
            }
        }
        return res;
    }
    // [0, r) に含まれる x の個数: O(log V)
    int rank(long long x, int r) const {
        int l = 0;
        for (int d = max_log - 1; d >= 0; d--) {
            bool bit = (x >> d) & 1;
            if (bit) {
                l = mid_points[d] + matrix[d].rank1(l);
                r = mid_points[d] + matrix[d].rank1(r);
            } else {
                l = matrix[d].rank0(l);
                r = matrix[d].rank0(r);
            }
        }
        return r - l;
    }
    // [l, r) 内で k 番目に小さい値: O(log V)
    long long kth_smallest(int l, int r, int k) const {
        long long res = 0;
        for (int d = max_log - 1; d >= 0; d--) {
            int cnt0 = matrix[d].rank0(r) - matrix[d].rank0(l);
            if (k < cnt0) {
                l = matrix[d].rank0(l);
                r = matrix[d].rank0(r);
            } else {
                res |= (1LL << d);
                k -= cnt0;
                l = mid_points[d] + matrix[d].rank1(l);
                r = mid_points[d] + matrix[d].rank1(r);
            }
        }
        return res;
    }
    // [l, r) 内で k 番目に大きい値: O(log V)
    long long kth_largest(int l, int r, int k) const {
        return kth_smallest(l, r, (r - l) - 1 - k);
    }
    // [l, r) 内で [lower, upper) に含まれる要素数: O(log V)
    int range_freq(int l, int r, long long lower, long long upper) const {
        return count_less(l, r, upper) - count_less(l, r, lower);
    }
    // [l, r) 内で val 未満の要素数 (内部用補助関数)
    int count_less(int l, int r, long long val) const {
        int res = 0;
        for (int d = max_log - 1; d >= 0; d--) {
            bool bit = (val >> d) & 1;
            int cnt0 = matrix[d].rank0(r) - matrix[d].rank0(l);
            if (bit) {
                res += cnt0; // 0のビットを持つものは確実に val より小さい
                l = mid_points[d] + matrix[d].rank1(l);
                r = mid_points[d] + matrix[d].rank1(r);
            } else {
                l = matrix[d].rank0(l);
                r = matrix[d].rank0(r);
            }
        }
        return res;
    }
    // [l, r) 内で upper より小さい最大値
    long long prev_value(int l, int r, long long upper) const {
        int cnt = count_less(l, r, upper);
        return (cnt == 0) ? -1 : kth_smallest(l, r, cnt - 1);
    }
    // [l, r) 内で lower 以上の最小値
    long long next_value(int l, int r, long long lower) const {
        int cnt = count_less(l, r, lower);
        return (cnt == (r - l)) ? -1 : kth_smallest(l, r, cnt);
    }
};
#line 4 "library/search/binary_search/wavelet_search_engine.hpp"
struct WaveletSearchEngine {
    int n;
    Compressor<long long> cp;
    WaveletMatrix wm;
    const long long INF_VAL = 1e18;
    WaveletSearchEngine(const vector<long long> &v)
        : n(v.size()), cp(v), wm({}, 0) {
        vector<int> compressed_all = cp.get_all();
        vector<long long> wm_input(compressed_all.begin(),
                                   compressed_all.end());
        int max_log = 0;
        while ((1LL << max_log) <= (int)cp.size()) max_log++;
        if (max_log == 0) max_log = 1; // 要素が空または1種類の場合
        wm = WaveletMatrix(wm_input, max_log);
    }
    // --- 基本操作 ---
    // i 番目の元の値を返す
    long long access(int i) const { return cp.get_val(wm.access(i)); }
    // [l, r) 内に x が何個含まれるか
    int count_x(int l, int r, long long x) const {
        int id = cp.get_id(x);
        if (id >= cp.size() || cp.get_val(id) != x) return 0;
        return wm.rank(id, r) - wm.rank(id, l);
    }
    // --- 二分探索系統合インターフェース (bi_..._cnt) ---
    // [l, r) 内で x 未満 (Less Than) の個数
    int count_lt(int l, int r, long long x) const {
        return wm.count_less(l, r, cp.get_id(x));
    }
    // [l, r) 内で x 以下 (Less Equal) の個数
    int count_le(int l, int r, long long x) const {
        return wm.count_less(l, r, cp.get_upper_id(x));
    }
    // [l, r) 内で x 以上 (Greater Equal) の個数
    int count_ge(int l, int r, long long x) const {
        return (r - l) - count_lt(l, r, x);
    }
    // [l, r) 内で x より大きい (Greater Than) の個数
    int count_gt(int l, int r, long long x) const {
        return (r - l) - count_le(l, r, x);
    }
    // --- 値検索系統合インターフェース (bi_..._val) ---
    // [l, r) 内で k 番目に小さい値 (0-indexed)
    long long kth_smallest(int l, int r, int k) const {
        if (k < 0 || k >= (r - l)) return INF_VAL;
        return cp.get_val(wm.kth_smallest(l, r, k));
    }
    // [l, r) 内で k 番目に大きい値 (0-indexed)
    long long kth_largest(int l, int r, int k) const {
        if (k < 0 || k >= (r - l)) return -INF_VAL;
        return cp.get_val(wm.kth_largest(l, r, k));
    }
    // [l, r) 内で x 未満の最大値 (Less Than Value)
    long long find_lt(int l, int r, long long x) const {
        int cnt = count_lt(l, r, x);
        return (cnt == 0) ? -INF_VAL : kth_smallest(l, r, cnt - 1);
    }
    // [l, r) 内で x 以下の最大値 (Less Equal Value)
    long long find_le(int l, int r, long long x) const {
        int cnt = count_le(l, r, x);
        return (cnt == 0) ? -INF_VAL : kth_smallest(l, r, cnt - 1);
    }
    // [l, r) 内で x 以上の最小値 (Greater Equal Value)
    long long find_ge(int l, int r, long long x) const {
        int cnt = count_lt(l, r, x);
        return (cnt == (r - l)) ? INF_VAL : kth_smallest(l, r, cnt);
    }
    // [l, r) 内で x より大きい最小値 (Greater Than Value)
    long long find_gt(int l, int r, long long x) const {
        int cnt = count_le(l, r, x);
        return (cnt == (r - l)) ? INF_VAL : kth_smallest(l, r, cnt);
    }
    // --- 応用クエリ ---
    // [l, r) 内で [lower, upper) に含まれる要素数
    int range_freq(int l, int r, long long lower, long long upper) const {
        if (lower >= upper) return 0;
        return wm.count_less(l, r, cp.get_id(upper)) -
               wm.count_less(l, r, cp.get_id(lower));
    }
};
Back to top page