serna37's Library

Logo

C++ アルゴリズムとデータ構造のライブラリ

View the Project on GitHub serna37/library-cpp

:heavy_check_mark: 文字列複数 統合検索エンジン
(library/string/finds.hpp)

文字列複数 統合検索エンジン

できること

計算量

以下3種類を自動で使い分けます

使い方

auto pos = finds(T, {S});
vector<int> idxs = pos[S];

auto pos = finds("abracadabra", {"abr", "ra", "a"});
// pos["abr"] == {0, 7}
// pos["ra"] == {2, 9}
// pos["a"] == {0, 3, 5, 7, 10}

Depends on

Verified with

Code

#pragma once
#include "library/string/rolling_hash.hpp"
#include "library/string/aho_corasick.hpp"
// 統合検索エンジン:戻り値を map<検索単語, 出現インデックスのリスト> で返却
map<string, vector<int>> finds(const string &T,
                               const vector<string> &patterns) {
    map<string, vector<int>> res_map;
    if (patterns.empty() || T.empty()) return res_map;
    // 1. パターンが1つだけで、かつ短い場合(ナイーブ)
    if (patterns.size() == 1 && patterns[0].size() < 10) {
        const string &S = patterns[0];
        size_t p = T.find(S);
        while (p != string::npos) {
            res_map[S].push_back((int)p);
            p = T.find(S, p + 1);
        }
        return res_map;
    }
    // 2. パターンが1つだけだが、長い場合(Rolling Hash)
    if (patterns.size() == 1) {
        vector<int> res;
        string P = patterns[0];
        int t = (int)T.size(), p = (int)P.size();
        if (p <= t) {
            RollingHash rht(T);
            RollingHash rhp(P);
            long long hash_p = rhp.get(0, p);
            for (int i = 0; i <= t - p; ++i) {
                if (rht.get(i, i + p) == hash_p) res.emplace_back(i);
            }
        }
        res_map[patterns[0]] = res;
        return res_map;
    }
    // 3. パターンが複数ある場合(Aho-Corasick)
    AhoCorasick<128, 0> ac;
    for (int i = 0; i < (int)patterns.size(); ++i) {
        if (patterns[i].empty()) continue;
        ac.add(patterns[i], i);
        res_map
            [patterns
                 [i]]; // ヒットしなかった単語もキーとして存在させる場合はここで初期化
    }
    ac.build();
    int now = 0;
    for (int i = 0; i < (int)T.size(); ++i) {
        now = ac.next(now, T[i]);
        int temp = now;
        while (temp > 0 && ac.count[temp] > 0) {
            for (int id : ac.nodes[temp].accept) {
                res_map[patterns[id]].push_back(i - (int)patterns[id].size() +
                                                1);
            }
            temp = ac.failure[temp];
        }
    }
    return res_map;
}
#line 2 "library/various/random.hpp"
#include <chrono>
#include <random>
inline long long random(long long a, long long b) {
    if (a >= b) return a;
    static mt19937 mt(chrono::steady_clock::now().time_since_epoch().count());
    uniform_int_distribution<long long> dist(a, b - 1);
    return dist(mt);
}
#line 3 "library/string/rolling_hash.hpp"
struct RollingHash {
    static const long long MOD = (1LL << 61) - 1;
    static inline long long base = 0;
    vector<long long> hash_sum;
    // 基数をメルセンヌツイスタの乱数で初期化する
    static void init_base() {
        if (base != 0) return;
        base = random(2, MOD - 1);
    }
    // 文字列からハッシュの累積和を構築する
    RollingHash(const string &s) {
        init_base();
        int n = s.size();
        hash_sum.assign(n + 1, 0);
        for (int i = 0; i < n; i++) {
            hash_sum[i + 1] = mul(hash_sum[i], base) + s[i];
            if (hash_sum[i + 1] >= MOD) hash_sum[i + 1] -= MOD;
        }
    }
    // 2^61-1 用の高速な掛け算
    static long long mul(long long a, long long b) {
        __int128_t res = (__int128_t)a * b;
        long long ans = (long long)(res >> 61) + (long long)(res & MOD);
        if (ans >= MOD) ans -= MOD;
        return ans;
    }
    // 累乗テーブルの管理
    static const vector<long long> &get_pow(int n) {
        static vector<long long> pow_memo = {1};
        while ((int)pow_memo.size() <= n) {
            pow_memo.push_back(mul(pow_memo.back(), base));
        }
        return pow_memo;
    }
    // s[l, r) のハッシュを取得
    long long get(int l, int r) const {
        long long res = hash_sum[r] - mul(hash_sum[l], get_pow(r - l)[r - l]);
        if (res < 0) res += MOD;
        return res;
    }
    // 2つのハッシュ(a, b)を結合する。bの長さが b_len
    static long long merge(long long a_hash, long long b_hash, int b_len) {
        long long res = mul(a_hash, get_pow(b_len)[b_len]) + b_hash;
        if (res >= MOD) res -= MOD;
        return res;
    }
};
#line 2 "library/string/trie.hpp"
template <int char_size, int margin> struct Trie {
    struct Node {
        vector<int> nxt;
        vector<int> accept; // その地点で終わる単語のIDリスト
        int exist;          // その地点を接頭辞として持つ単語の数
        Node() : nxt(char_size, -1), exist(0) {}
    };
    vector<Node> nodes;
    Trie() { nodes.emplace_back(); }
    int size() const { return (int)nodes.size(); }
    // 単語の追加
    virtual void add(const string &s, int id = -1) {
        int now = 0;
        for (char c : s) {
            int x = c - margin;
            if (nodes[now].nxt[x] == -1) {
                nodes[now].nxt[x] = (int)nodes.size();
                nodes.emplace_back();
            }
            now = nodes[now].nxt[x];
            nodes[now].exist++;
        }
        if (id != -1) nodes[now].accept.push_back(id);
    }
    // 単一の単語の検索 (完全一致)
    bool search(const string &s) const {
        int now = 0;
        for (char c : s) {
            int x = c - margin;
            if (nodes[now].nxt[x] == -1) return false;
            now = nodes[now].nxt[x];
        }
        return !nodes[now].accept.empty();
    }
    // 接頭辞検索:s を接頭辞として持つ単語の数を返す
    int count_prefix(const string &s) const {
        int now = 0;
        for (char c : s) {
            int x = c - margin;
            if (nodes[now].nxt[x] == -1) return 0;
            now = nodes[now].nxt[x];
        }
        return nodes[now].exist;
    }
};
#line 3 "library/string/aho_corasick.hpp"
template <int char_size, int margin>
struct AhoCorasick : Trie<char_size, margin> {
    using Trie<char_size, margin>::nodes;
    vector<int> failure; // 失敗リンク
    vector<int> count;   // そのノードでマッチするパターンの総数
    AhoCorasick() : Trie<char_size, margin>() {}
    // 失敗リンクの構築と遷移関数の最適化
    void build() {
        int n = (int)nodes.size();
        failure.assign(n, 0);
        count.assign(n, 0);
        for (int i = 0; i < n; i++) {
            count[i] = (int)nodes[i].accept.size();
        }
        queue<int> que;
        for (int i = 0; i < char_size; i++) {
            if (nodes[0].nxt[i] != -1) {
                que.push(nodes[0].nxt[i]);
            } else {
                nodes[0].nxt[i] = 0;
            }
        }
        while (!que.empty()) {
            int now = que.front();
            que.pop();
            for (int i = 0; i < char_size; i++) {
                int &next_node = nodes[now].nxt[i];
                int fail_link = nodes[failure[now]].nxt[i];
                if (next_node != -1) {
                    failure[next_node] = fail_link;
                    count[next_node] += count[fail_link];
                    que.push(next_node);
                } else {
                    next_node = fail_link;
                }
            }
        }
    }
    // 次の状態へ遷移
    int next(int now, char c) const { return nodes[now].nxt[c - margin]; }
    // 文字列全体を走査して総ヒット数を返す
    long long count_all(const string &s) const {
        long long res = 0;
        int now = 0;
        for (char c : s) {
            now = next(now, c);
            res += count[now];
        }
        return res;
    }
};
#line 4 "library/string/finds.hpp"
// 統合検索エンジン:戻り値を map<検索単語, 出現インデックスのリスト> で返却
map<string, vector<int>> finds(const string &T,
                               const vector<string> &patterns) {
    map<string, vector<int>> res_map;
    if (patterns.empty() || T.empty()) return res_map;
    // 1. パターンが1つだけで、かつ短い場合(ナイーブ)
    if (patterns.size() == 1 && patterns[0].size() < 10) {
        const string &S = patterns[0];
        size_t p = T.find(S);
        while (p != string::npos) {
            res_map[S].push_back((int)p);
            p = T.find(S, p + 1);
        }
        return res_map;
    }
    // 2. パターンが1つだけだが、長い場合(Rolling Hash)
    if (patterns.size() == 1) {
        vector<int> res;
        string P = patterns[0];
        int t = (int)T.size(), p = (int)P.size();
        if (p <= t) {
            RollingHash rht(T);
            RollingHash rhp(P);
            long long hash_p = rhp.get(0, p);
            for (int i = 0; i <= t - p; ++i) {
                if (rht.get(i, i + p) == hash_p) res.emplace_back(i);
            }
        }
        res_map[patterns[0]] = res;
        return res_map;
    }
    // 3. パターンが複数ある場合(Aho-Corasick)
    AhoCorasick<128, 0> ac;
    for (int i = 0; i < (int)patterns.size(); ++i) {
        if (patterns[i].empty()) continue;
        ac.add(patterns[i], i);
        res_map
            [patterns
                 [i]]; // ヒットしなかった単語もキーとして存在させる場合はここで初期化
    }
    ac.build();
    int now = 0;
    for (int i = 0; i < (int)T.size(); ++i) {
        now = ac.next(now, T[i]);
        int temp = now;
        while (temp > 0 && ac.count[temp] > 0) {
            for (int id : ac.nodes[temp].accept) {
                res_map[patterns[id]].push_back(i - (int)patterns[id].size() +
                                                1);
            }
            temp = ac.failure[temp];
        }
    }
    return res_map;
}
Back to top page