serna37's Library

Logo

C++ アルゴリズムとデータ構造のライブラリ

View the Project on GitHub serna37/library-cpp

:heavy_check_mark: LCA
(library/graph/tree/lca.hpp)

LCA

できること

計算量

使い方

int N = 5;
Graph G(N);
// 0-1(10), 1-2(20), 1-3(30), 3-4(40)
G.add_both(0, 1, 10);
G.add_both(1, 2, 20);
G.add_both(1, 3, 30);
G.add_both(3, 4, 40);

LCA tree(G, 0); // 根が0

// 頂点2と頂点4の距離
// パス: 2 - (20) - 1 - (30) - 3 - (40) - 4  => 合計 90
long long dis = tree.get_dist(2, 4);
// LCAは1
int lca = tree.get_lca(2, 4);

Depends on

Verified with

Code

#pragma once
#include "library/graph/shortest_path/bfs.hpp"
#include "library/dp/doubling.hpp"
struct Node {
    static const int e = -1;
    int to = e;
    long long dist = 0;
    Node() = default;
    Node(int to, long long dist) : to(to), dist(dist) {}
    Node operator+(const Node &A) const {
        if (to == e) return *this;
        return {A.to, dist + A.dist};
    }
};
struct LCA {
    vector<int> depth;
    Doubling<Node> db;
    LCA(const Graph &G, int root = 0) {
        int N = G.size();
        // 1. 既存のBFSを利用して深さと親を取得
        auto [dis, route] = bfs(G, {root});
        depth = dis;
        // 2. Node(親への移動先, そのエッジのコスト) の初期配列を作成
        // bfsの結果(route)にはコストが含まれていないため、グラフから取得
        vector<Node> next(N, Node(Node::e, 0));
        for (int v = 0; v < N; ++v) {
            int p = route[v];
            if (p != -1) {
                // vから親pへのエッジコストを探す
                for (auto &e : G[v]) {
                    if (e.to == p) {
                        next[v] = Node(p, e.cost);
                        break;
                    }
                }
            }
        }
        // 3. ダブリング構築
        db = Doubling<Node>(next, N);
    }
    int get_lca(int u, int v) const {
        if (depth[u] > depth[v]) swap(u, v);
        v = db.query(v, depth[v] - depth[u]).to;
        if (u == v) return u;
        for (int k = db.log - 1; k >= 0; --k) {
            if (db.table[k][u].to != db.table[k][v].to) {
                u = db.table[k][u].to;
                v = db.table[k][v].to;
            }
        }
        return db.table[0][u].to;
    }
    long long get_dist(int u, int v) const {
        int lca = get_lca(u, v);
        return db.query(u, depth[u] - depth[lca]).dist +
               db.query(v, depth[v] - depth[lca]).dist;
    }
};
#line 2 "library/graph/base/edge.hpp"
struct Edge {
    int from, to;
    long long cost;
    int idx;
    Edge(int from, int to, long long cost = 1, int idx = -1)
        : from(from), to(to), cost(cost), idx(idx) {}
};
#line 3 "library/graph/base/graph.hpp"
struct Graph {
    int N;
    vector<vector<Edge>> G;
    int es;
    Graph() = default;
    Graph(int N) : N(N), G(N), es(0) {}
    const vector<Edge> &operator[](int v) const { return G[v]; }
    int size() const { return N; }
    void add(int from, int to, long long cost = 1) {
        G[from].push_back(Edge(from, to, cost, es++));
    }
    void add_both(int from, int to, long long cost = 1) {
        G[from].push_back(Edge(from, to, cost, es));
        G[to].push_back(Edge(to, from, cost, es++));
    }
    void read(int M, int padding = -1, bool weighted = false,
              bool directed = false) {
        for (int i = 0; i < M; i++) {
            int u, v;
            cin >> u >> v;
            u += padding, v += padding;
            long long cost = 1ll;
            if (weighted) cin >> cost;
            if (directed) {
                add(u, v, cost);
            } else {
                add_both(u, v, cost);
            }
        }
    }
};
#line 3 "library/graph/shortest_path/bfs.hpp"
pair<vector<int>, vector<int>> bfs(const Graph &G,
                                   const vector<int> &starts = {0}) {
    int N = G.size();
    queue<int> q;
    vector<int> dis(N, -1), route(N, -1);
    for (auto &&v : starts) q.push(v), dis[v] = 0;
    while (!q.empty()) {
        int v = q.front();
        q.pop();
        for (auto &&[from, to, cost, idx] : G[v]) {
            if (~dis[to]) continue;
            dis[to] = dis[from] + 1;
            q.push(to);
            route[to] = v;
        }
    }
    return {dis, route};
}
#line 2 "library/dp/doubling.hpp"
template <typename T> struct Doubling {
    int N, log = 0;
    vector<vector<T>> table;
    Doubling() {}
    Doubling(const vector<T> &next, long long max_steps) {
        N = next.size();
        while ((1ll << log) <= max_steps) ++log;
        table.assign(log, vector<T>(N, T()));
        table[0] = next;
        for (int k = 0; k < log - 1; ++k) {
            for (int v = 0; v < N; ++v) {
                if (table[k][v].to == T::e) {
                    table[k + 1][v] = table[k][v];
                } else {
                    table[k + 1][v] = table[k][v] + table[k][table[k][v].to];
                }
            }
        }
    }
    T query(int v, long long steps) const {
        T res;
        res.to = v;
        for (int k = 0; k < log; ++k) {
            if ((steps >> k) & 1) {
                if (res.to == T::e) break;
                res = res + table[k][res.to];
            }
        }
        return res;
    }
};
#line 4 "library/graph/tree/lca.hpp"
struct Node {
    static const int e = -1;
    int to = e;
    long long dist = 0;
    Node() = default;
    Node(int to, long long dist) : to(to), dist(dist) {}
    Node operator+(const Node &A) const {
        if (to == e) return *this;
        return {A.to, dist + A.dist};
    }
};
struct LCA {
    vector<int> depth;
    Doubling<Node> db;
    LCA(const Graph &G, int root = 0) {
        int N = G.size();
        // 1. 既存のBFSを利用して深さと親を取得
        auto [dis, route] = bfs(G, {root});
        depth = dis;
        // 2. Node(親への移動先, そのエッジのコスト) の初期配列を作成
        // bfsの結果(route)にはコストが含まれていないため、グラフから取得
        vector<Node> next(N, Node(Node::e, 0));
        for (int v = 0; v < N; ++v) {
            int p = route[v];
            if (p != -1) {
                // vから親pへのエッジコストを探す
                for (auto &e : G[v]) {
                    if (e.to == p) {
                        next[v] = Node(p, e.cost);
                        break;
                    }
                }
            }
        }
        // 3. ダブリング構築
        db = Doubling<Node>(next, N);
    }
    int get_lca(int u, int v) const {
        if (depth[u] > depth[v]) swap(u, v);
        v = db.query(v, depth[v] - depth[u]).to;
        if (u == v) return u;
        for (int k = db.log - 1; k >= 0; --k) {
            if (db.table[k][u].to != db.table[k][v].to) {
                u = db.table[k][u].to;
                v = db.table[k][v].to;
            }
        }
        return db.table[0][u].to;
    }
    long long get_dist(int u, int v) const {
        int lca = get_lca(u, v);
        return db.query(u, depth[u] - depth[lca]).dist +
               db.query(v, depth[v] - depth[lca]).dist;
    }
};
Back to top page