抽象化Segment Tree

このデータ構造の実装には、634kamiさんの解説及び ふるやんさんの解説 を大変参考にさせていただきました。

    
/*
抽象化したセグメント木
ex)要素数n,long long型,区間最大クエリ(fx),点加算(gx)のセグメント木が欲しい場合
auto fx=[](long long x,long long y){return max(x,y);};
auto gx=[](long long a,long long b){return a+b;};
long long id=numeric_limits<long long>::lowest();//単位元 実際の実装では十分小さい値でよい
segment_tree<long long> tree(n,id,fx,gx);で宣言 要素は単位元で初期化される

gcdを区間に作用させるときは単位元を0にするらしい
gxを点更新とする場合にはgxを省略できる
以下宣言後できること一覧(0-indexedで考える)
build(A):長さnの配列Aで要素を初期化する
update(i,x):i番目の値をxを用いて更新させる(更新,加算など)
query(l,r):区間[l,r]に対してfxを作用させた結果を返す
tree[x],tree.at(x):x番目の値を取得する
[fxを区間最大クエリにした場合]
find_rightest_max(l,r,x):区間[l,r]で最も右にあるx以上の要素の位置を返す(なければl-1)
find_leftest_max(l,r,x):区間[l,r]で最も左にあるx以上の要素の位置を返す(なければr+1)
[fxを区間最小クエリにした場合]
find_rightest_min(l,r,x):区間[l,r]で最も右にあるx以下の要素の位置を返す(なければl-1)
find_leftest_min(l,r,x):区間[l,r]で最も左にあるx以下の要素の位置を返す(なければr+1)
*/
template <typename T>
struct segment_tree
{
private:
    using FX = function<T(T, T)>;
    int n;
    T id;
    vector<T> data;
    FX fx, gx; //区間操作, 点更新操作用
    //[a,b)に対する区間操作 kは[l,r)に対するデータを保持する
    T sub_query(int a, int b, int k, int l, int r)
    {
        if (r <= a || b <= l)
            return id;
        if (a <= l && r <= b)
            return data[k];
        T L = sub_query(a, b, 2 * k + 1, l, (l + r) / 2); //左の子
        T R = sub_query(a, b, 2 * k + 2, (l + r) / 2, r); //右の子
        return fx(L, R);
    }
    int sub_find_rightest_min(T x, int a, int b, int k, int l, int r)
    {
        if (data[k] > x || r <= a || b <= l)
            return a - 1;
        if (k >= n - 1)
            return k - (n - 1);
        int R = sub_find_rightest_min(x, a, b, 2 * k + 2, (l + r) / 2, r);
        if (R != a - 1)
            return R;
        return sub_find_rightest_min(x, a, b, 2 * k + 1, l, (l + r) / 2);
    }
    int sub_find_leftest_min(T x, int a, int b, int k, int l, int r)
    {
        if (data[k] > x || r <= a || b <= l)
            return b;
        if (k >= n - 1)
            return k - (n - 1);
        int L = sub_find_leftest_min(x, a, b, 2 * k + 1, l, (l + r) / 2);
        if (L != b)
            return L;
        return sub_find_leftest_min(x, a, b, 2 * k + 2, (l + r) / 2, r);
    }
    int sub_find_rightest_max(T x, int a, int b, int k, int l, int r)
    {
        if (data[k] < x || r <= a || b <= l)
            return a - 1;
        if (k >= n - 1)
            return k - (n - 1);
        int R = sub_find_rightest_max(x, a, b, 2 * k + 2, (l + r) / 2, r);
        if (R != a - 1)
            return R;
        return sub_find_rightest_max(x, a, b, 2 * k + 1, l, (l + r) / 2);
    }
    int sub_find_leftest_max(T x, int a, int b, int k, int l, int r)
    {
        if (data[k] < x || r <= a || b <= l)
            return b;
        if (k >= n - 1)
            return k - (n - 1);
        int L = sub_find_leftest_max(x, a, b, 2 * k + 1, l, (l + r) / 2);
        if (L != b)
            return L;
        return sub_find_leftest_max(x, a, b, 2 * k + 2, (l + r) / 2, r);
    }

public:
    segment_tree(int n0, T id0, FX fx0, FX gx0) : n(1), id(id0), fx(fx0), gx(gx0)
    {
        while (n < n0)
            n *= 2;
        data.resize(2 * n - 1, id); //単位元で初期化
    }
    //1点更新クエリの場合はgxを省略できるように
    segment_tree(int n0, T id0, FX fx0) : segment_tree(n0, id0, fx0, [](T a, T b) { return b; }) {}
    //配列Aの値で初期化する
    void build(vector<T> A)
    {
        T siz = A.size();
        for (int i = 0; i < siz; i++)
            data[i + n - 1] = A[i];
        for (int i = n - 2; i >= 0; i--)
            data[i] = fx(data[2 * i + 1], data[2 * i + 2]);
    }
    void update(int i, T x)
    {
        i += n - 1;
        data[i] = gx(data[i], x);
        while (i > 0)
        {
            i = (i - 1) / 2;                                //親へ
            data[i] = fx(data[2 * i + 1], data[2 * i + 2]); //子ノードで更新
        }
    }
    T query(int l, int r) { return sub_query(l, r + 1, 0, 0, n); } //根からスタート
    int find_rightest_min(int l, int r, T x) { return sub_find_rightest_min(x, l, r + 1, 0, 0, n); }
    int find_leftest_min(int l, int r, T x) { return sub_find_leftest_min(x, l, r + 1, 0, 0, n); }
    int find_rightest_max(int l, int r, T x) { return sub_find_rightest_max(x, l, r + 1, 0, 0, n); }
    int find_leftest_max(int l, int r, T x) { return sub_find_leftest_max(x, l, r + 1, 0, 0, n); }
    //data[x]またはdata.at(x)でアクセスできるように
    T operator[](int i) { return data[i + n - 1]; }
    T at(int i) { return data[i + n - 1]; }
};
    
    
© 2020 kacho65535