Segment Tree Beats

このデータ構造の実装には、 yaketake08さんの解説 を大変参考にさせていただきました。

    
/*
<Segment Tree Beats>
様々なクエリに対応した万能遅延セグメント木(0-indexed)
tree.update_max(l,r,x):区間[l,r]内のiについてa[i]をmax(a[i],x)に更新
tree.update_min(l,r,x):区間[l,r]内のiについてa[i]をmin(a[i],x)に更新
tree.query_add(l,r,x):区間[l,r]内のiについてa[i]にxを加算
tree.query_val(l,r,x):区間[l,r]内のiについてa[i]をxに更新
tree.query_max(l,r):区間[l,r]に対する最大値クエリ
tree.query_min(l,r):区間[l,r]に対する最小値クエリ
tree.query_sum(l,r):区間[l,r]に対する区間和クエリ
*/
template <typename T>
struct Segment_Tree_Beats
{
private:
    const T INF = numeric_limits<T>::max();
    int n, n0;
    vector<T> fmax;   //最大値
    vector<T> smax;   //2番目に大きい値
    vector<T> fmax_c; //最大値の個数
    vector<T> fmin;   //最小値
    vector<T> smin;   //2番目に小さい値
    vector<T> fmin_c; //最小値の個数
    vector<T> sum;    //区間和
    vector<T> len;    //区間の長さ
    vector<T> ladd;   //遅延加算
    vector<T> lval;   //遅延更新
    void update_node_max(int k, T x)
    {
        sum[k] += (x - fmax[k]) * fmax_c[k];
        if (fmax[k] == fmin[k])
        {
            fmax[k] = x;
            fmin[k] = x;
        }
        else
        {
            if (fmax[k] == smin[k])
            {
                fmax[k] = x;
                smin[k] = x;
            }
            else
                fmax[k] = x;
        }
        //遅延処理
        if (lval[k] != INF && x < lval[k])
            lval[k] = x;
    }
    void update_node_min(int k, T x)
    {
        sum[k] += (x - fmin[k]) * fmin_c[k];
        if (fmin[k] == fmax[k])
        {
            fmin[k] = x;
            fmax[k] = x;
        }
        else
        {
            if (fmin[k] == smax[k])
            {
                fmin[k] = x;
                smax[k] = x;
            }
            else
                fmin[k] = x;
        }
        //遅延処理
        if (lval[k] != INF && lval[k] < x)
            lval[k] = x;
    }
    //全体加算
    void addall(int k, T x)
    {
        fmax[k] += x;
        fmin[k] += x;
        if (smax[k] != -INF)
            smax[k] += x;
        if (smin[k] != INF)
            smin[k] += x;
        sum[k] += len[k] * x;
        if (lval[k] != INF)
            lval[k] += x;
        else
            ladd[k] += x;
    }
    //全体更新
    void updateall(int k, T x)
    {
        fmax[k] = x;
        fmin[k] = x;
        smax[k] = -INF;
        smin[k] = INF;
        fmax_c[k] = len[k];
        fmin_c[k] = len[k];
        sum[k] = x * len[k];
        lval[k] = x;
        ladd[k] = 0;
    }
    //親->子への伝播
    void push(int k)
    {
        if ((k - n0 + 1) >= 0)
            return;
        //遅延伝播
        if (lval[k] != INF)
        {
            updateall(2 * k + 1, lval[k]);
            updateall(2 * k + 2, lval[k]);
            lval[k] = INF;
            return;
        }
        if (ladd[k] != 0)
        {
            addall(2 * k + 1, ladd[k]);
            addall(2 * k + 2, ladd[k]);
            ladd[k] = 0;
        }
        for (int i = 1; i <= 2; i++)
        {
            if (fmax[k] < fmax[2 * k + i])
                update_node_max(2 * k + i, fmax[k]);
            if (fmin[k] > fmin[2 * k + i])
                update_node_min(2 * k + i, fmin[k]);
        }
    }
    //子->親への伝播
    void update(int k)
    {
        sum[k] = sum[2 * k + 1] + sum[2 * k + 2];
        //最大値関連の更新(大きい方を優先)
        if (fmax[2 * k + 1] < fmax[2 * k + 2])
        {
            fmax[k] = fmax[2 * k + 2];
            fmax_c[k] = fmax_c[2 * k + 2];
            smax[k] = max(fmax[2 * k + 1], smax[2 * k + 2]);
        }
        if (fmax[2 * k + 1] > fmax[2 * k + 2])
        {
            fmax[k] = fmax[2 * k + 1];
            fmax_c[k] = fmax_c[2 * k + 1];
            smax[k] = max(fmax[2 * k + 2], smax[2 * k + 1]);
        }
        if (fmax[2 * k + 1] == fmax[2 * k + 2])
        {
            fmax[k] = fmax[2 * k + 1];
            fmax_c[k] = fmax_c[2 * k + 1] + fmax_c[2 * k + 2];
            smax[k] = max(smax[2 * k + 1], smax[2 * k + 2]);
        }
        //最小値関連の更新(小さい方を優先)
        if (fmin[2 * k + 1] < fmin[2 * k + 2])
        {
            fmin[k] = fmin[2 * k + 1];
            fmin_c[k] = fmin_c[2 * k + 1];
            smin[k] = min(fmin[2 * k + 2], smin[2 * k + 1]);
        }
        if (fmin[2 * k + 1] > fmin[2 * k + 2])
        {
            fmin[k] = fmin[2 * k + 2];
            fmin_c[k] = fmin_c[2 * k + 2];
            smin[k] = min(fmin[2 * k + 1], smin[2 * k + 2]);
        }
        if (fmin[2 * k + 1] == fmin[2 * k + 2])
        {
            fmin[k] = fmin[2 * k + 1];
            fmin_c[k] = fmin_c[2 * k + 1] + fmin_c[2 * k + 2];
            smin[k] = max(smin[2 * k + 1], smin[2 * k + 2]);
        }
    }
    //区間[a,b)内のiについてa[i]をmax(a[i],x)に更新
    void _update_max(T x, int a, int b, int k, int l, int r)
    {
        //区間をはみ出すorxが最小値以下
        if (b <= l || r <= a || fmin[k] >= x)
            return;
        //(最小値)<x<(2番目に小さい値)のとき
        if (l >= a && r <= b && smin[k] > x)
        {
            update_node_min(k, x);
            return;
        }
        //親->子->親
        push(k);
        _update_max(x, a, b, 2 * k + 1, l, (l + r) / 2);
        _update_max(x, a, b, 2 * k + 2, (l + r) / 2, r);
        update(k);
    }
    //区間[a,b)内のiについてa[i]をmin(a[i],x)に更新
    void _update_min(T x, int a, int b, int k, int l, int r)
    {
        //区間をはみ出すorxが最大値以上
        if (b <= l || r <= a || fmax[k] <= x)
            return;
        //(2番目に大きい値)<x<(最大値)のとき
        if (l >= a && r <= b && smax[k] < x)
        {
            update_node_max(k, x);
            return;
        }
        //親->子->親
        push(k);
        _update_min(x, a, b, 2 * k + 1, l, (l + r) / 2);
        _update_min(x, a, b, 2 * k + 2, (l + r) / 2, r);
        update(k);
    }
    //区間[a,b)に対する最大値クエリ
    T _query_max(int a, int b, int k, int l, int r)
    {
        if (b <= l || a >= r)
            return -INF; //区間外
        if (a <= l && r <= b)
            return fmax[k];
        push(k);
        T L = _query_max(a, b, 2 * k + 1, l, (l + r) / 2);
        T R = _query_max(a, b, 2 * k + 2, (l + r) / 2, r);
        return max(L, R);
    }
    //区間[a,b)に対する最小値クエリ
    T _query_min(int a, int b, int k, int l, int r)
    {
        if (b <= l || a >= r)
            return INF; //区間外
        if (a <= l && r <= b)
            return fmin[k];
        push(k);
        T L = _query_min(a, b, 2 * k + 1, l, (l + r) / 2);
        T R = _query_min(a, b, 2 * k + 2, (l + r) / 2, r);
        return min(L, R);
    }
    //区間[a,b)に対する区間和クエリ
    T _query_sum(int a, int b, int k, int l, int r)
    {
        if (b <= l || a >= r)
            return 0; //区間外
        if (a <= l && r <= b)
            return sum[k];
        push(k);
        T L = _query_sum(a, b, 2 * k + 1, l, (l + r) / 2);
        T R = _query_sum(a, b, 2 * k + 2, (l + r) / 2, r);
        return (L + R);
    }
    //区間[a,b)に対する区間加算クエリ
    void _query_add(T x, int a, int b, int k, int l, int r)
    {
        if (b <= l || a >= r)
            return; //区間外
        if (a <= l && r <= b)
        {
            addall(k, x);
            return;
        }
        push(k);
        _query_add(x, a, b, 2 * k + 1, l, (l + r) / 2);
        _query_add(x, a, b, 2 * k + 2, (l + r) / 2, r);
        update(k);
    }
    //区間[a,b)に対する区間更新クエリ
    void _query_val(T x, int a, int b, int k, int l, int r)
    {
        if (b <= l || a >= r)
            return; //区間外
        if (a <= l && r <= b)
        {
            updateall(k, x);
            return;
        }
        push(k);
        _query_val(x, a, b, 2 * k + 1, l, (l + r) / 2);
        _query_val(x, a, b, 2 * k + 2, (l + r) / 2, r);
        update(k);
    }

public:
    Segment_Tree_Beats(int N) : n(N), fmax(4 * N), smax(4 * N), fmax_c(4 * N), fmin(4 * N), smin(4 * N),
                                fmin_c(4 * N), sum(4 * N), len(4 * N), ladd(4 * N), lval(4 * N)
    {
        n0 = 1;
        while (n0 < n)
            n0 *= 2;
        len[0] = n0;
        for (int i = 0; i < 2 * n0; i++)
        {
            lval[i] = INF;
        }
        for (int i = 0; i < n0 - 1; i++)
        {
            len[2 * i + 1] = len[i] / 2;
            len[2 * i + 2] = len[i] / 2;
        }

        for (int i = 0; i < n; i++)
        {
            int nn = n0 - 1 + i;
            smax[nn] = -INF;
            smin[nn] = INF;
            fmax_c[nn] = 1;
            fmin_c[nn] = 1;
        }
        for (int i = n; i < n0; i++)
        {
            int nn = n0 - 1 + i;
            fmax[nn] = -INF;
            smax[nn] = -INF;
            fmin[nn] = INF;
            smin[nn] = INF;
        }
        for (int i = n0 - 2; i >= 0; i--)
        {
            update(i);
        }
    }
    Segment_Tree_Beats(int N, vector<T> &A) : n(N), fmax(4 * N), smax(4 * N), fmax_c(4 * N), fmin(4 * N), smin(4 * N),
                                              fmin_c(4 * N), sum(4 * N), len(4 * N), ladd(4 * N), lval(4 * N)
    {
        n0 = 1;
        while (n0 < n)
            n0 *= 2;
        len[0] = n0;
        for (int i = 0; i < 2 * n0; i++)
        {
            lval[i] = INF;
        }
        for (int i = 0; i < n0 - 1; i++)
        {
            len[2 * i + 1] = len[i] / 2;
            len[2 * i + 2] = len[i] / 2;
        }

        for (int i = 0; i < n; i++)
        {
            int nn = n0 - 1 + i;
            fmax[nn] = A[i];
            fmin[nn] = A[i];
            sum[nn] = A[i];
            smax[nn] = -INF;
            smin[nn] = INF;
            fmax_c[nn] = 1;
            fmin_c[nn] = 1;
        }
        for (int i = n; i < n0; i++)
        {
            int nn = n0 - 1 + i;
            fmax[nn] = -INF;
            smax[nn] = -INF;
            fmin[nn] = INF;
            smin[nn] = INF;
        }
        for (int i = n0 - 2; i >= 0; i--)
        {
            update(i);
        }
    }
    //区間[l,r]内のiについてa[i]をmax(a[i],x)に更新
    void update_max(int l, int r, T x)
    {
        _update_max(x, l, r + 1, 0, 0, n0);
    }
    //区間[l,r]内のiについてa[i]をmin(a[i],x)に更新
    void update_min(int l, int r, T x)
    {
        _update_min(x, l, r + 1, 0, 0, n0);
    }
    //区間[l,r]内のiについてa[i]にxを加算
    void query_add(int l, int r, T x)
    {
        _query_add(x, l, r + 1, 0, 0, n0);
    }
    //区間[l,r]内のiについてa[i]をxに更新
    void query_val(int l, int r, T x)
    {
        _query_val(x, l, r + 1, 0, 0, n0);
    }
    //区間[l,r]に対する最大値クエリ
    T query_max(int l, int r)
    {
        return _query_max(l, r + 1, 0, 0, n0);
    }
    //区間[l,r]に対する最小値クエリ
    T query_min(int l, int r)
    {
        return _query_min(l, r + 1, 0, 0, n0);
    }
    //区間[l,r]に対する区間和クエリ
    T query_sum(int l, int r)
    {
        return _query_sum(l, r + 1, 0, 0, n0);
    }
};
    
    
© 2020 kacho65535