serna37's Library

Logo

C++ アルゴリズムとデータ構造のライブラリ

View the Project on GitHub serna37/library-cpp

:heavy_check_mark: Aho Corasick
(library/string/aho_corasick.hpp)

Aho Corasick

できること

オートマトンとは

5つの組からなる。 $M = (Q, \sum, \delta, q_0, F)$

$M$ 自動改札なら
$Q$ { 閉鎖, 開放 }
$\sum$ { ICカードタッチ, 通り抜ける }
$\delta$ { (閉でタッチ) => 開にする, (閉で通り抜け) => 閉のまま, etc… }
$q_0$ 閉鎖
$F$ { 閉鎖 }

計算量

追加した文字列の長さの総和を $L$ 、文字種数を $\sum$ とする

使い方

// char_size=26, margin='a' と仮定
AhoCorasick<26, 'a'> ac;
vector<string> patterns = {"hers", "she", "he", "his"};
// Trie木に追加してからビルド
for (int i = 0; i < (int)patterns.size(); ++i) ac.add(patterns[i], i);
ac.build();

string text = "ushershehis";

// 文字列中の全ヒット数を数える
long long total = ac.count_all(text); // 6 (she, he, hers, she, he, his)

// 1文字ずつ遷移する
int now = 0;
for (char c : text) {
    now = ac.next(now, c);
    int hits = ac.count[now]; // その地点で終わるパターンの数
}

Depends on

Required by

Verified with

Code

#pragma once
#include "library/string/trie.hpp"
template <int char_size, int margin>
struct AhoCorasick : Trie<char_size, margin> {
    using Trie<char_size, margin>::nodes;
    vector<int> failure; // 失敗リンク
    vector<int> count;   // そのノードでマッチするパターンの総数
    AhoCorasick() : Trie<char_size, margin>() {}
    // 失敗リンクの構築と遷移関数の最適化
    void build() {
        int n = (int)nodes.size();
        failure.assign(n, 0);
        count.assign(n, 0);
        for (int i = 0; i < n; i++) {
            count[i] = (int)nodes[i].accept.size();
        }
        queue<int> que;
        for (int i = 0; i < char_size; i++) {
            if (nodes[0].nxt[i] != -1) {
                que.push(nodes[0].nxt[i]);
            } else {
                nodes[0].nxt[i] = 0;
            }
        }
        while (!que.empty()) {
            int now = que.front();
            que.pop();
            for (int i = 0; i < char_size; i++) {
                int &next_node = nodes[now].nxt[i];
                int fail_link = nodes[failure[now]].nxt[i];
                if (next_node != -1) {
                    failure[next_node] = fail_link;
                    count[next_node] += count[fail_link];
                    que.push(next_node);
                } else {
                    next_node = fail_link;
                }
            }
        }
    }
    // 次の状態へ遷移
    int next(int now, char c) const { return nodes[now].nxt[c - margin]; }
    // 文字列全体を走査して総ヒット数を返す
    long long count_all(const string &s) const {
        long long res = 0;
        int now = 0;
        for (char c : s) {
            now = next(now, c);
            res += count[now];
        }
        return res;
    }
};
#line 2 "library/string/trie.hpp"
template <int char_size, int margin> struct Trie {
    struct Node {
        vector<int> nxt;
        vector<int> accept; // その地点で終わる単語のIDリスト
        int exist;          // その地点を接頭辞として持つ単語の数
        Node() : nxt(char_size, -1), exist(0) {}
    };
    vector<Node> nodes;
    Trie() { nodes.emplace_back(); }
    int size() const { return (int)nodes.size(); }
    // 単語の追加
    virtual void add(const string &s, int id = -1) {
        int now = 0;
        for (char c : s) {
            int x = c - margin;
            if (nodes[now].nxt[x] == -1) {
                nodes[now].nxt[x] = (int)nodes.size();
                nodes.emplace_back();
            }
            now = nodes[now].nxt[x];
            nodes[now].exist++;
        }
        if (id != -1) nodes[now].accept.push_back(id);
    }
    // 単一の単語の検索 (完全一致)
    bool search(const string &s) const {
        int now = 0;
        for (char c : s) {
            int x = c - margin;
            if (nodes[now].nxt[x] == -1) return false;
            now = nodes[now].nxt[x];
        }
        return !nodes[now].accept.empty();
    }
    // 接頭辞検索:s を接頭辞として持つ単語の数を返す
    int count_prefix(const string &s) const {
        int now = 0;
        for (char c : s) {
            int x = c - margin;
            if (nodes[now].nxt[x] == -1) return 0;
            now = nodes[now].nxt[x];
        }
        return nodes[now].exist;
    }
};
#line 3 "library/string/aho_corasick.hpp"
template <int char_size, int margin>
struct AhoCorasick : Trie<char_size, margin> {
    using Trie<char_size, margin>::nodes;
    vector<int> failure; // 失敗リンク
    vector<int> count;   // そのノードでマッチするパターンの総数
    AhoCorasick() : Trie<char_size, margin>() {}
    // 失敗リンクの構築と遷移関数の最適化
    void build() {
        int n = (int)nodes.size();
        failure.assign(n, 0);
        count.assign(n, 0);
        for (int i = 0; i < n; i++) {
            count[i] = (int)nodes[i].accept.size();
        }
        queue<int> que;
        for (int i = 0; i < char_size; i++) {
            if (nodes[0].nxt[i] != -1) {
                que.push(nodes[0].nxt[i]);
            } else {
                nodes[0].nxt[i] = 0;
            }
        }
        while (!que.empty()) {
            int now = que.front();
            que.pop();
            for (int i = 0; i < char_size; i++) {
                int &next_node = nodes[now].nxt[i];
                int fail_link = nodes[failure[now]].nxt[i];
                if (next_node != -1) {
                    failure[next_node] = fail_link;
                    count[next_node] += count[fail_link];
                    que.push(next_node);
                } else {
                    next_node = fail_link;
                }
            }
        }
    }
    // 次の状態へ遷移
    int next(int now, char c) const { return nodes[now].nxt[c - margin]; }
    // 文字列全体を走査して総ヒット数を返す
    long long count_all(const string &s) const {
        long long res = 0;
        int now = 0;
        for (char c : s) {
            now = next(now, c);
            res += count[now];
        }
        return res;
    }
};
Back to top page