serna37's Library

Logo

C++ アルゴリズムとデータ構造のライブラリ

View the Project on GitHub serna37/library-cpp

:heavy_check_mark: モジュロ演算
(library/number/mod/montgomery_mod_int.hpp)

モジュロ演算

できること

計算量

$O(1)$

使い方

// 1. 宣言と初期化 任意のMODを指定する。2つはよく使うのであらかじめ定義済み
modint998244353 a = 10;
modint1000000007 b = 20;
using mint = modint<MOD>;
using mint = modint<3456>;

// 2. 標準入出力に対応。自動的に mod されるし、自動的にval()が呼ばれる
modint998244353 c;
cin >> c;
cout << c << endl;

// 3. 四則演算
modint998244353 res_add = a + b;
modint998244353 res_mul = a * b;
modint998244353 res_div = a / b; // 内部で逆元(inv)を掛けています

// 4. 累乗 (pow) は内部で、二分累乗を行う
modint998244353 res_pow = a.pow(100); // a の 100 乗

Required by

Verified with

Code

#pragma once
template <uint32_t mod_, bool fast = false> struct MontgomeryModInt {
  private:
    using mint = MontgomeryModInt;
    using i32 = int32_t;
    using i64 = int64_t;
    using u32 = uint32_t;
    using u64 = uint64_t;
    static constexpr u32 get_r() {
        u32 ret = mod_;
        for (i32 i = 0; i < 4; i++) ret *= 2 - mod_ * ret;
        return ret;
    }
    static constexpr u32 r = get_r();
    static constexpr u32 n2 = -u64(mod_) % mod_;
    static_assert(r * mod_ == 1, "invalid, r * mod != 1");
    static_assert(mod_ < (1 << 30), "invalid, mod >= 2 ^ 30");
    static_assert((mod_ & 1) == 1, "invalid, mod % 2 == 0");
    u32 x;

  public:
    MontgomeryModInt() : x{} {}
    MontgomeryModInt(const i64 &a)
        : x(reduce(u64(fast ? a : (a % mod() + mod())) * n2)) {}
    static constexpr u32 reduce(const u64 &b) {
        return u32(b >> 32) + mod() - u32((u64(u32(b) * r) * mod()) >> 32);
    }
    mint &operator+=(const mint &p) {
        if (i32(x += p.x - 2 * mod()) < 0) x += 2 * mod();
        return *this;
    }
    mint &operator-=(const mint &p) {
        if (i32(x -= p.x) < 0) x += 2 * mod();
        return *this;
    }
    mint &operator*=(const mint &p) {
        x = reduce(u64(x) * p.x);
        return *this;
    }
    mint &operator/=(const mint &p) {
        *this *= p.inv();
        return *this;
    }
    mint operator-() const { return mint() - *this; }
    mint operator+(const mint &p) const { return mint(*this) += p; }
    mint operator-(const mint &p) const { return mint(*this) -= p; }
    mint operator*(const mint &p) const { return mint(*this) *= p; }
    mint operator/(const mint &p) const { return mint(*this) /= p; }
    bool operator==(const mint &p) const {
        return (x >= mod() ? x - mod() : x) ==
               (p.x >= mod() ? p.x - mod() : p.x);
    }
    bool operator!=(const mint &p) const {
        return (x >= mod() ? x - mod() : x) !=
               (p.x >= mod() ? p.x - mod() : p.x);
    }
    u32 val() const {
        u32 ret = reduce(x);
        return ret >= mod() ? ret - mod() : ret;
    }
    mint pow(u64 n) const {
        mint ret(1), mul(*this);
        while (n > 0) {
            if (n & 1) ret *= mul;
            mul *= mul;
            n >>= 1;
        }
        return ret;
    }
    mint inv() const { return pow(mod() - 2); }
    friend ostream &operator<<(ostream &os, const mint &p) {
        return os << p.val();
    }
    friend istream &operator>>(istream &is, mint &a) {
        i64 t;
        is >> t;
        a = mint(t);
        return is;
    }
    static constexpr u32 mod() { return mod_; }
};
template <uint32_t mod> using modint = MontgomeryModInt<mod>;
using modint998244353 = modint<998244353>;
using modint1000000007 = modint<1000000007>;
#line 2 "library/number/mod/montgomery_mod_int.hpp"
template <uint32_t mod_, bool fast = false> struct MontgomeryModInt {
  private:
    using mint = MontgomeryModInt;
    using i32 = int32_t;
    using i64 = int64_t;
    using u32 = uint32_t;
    using u64 = uint64_t;
    static constexpr u32 get_r() {
        u32 ret = mod_;
        for (i32 i = 0; i < 4; i++) ret *= 2 - mod_ * ret;
        return ret;
    }
    static constexpr u32 r = get_r();
    static constexpr u32 n2 = -u64(mod_) % mod_;
    static_assert(r * mod_ == 1, "invalid, r * mod != 1");
    static_assert(mod_ < (1 << 30), "invalid, mod >= 2 ^ 30");
    static_assert((mod_ & 1) == 1, "invalid, mod % 2 == 0");
    u32 x;

  public:
    MontgomeryModInt() : x{} {}
    MontgomeryModInt(const i64 &a)
        : x(reduce(u64(fast ? a : (a % mod() + mod())) * n2)) {}
    static constexpr u32 reduce(const u64 &b) {
        return u32(b >> 32) + mod() - u32((u64(u32(b) * r) * mod()) >> 32);
    }
    mint &operator+=(const mint &p) {
        if (i32(x += p.x - 2 * mod()) < 0) x += 2 * mod();
        return *this;
    }
    mint &operator-=(const mint &p) {
        if (i32(x -= p.x) < 0) x += 2 * mod();
        return *this;
    }
    mint &operator*=(const mint &p) {
        x = reduce(u64(x) * p.x);
        return *this;
    }
    mint &operator/=(const mint &p) {
        *this *= p.inv();
        return *this;
    }
    mint operator-() const { return mint() - *this; }
    mint operator+(const mint &p) const { return mint(*this) += p; }
    mint operator-(const mint &p) const { return mint(*this) -= p; }
    mint operator*(const mint &p) const { return mint(*this) *= p; }
    mint operator/(const mint &p) const { return mint(*this) /= p; }
    bool operator==(const mint &p) const {
        return (x >= mod() ? x - mod() : x) ==
               (p.x >= mod() ? p.x - mod() : p.x);
    }
    bool operator!=(const mint &p) const {
        return (x >= mod() ? x - mod() : x) !=
               (p.x >= mod() ? p.x - mod() : p.x);
    }
    u32 val() const {
        u32 ret = reduce(x);
        return ret >= mod() ? ret - mod() : ret;
    }
    mint pow(u64 n) const {
        mint ret(1), mul(*this);
        while (n > 0) {
            if (n & 1) ret *= mul;
            mul *= mul;
            n >>= 1;
        }
        return ret;
    }
    mint inv() const { return pow(mod() - 2); }
    friend ostream &operator<<(ostream &os, const mint &p) {
        return os << p.val();
    }
    friend istream &operator>>(istream &is, mint &a) {
        i64 t;
        is >> t;
        a = mint(t);
        return is;
    }
    static constexpr u32 mod() { return mod_; }
};
template <uint32_t mod> using modint = MontgomeryModInt<mod>;
using modint998244353 = modint<998244353>;
using modint1000000007 = modint<1000000007>;
Back to top page