抽象化遅延Segment Tree

このデータ構造の実装には、634kamiさんの解説 を大変参考にさせていただきました。

    
template <typename X, typename M>
struct lazy_segment_tree
{
private:
    using FX = function<X(X, X)>;
    using FA = function<X(X, M)>;
    using FM = function<M(M, M)>;
    int n;
    FX fx;
    FA fa;
    FM fm;
    const X ex;
    const M em;
    vector<X> node;
    vector<M> lazy;
    void eval(int k)
    {
        if (lazy[k] == em)
            return;
        if (k <= (n - 2))
        {
            lazy[2 * k + 1] = fm(lazy[2 * k + 1], lazy[k]);
            lazy[2 * k + 2] = fm(lazy[2 * k + 2], lazy[k]);
        }
        node[k] = fa(node[k], lazy[k]);
        lazy[k] = em;
    }
    void _update(int a, int b, M x, int k, int l, int r)
    {
        eval(k);
        if (b <= l || r <= a)
            return;
        if (a <= l && r <= b)
        {
            lazy[k] = fm(lazy[k], x);
            eval(k);
            return;
        }
        _update(a, b, x, 2 * k + 1, l, (l + r) / 2);
        _update(a, b, x, 2 * k + 2, (l + r) / 2, r);
        node[k] = fx(node[2 * k + 1], node[2 * k + 2]);
    }
    X _query(int a, int b, int k, int l, int r)
    {
        eval(k);
        if (b <= l || r <= a)
            return ex;
        if (a <= l && r <= b)
            return node[k];
        X L = _query(a, b, 2 * k + 1, l, (l + r) / 2);
        X R = _query(a, b, 2 * k + 2, (l + r) / 2, r);
        return fx(L, R);
    }

public:
    lazy_segment_tree(int n0, FX fx0, FA fa0, FM fm0, X ex0, M em0)
        : n(1), fx(fx0), fa(fa0), fm(fm0), ex(ex0), em(em0)
    {
        while (n < n0)
            n *= 2;
        node.resize(2 * n - 1, ex);
        lazy.resize(2 * n - 1, em);
    }
    void set(int i, X x) { node[i + (n - 1)] = x; }
    void build()
    {
        for (int i = n - 2; i >= 0; i--)
            node[i] = fx(node[2 * i + 1], node[2 * i + 2]);
    }
    void build(vector<X> A)
    {
        int siz = A.size();
        for (int i = 0; i < siz; i++)
            set(i, A[i]);
        for (int i = n - 2; i >= 0; i--)
            node[i] = fx(node[2 * i + 1], node[2 * i + 2]);
    }
    void update(int l, int r, M x) { _update(l, r + 1, x, 0, 0, n); }
    X query(int l, int r) { return _query(l, r + 1, 0, 0, n); }
};
    
    
© 2020 kacho65535